Commit 1d08f073 authored by Pruthvi's avatar Pruthvi Committed by Adam Procter

LSTM fusion + RNN fusion across time slice's for single layer (#826)

* - Added pattren matcher for LSTM cell

* WIP added support to replace lstm cell instead of subgraph

* WIP LSTM pattern matcher, fuses recurrent cells

* WIP added RNN CPU op

* WIP mkldnn emmiter code for fprop RNN

* WIP RNN mkldnn integration
- Added mkldnn kernel for uni directional LSTM in the CPU emitter

* add a getter for root node

* recurrent graph rewrite

* fix perms, rename match_root -> get_match_root

* fix comp errors

* make match_root return the topmost match; fix tests

* - WIP GetOutputElement for handling multiple LSTM o/ps
- use RecurrentGraphRewrite for replacing node after matching LSTM cells

* WIP LSTM multi Output + debug prints

* moved LSTM fusion to cpu_fusion

* WIP added RNN superfused OP

* WIP towards RNN layer fusion

* WIP multiple output slicing RNN

* WIP RNN mulitple o/ps fusion across layer

* WIP corrected input params for fused RNN OP

* concat corrosponding param's across differnt LSTM to form inputs to RNN fused op

* i) Added  test case for RNN kernel ii) runs without error's

* refactored and moved LSTM class to standalone file

* Rename RNN -> Rnn , LSTM -> Lstm

* WIP replace lstm slices to the consumer op

* Slicing works on multiple RNN layers

* fixed all bugs

* - Added CPU RNN Recurrent Fusion
- Added CPU LSTM fusion
- removed debug code
- style fix

* - Added support to compute src_iter and dst_iter instead of taking zero_memory_desc
- Added unit test to compute one LSTM cell

*  changed RNN op signature to accept number of states in basic unit of RNN(GRU/LSTM/ vanilla RNN) cell

* added sanity checks for RNN op

* Fixed issue related to patching the graph while replacing the RNN sliced outputs

* Fixed issue to feed the input symbols in the order X0, X1, ...Xt to the RNN op

* Added unit test for multi layer RNN fusion

* Removed debug statements

* Added mulitlayered serialized graph ii) fixed compilation issue

* Addressed PR comments

* i) WIP MKLDNN layout for RNN Op ii) added test case for INTERPRETER v/s CPU Rnn results

* - Fixed bug w.r.to src_layer feature size in rnn mkldnn emitter code
- Refactored cpu_fusion rnn test case

* merge origin/master with branch pruthvi/lstm_fusion

* style fix

* Added test case for multiple RNN layers

* i) make rnn as mkldnn op if it meets the constraints ii) assert if rnn is not mkldnn op

* fix unit test failure

* - Added support to reliabily identify the hiddent state and input symbols from the nodes collected by Pattern matcher
- Fixed failing unit tests

* style fix

* - removed "node type" dependency to replace the intermediate LSTM outputs

* Addressed PR comments

* Fix unit test

* - added MKLDNN emitter for LSTM op
- graph pass to concat LSTM input recurrent state tensors
- CPU layout assignment for LSTM Op
- Fixed bug in rnn/lstm unit test's
- made changes to use replace_output instead of replace_node for replacing matched graph nodes in LSTM/RNN fusion pass

(cherry picked from commit d16fc709265cc0a73e60c6d5f6d2878e7b908aca)

* style fix

* Renamed passes and style fixes
parent 6425a516
......@@ -216,11 +216,15 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND MKLDNN_INCLUDE_DIR)
runtime/cpu/op/conv_relu.cpp
runtime/cpu/op/convert_layout.cpp
runtime/cpu/op/sigmoid.cpp
runtime/cpu/op/rnn.cpp
runtime/cpu/op/lstm.cpp
runtime/cpu/op/matmul_bias.cpp
runtime/cpu/op/max_pool_with_indices.cpp
runtime/cpu/op/batch_norm_relu.cpp
runtime/cpu/pass/cpu_assignment.cpp
runtime/cpu/pass/cpu_fusion.cpp
runtime/cpu/pass/cpu_rnn_fusion.cpp
runtime/cpu/pass/cpu_concat_inputs.cpp
runtime/cpu/pass/cpu_workspace_insertion.cpp
runtime/cpu/pass/cpu_layout.cpp
runtime/cpu/pass/cpu_rnn_mat_fusion.cpp
......
This diff is collapsed.
......@@ -120,13 +120,17 @@
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/pass/cpu_assignment.hpp"
#include "ngraph/runtime/cpu/pass/cpu_concat_inputs.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_layout.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_shuffle_folding.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
......@@ -274,6 +278,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::BatchNorm), &runtime::cpu::CPU_Emitter::emit<op::BatchNorm>},
{TI(ngraph::op::BatchNormRelu), &runtime::cpu::CPU_Emitter::emit<op::BatchNormRelu>},
{TI(ngraph::op::BatchNormBackprop), &runtime::cpu::CPU_Emitter::emit<op::BatchNormBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::CPU_Emitter::emit<op::Lstm>},
{TI(ngraph::op::MaxPoolBackprop), &runtime::cpu::CPU_Emitter::emit<op::MaxPoolBackprop>},
{TI(ngraph::op::MaxPoolWithIndicesBackprop),
&runtime::cpu::CPU_Emitter::emit<op::MaxPoolWithIndicesBackprop>},
......@@ -282,6 +287,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Min), &runtime::cpu::CPU_Emitter::emit<op::Min>},
{TI(ngraph::op::Relu), &runtime::cpu::CPU_Emitter::emit<op::Relu>},
{TI(ngraph::op::ReluBackprop), &runtime::cpu::CPU_Emitter::emit<op::ReluBackprop>},
{TI(ngraph::op::Rnn), &runtime::cpu::CPU_Emitter::emit<op::Rnn>},
{TI(ngraph::op::Sigmoid), &runtime::cpu::CPU_Emitter::emit<op::Sigmoid>},
{TI(ngraph::op::Softmax), &runtime::cpu::CPU_Emitter::emit<op::Softmax>},
{TI(ngraph::op::SigmoidBackprop), &runtime::cpu::CPU_Emitter::emit<op::SigmoidBackprop>},
......@@ -317,6 +323,9 @@ void runtime::cpu::CPU_ExternalFunction::compile()
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
......
......@@ -764,6 +764,59 @@ size_t MKLDNNEmitter::build_batchnorm_backward(const mkldnn::memory::desc& weigh
return batchnorm_index;
}
size_t MKLDNNEmitter::build_rnn_forward(const mkldnn::memory::desc& src_layer_desc,
const mkldnn::memory::desc& src_iter_desc,
const mkldnn::memory::desc& weights_layer_desc,
const mkldnn::memory::desc& weights_iter_desc,
const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& dst_layer_desc,
const mkldnn::memory::desc& dst_iter_desc)
{
size_t src_layer_index = build_memory_primitive(src_layer_desc);
size_t src_iter_index = build_memory_primitive(src_iter_desc);
size_t weights_layer_index = build_memory_primitive(weights_layer_desc);
size_t weights_iter_index = build_memory_primitive(weights_iter_desc);
size_t bias_index = build_memory_primitive(bias_desc);
size_t dst_layer_index = build_memory_primitive(dst_layer_desc);
size_t dst_iter_index = build_memory_primitive(dst_iter_desc);
//TODO: figure our the role of workspace
auto null_memory_ = mkldnn::null_memory(mkldnn_utils::global_cpu_engine);
mkldnn::rnn_cell::desc rnn_cell(mkldnn::algorithm::vanilla_lstm);
mkldnn::rnn_forward::desc rnn_layer_desc(mkldnn::prop_kind::forward_inference,
rnn_cell,
mkldnn::rnn_direction::unidirectional_left2right,
src_layer_desc,
src_iter_desc,
weights_layer_desc,
weights_iter_desc,
bias_desc,
dst_layer_desc,
dst_iter_desc);
auto rnn_layer_prim_desc =
mkldnn::rnn_forward::primitive_desc(rnn_layer_desc, mkldnn_utils::global_cpu_engine);
size_t rnn_index = insert_primitive(
new mkldnn::rnn_forward(rnn_layer_prim_desc,
mkldnn::primitive::at(*m_mkldnn_primitives[src_layer_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[src_iter_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_layer_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_iter_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[bias_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[dst_layer_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[dst_iter_index]),
static_cast<mkldnn::memory>(null_memory_)));
m_primitive_deps[rnn_index] = {src_layer_index,
src_iter_index,
weights_layer_index,
weights_iter_index,
bias_index,
dst_layer_index,
dst_iter_index};
return rnn_index;
}
size_t MKLDNNEmitter::build_concat(const std::vector<mkldnn::memory::desc>& inputs_data_desc,
const mkldnn::memory::desc& result_desc,
const size_t concat_dim)
......
......@@ -202,6 +202,14 @@ namespace ngraph
const mkldnn::memory::desc& dweights_desc,
const double eps);
size_t build_rnn_forward(const mkldnn::memory::desc& src_layer_desc,
const mkldnn::memory::desc& src_iter_desc,
const mkldnn::memory::desc& weights_layer_desc,
const mkldnn::memory::desc& weights_iter_desc,
const mkldnn::memory::desc& bias_desc,
const mkldnn::memory::desc& dst_layer_desc,
const mkldnn::memory::desc& dst_iter_desc);
size_t build_concat(const std::vector<mkldnn::memory::desc>& inputs_data_desc,
const mkldnn::memory::desc& result_desc,
const size_t concat_dim);
......
......@@ -23,6 +23,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/relu.hpp"
......@@ -46,6 +47,7 @@ static const std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::AvgPoolBackprop),
TI(ngraph::op::BatchNorm),
TI(ngraph::op::BatchNormBackprop),
TI(ngraph::op::Concat),
TI(ngraph::op::Convolution),
TI(ngraph::op::ConvolutionBackpropData),
TI(ngraph::op::ConvolutionBackpropFilters),
......@@ -113,6 +115,10 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_
{memory::format::Ohwi8o, "memory::format::Ohwi8o"},
{memory::format::Ohwi16o, "memory::format::Ohwi16o"},
{memory::format::OhIw16o4i, "memory::format::OhIw16o4i"},
{memory::format::tnc, "memory::format::tnc"},
{memory::format::ldsnc, "memory::format::ldsnc"},
{memory::format::ldigo, "memory::format::ldigo"},
{memory::format::ldgo, "memory::format::ldgo"},
};
static const std::set<memory::format> s_filter_formats{
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
shared_ptr<Node> op::Lstm::copy_with_new_args(const NodeVector& new_args) const
{
if (!m_fused_inputs)
{
if (new_args.size() != 7)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Lstm>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
new_args.at(5),
new_args.at(6));
}
else
{
if (new_args.size() != 5 && m_fused_inputs)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Lstm>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4));
}
}
op::Lstm::Lstm(std::shared_ptr<Node> input_xt_1,
std::shared_ptr<Node> i2h_weights,
std::shared_ptr<Node> hidden_state_ht_1,
std::shared_ptr<Node> h2h_weights,
std::shared_ptr<Node> i2h_bias,
std::shared_ptr<Node> h2h_bias,
std::shared_ptr<Node> cell_state_ct_1)
: RequiresTensorViewArgs("Lstm",
{input_xt_1,
i2h_weights,
hidden_state_ht_1,
h2h_weights,
i2h_bias,
h2h_bias,
cell_state_ct_1})
, m_output_tensor_shape(hidden_state_ht_1->get_shape())
, m_output_cell_shape(cell_state_ct_1->get_shape())
, m_num_timesteps(1)
, m_num_gates_per_cell(4)
, m_src_sequence_length(1)
, m_src_layer_feature_size(static_cast<int>(input_xt_1->get_shape()[1]))
, m_src_iter_feature_size(static_cast<int>(hidden_state_ht_1->get_shape()[1]))
, m_num_cell_states(2)
, m_direction(1)
, m_num_fused_layers(1)
, m_fused_inputs(false)
{
if (input_xt_1->get_shape().size() != i2h_weights->get_shape().size())
{
throw ngraph_error("input_xt_1 and i2h weights size dont match");
}
if (hidden_state_ht_1->get_shape().size() != h2h_weights->get_shape().size())
{
throw ngraph_error("hidden_state_ht_1 and h2h weights size dont match");
}
if (input_xt_1->get_shape().size() == 2)
{
m_batch_size = static_cast<int>(input_xt_1->get_shape()[0]);
}
else
{
throw ngraph_error("input_xt_1 doesnt have a rank 2");
}
if (shape_size(input_xt_1->get_shape()) !=
m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
{
std::cout << "shape_size: " << shape_size(input_xt_1->get_shape()) << std::endl;
throw ngraph_error("input_xt_1 size is not equal t*n*c");
}
if (i2h_bias->get_shape()[0] != i2h_weights->get_shape()[0] ||
h2h_bias->get_shape()[0] != h2h_weights->get_shape()[0])
{
throw ngraph_error("bias and weights_shape are not compatible");
}
auto et = input_xt_1->get_element_type();
for (auto& lstm_input : get_arguments())
{
if (lstm_input->get_element_type() != et)
{
throw ngraph_error("all rnn inputs must have the same element type");
}
}
add_output(hidden_state_ht_1->get_element_type(), hidden_state_ht_1->get_shape());
add_output(cell_state_ct_1->get_element_type(), cell_state_ct_1->get_shape());
}
op::Lstm::Lstm(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> weights_layer,
std::shared_ptr<Node> weights_iter,
std::shared_ptr<Node> bias)
: RequiresTensorViewArgs("Lstm", {src_layer, src_iter, weights_layer, weights_iter, bias})
, m_output_tensor_shape(src_layer->get_shape())
, m_output_cell_shape(src_iter->get_shape())
, m_num_timesteps(1)
, m_num_gates_per_cell(4)
, m_src_sequence_length(1)
, m_src_layer_feature_size(static_cast<int>(src_layer->get_shape()[1]))
, m_src_iter_feature_size(static_cast<int>(src_iter->get_shape()[1]))
, m_num_cell_states(2)
, m_direction(1)
, m_num_fused_layers(1)
, m_fused_inputs(true)
{
if (src_layer->get_shape().size() != weights_layer->get_shape().size())
{
throw ngraph_error("src_layer and i2h weights size dont match");
}
if (src_iter->get_shape().size() != weights_iter->get_shape().size())
{
throw ngraph_error("src_iter and h2h weights size dont match");
}
if (src_layer->get_shape().size() == 2)
{
m_batch_size = static_cast<int>(src_layer->get_shape()[0] / m_num_timesteps);
}
else
{
throw ngraph_error("src_layer doesnt have a rank 2");
}
if (shape_size(src_layer->get_shape()) !=
m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
{
std::cout << "shape_size: " << shape_size(src_layer->get_shape()) << std::endl;
throw ngraph_error("src_layer size is not equal t*n*c");
}
if (bias->get_shape()[0] != weights_layer->get_shape()[0] ||
bias->get_shape()[0] != weights_iter->get_shape()[0])
{
throw ngraph_error("bias and weights_shape are not compatible");
}
auto et = src_layer->get_element_type();
for (auto& rnn_input : get_arguments())
{
if (rnn_input->get_element_type() != et)
{
throw ngraph_error("all rnn inputs must have the same element type");
}
}
add_output(src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_num_timesteps * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
add_output(src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_num_cell_states * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/util/requires_tensor_view_args.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
class Lstm : public util::RequiresTensorViewArgs
{
public:
// INPUTS:
// [0] - xt, input tensor of layout TNC, Shape{sequence length*batch_size, feature_size}
// [1] - initializer for the input weights matrix, used for the linear transformation of the inputs.
// [2] - ht_1, hidden state of shape (batch_size, feature_size)
// [3] - initializer for the recurrent weights matrix, used for the linear transformation of the recurrent state.
// [4] - Initializer for the bias vector w.r.to inputs.
// [5] - Initializer for the bias vector w.r.to hidden state
// [6] - ct_1, cell state of shape (batch_size, feature_size)
// OUTPUT VALUE: A tuple with the following structure:
// [0] - ht, output tensor with shape (sequence_length*batch_size, num_hidden) .
// [1] - ct, output recurrent state tensor with the same shape as cell state
// This version of the LSTM op is only used to simplify recurrent RNN cell(LSTM) fusion across
// horizontal time steps. This doesnt have mkldnn emitter code.
Lstm(std::shared_ptr<Node> input_xt_1,
std::shared_ptr<Node> i2h_weights,
std::shared_ptr<Node> hidden_state_ht_1,
std::shared_ptr<Node> h2h_weights,
std::shared_ptr<Node> i2h_bias,
std::shared_ptr<Node> h2h_bias,
std::shared_ptr<Node> cell_state_ct_1);
// INPUTS:
// [0] - {Xt} input tensor of layout TNC, Shape{sequence length*batch_size, feature_size}
// [1] - recurrent state tensors {ht_1 | ct_1} of Shape{sequence length*batch_size, feature_size}
// [2] - initializer for the input weights matrix, used for the linear transformation of the inputs.
// [3] - initializer for the recurrent weights matrix, used for the linear transformation of the recurrent state.
// [4] - Initializer for the bias vector w.r.to inputs + hidden state (ibh_bias + hbh_bias)
// OUTPUT VALUE: A tuple with the following structure:
// [0] - ht, output tensor with shape (sequence_length*batch_size, num_hidden) .
// [1] - {ht | ct} output recurrent state tensor with the same shape as states
// This version of the LSTM op supports MKLDNN emitter code, this can be used standalone for computing RNN
// without fusing RNN cell (LSTM)'s across time steps.
Lstm(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> weights_layer,
std::shared_ptr<Node> weights_iter,
std::shared_ptr<Node> bias);
Shape get_output_tensor_shape() const { return m_output_tensor_shape; }
Shape get_output_cell_shape() const { return m_output_cell_shape; }
int get_num_timesteps() const { return m_num_timesteps; }
int get_src_sequence_length() const { return m_src_sequence_length; }
int get_gates_per_cell() const { return m_num_gates_per_cell; }
int get_batch_size() const { return m_batch_size; }
int get_src_layer_feature_size() const { return m_src_layer_feature_size; }
int get_src_iter_feature_size() const { return m_src_iter_feature_size; }
int get_num_cell_states() const { return m_num_cell_states; }
int get_direction() const { return m_direction; }
int get_num_fused_layers() const { return m_num_fused_layers; }
int get_fused_inputs() const { return m_fused_inputs; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
Shape m_output_tensor_shape;
Shape m_output_cell_shape;
int m_num_timesteps;
int m_num_gates_per_cell;
int m_src_sequence_length;
int m_batch_size;
int m_src_layer_feature_size;
int m_src_iter_feature_size;
int m_num_cell_states;
int m_direction;
int m_num_fused_layers;
bool m_fused_inputs; // True if node gets fused inputs/weights
};
}
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
shared_ptr<Node> op::Rnn::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 5)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Rnn>(new_args[0],
new_args[1],
new_args[2],
new_args[3],
new_args[4],
m_num_timesteps,
m_num_gates_per_cell,
m_src_sequence_length,
m_src_layer_feature_size,
m_src_iter_feature_size,
m_num_cell_states,
m_direction,
m_num_fused_layers);
}
op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> weights_layer,
std::shared_ptr<Node> weights_iter,
std::shared_ptr<Node> bias,
const int num_timesteps,
const int num_gates_per_cell,
const int src_sequence_length,
const int src_layer_feature_size,
const int src_iter_feature_size,
const int num_cell_states,
const int direction,
const int num_fused_layers)
: RequiresTensorViewArgs("Rnn", {src_layer, src_iter, weights_layer, weights_iter, bias})
, m_num_timesteps(num_timesteps)
, m_num_gates_per_cell(num_gates_per_cell)
, m_src_sequence_length(src_sequence_length)
, m_src_layer_feature_size(src_layer_feature_size)
, m_src_iter_feature_size(src_iter_feature_size)
, m_num_cell_states(num_cell_states)
, m_direction(direction)
, m_num_fused_layers(num_fused_layers)
{
if (src_layer->get_shape().size() != weights_layer->get_shape().size())
{
throw ngraph_error("src_layer and i2h weights size dont match");
}
if (src_iter->get_shape().size() != weights_iter->get_shape().size())
{
throw ngraph_error("src_iter and h2h weights size dont match");
}
if (src_layer->get_shape().size() == 2)
{
m_batch_size = static_cast<int>(src_layer->get_shape()[0] / num_timesteps);
}
else
{
throw ngraph_error("src_layer doesnt have a rank 2");
}
if (shape_size(src_layer->get_shape()) !=
m_src_sequence_length * m_batch_size * m_src_layer_feature_size)
{
std::cout << "shape_size: " << shape_size(src_layer->get_shape()) << std::endl;
throw ngraph_error("src_layer size is not equal t*n*c");
}
if (bias->get_shape()[0] != weights_layer->get_shape()[0] ||
bias->get_shape()[0] != weights_iter->get_shape()[0])
{
throw ngraph_error("bias and weights_shape are not compatible");
}
auto et = src_layer->get_element_type();
for (auto& rnn_input : get_arguments())
{
if (rnn_input->get_element_type() != et)
{
throw ngraph_error("all rnn inputs must have the same element type");
}
}
add_output(src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_num_timesteps * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
add_output(src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_num_cell_states * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/util/requires_tensor_view_args.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
// This is RNN op, which is formed by the fusion of multiple RNN cells ( LSTM/ GRU/ vanilla RNN)
// across multiple time slices
// INPUTS:
// [0] - {X0, X1...., Xt} input tensor of layout TNC, Shape{sequence length*batch_size, feature_size}
// [1] - recurrent state tensors {ht_1 | ct_1} of Shape{sequence length*batch_size, feature_size}
// [2] - initializer for the input weights matrix, used for the linear transformation of the inputs.
// [3] - initializer for the recurrent weights matrix, used for the linear transformation of the recurrent state.
// [4] - Initializer for the bias vector w.r.to inputs + hidden state (ibh_bias + hbh_bias)
// number_of_timesteps - number of unrolled cells up to timestep t.
// num_gates_per_cell - number of gates per RNN cell, LSTM = 4, GRU = 3, vanilla RNN = 1
// src_sequence_length - this will be same as number_of_timesteps
// src_layer_feature_size - feature size w.r.to input tensor
// src_iter_feature_size - feature size w.r.to hidden state
// num_cell_states - number of recurrent state tensor states , LSTM = 2, GRU = 1, vanilla RNN = 1
// OUTPUT VALUE: A tuple with the following structure:
// [0] - ht, output tensor with shape (sequence_length*batch_size, feature_size) .
// [1] - {ht | ct} output recurrent state tensor with the same shape as states i.e (sequence_length*batch_size, feature_size)
class Rnn : public util::RequiresTensorViewArgs
{
public:
Rnn(std::shared_ptr<Node> src_layer,
std::shared_ptr<Node> src_iter,
std::shared_ptr<Node> weights_layer,
std::shared_ptr<Node> weights_iter,
std::shared_ptr<Node> bias,
const int num_timesteps,
const int num_gates_per_cell,
const int src_sequence_length,
const int src_layer_feature_size,
const int src_iter_feature_size,
const int num_cell_states,
const int direction,
const int num_fused_layers);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int get_num_timesteps() const { return m_num_timesteps; }
int get_src_sequence_length() const { return m_src_sequence_length; }
int get_gates_per_cell() const { return m_num_gates_per_cell; }
int get_batch_size() const { return m_batch_size; }
int get_src_layer_feature_size() const { return m_src_layer_feature_size; }
int get_src_iter_feature_size() const { return m_src_iter_feature_size; }
int get_num_cell_states() const { return m_num_cell_states; }
int get_direction() const { return m_direction; }
int get_num_fused_layers() const { return m_num_fused_layers; }
private:
int m_num_timesteps;
int m_num_gates_per_cell;
int m_src_sequence_length;
int m_batch_size;
int m_src_layer_feature_size;
int m_src_iter_feature_size;
int m_num_cell_states;
int m_direction;
int m_num_fused_layers;
};
}
}
......@@ -37,7 +37,9 @@
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
using namespace std;
......@@ -497,6 +499,48 @@ namespace ngraph
batchnorm->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Lstm)
{
auto src_layer_rank = node->get_input_shape(0).size();
auto src_iter_rank = node->get_input_shape(1).size();
auto weights_layer_rank = node->get_input_shape(2).size();
auto weights_iter_rank = node->get_input_shape(3).size();
auto bias_rank = node->get_input_shape(4).size();
if ((src_layer_rank == 2 && src_iter_rank == 2 && weights_layer_rank == 2 &&
weights_iter_rank == 2 && bias_rank == 1 &&
node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32))
{
auto lstm_node = static_cast<op::Lstm*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
lstm_node->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Rnn)
{
auto src_layer_rank = node->get_input_shape(0).size();
auto src_iter_rank = node->get_input_shape(1).size();
auto weights_layer_rank = node->get_input_shape(2).size();
auto weights_iter_rank = node->get_input_shape(3).size();
auto bias_rank = node->get_input_shape(4).size();
if ((src_layer_rank == 2 && src_iter_rank == 2 && weights_layer_rank == 2 &&
weights_iter_rank == 2 && bias_rank == 1 &&
node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32))
{
auto rnn_node = static_cast<op::Rnn*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
rnn_node->set_op_annotations(op_annotations);
}
}
}
}
}
......@@ -542,6 +586,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Sigmoid>},
{TI(ngraph::op::SigmoidBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::SigmoidBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Rnn>},
};
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "cpu_concat_inputs.hpp"
#include <algorithm>
#include <iostream>
#include <numeric>
#include <typeindex>
#include <typeinfo>
#include <unordered_set>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
using namespace ngraph;
void ngraph::runtime::cpu::pass::ConcatInputs::concat_lstm_inputs()
{
auto ht_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
auto weights_h2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
auto xt = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
auto weights_i2h = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
auto bias1 = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto bias2 = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto ct_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
auto lstm = std::make_shared<op::Lstm>(xt, weights_i2h, ht_1, weights_h2h, bias1, bias2, ct_1);
auto goe = std::make_shared<op::GetOutputElement>(lstm, 0);
auto lstm_node_label = std::make_shared<pattern::op::Label>(goe, nullptr, NodeVector{goe});
pattern::graph_rewrite_callback callback =
[lstm_node_label, xt, weights_h2h, ht_1, weights_i2h, bias1, bias2, ct_1](
pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
NGRAPH_DEBUG << " In LSTM MKLDNN callback";
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
std::shared_ptr<Node> src_layer = pattern_map[xt];
std::shared_ptr<Node> src_iter =
std::make_shared<op::Concat>(NodeVector{pattern_map[ht_1], pattern_map[ct_1]}, 0);
std::shared_ptr<Node> bias =
std::make_shared<op::Add>(pattern_map[bias1], pattern_map[bias2]);
auto lstm_node = pattern_map[lstm_node_label]->get_arguments()[0];
auto batch_size = std::dynamic_pointer_cast<op::Lstm>(lstm_node)->get_batch_size();
auto feature_size =
std::dynamic_pointer_cast<op::Lstm>(lstm_node)->get_src_iter_feature_size();
auto lstm_mkldnn_node = std::make_shared<op::Lstm>(
src_layer, src_iter, pattern_map[weights_i2h], pattern_map[weights_h2h], bias);
auto lstm_ht_out = std::make_shared<op::GetOutputElement>(lstm_mkldnn_node, 0);
auto lstm_ht_ct_out = std::make_shared<op::GetOutputElement>(lstm_mkldnn_node, 1);
// dst_iter of lstm mkldnn output holds the results of both recurrent state
// tensor outputs. we need to slice the ct.
auto ht_slice =
std::make_shared<op::Slice>(lstm_ht_ct_out,
Coordinate{0, 0},
Coordinate{static_cast<unsigned long>(batch_size),
static_cast<unsigned long>(feature_size)});
auto ct_slice =
std::make_shared<op::Slice>(lstm_ht_ct_out,
Coordinate{static_cast<unsigned long>(batch_size), 0},
Coordinate{static_cast<unsigned long>(2 * batch_size),
static_cast<unsigned long>(feature_size)});
// now go through the GOE'sand replace the slices(ht)
std::set<std::shared_ptr<ngraph::Node>> lstm_outputs;
for (auto& goes : lstm_node->get_outputs().at(0).get_inputs())
{
auto goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(goes->get_node());
lstm_outputs.insert(goes->get_node());
// first output node of lstm
if (goe_node->get_n() == 0)
{
NGRAPH_DEBUG << "Replacing 1st output Lstm node " << goe_node->get_name()
<< " with " << lstm_ht_out->get_name();
ngraph::replace_node(goe_node, lstm_ht_out);
}
else if (goe_node->get_n() == 1)
{
for (auto& goe_ct_user : goe_node->get_users())
{
for (size_t i = 0; i < goe_ct_user->get_input_size(); i++)
{
if (goe_ct_user->get_argument(i) == goe_node)
{
goe_ct_user->get_inputs().at(i).replace_output(
ct_slice->get_outputs().at(0));
}
}
}
NGRAPH_DEBUG << "Replacing 2nd output Lstm node " << goe_node->get_name()
<< " with " << ct_slice->get_name();
}
}
if (lstm_outputs.find(m.get_match_root()) == lstm_outputs.end())
{
throw ngraph_error(
"Pattern matcher error, matched root node should be one of the LSTM outputs");
}
return true;
};
auto m = std::make_shared<pattern::Matcher>(lstm_node_label, callback);
this->add_matcher(m);
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class ConcatInputs;
}
}
}
}
class ngraph::runtime::cpu::pass::ConcatInputs : public ngraph::pass::GraphRewrite
{
public:
ConcatInputs()
: GraphRewrite()
{
concat_lstm_inputs();
}
private:
void concat_lstm_inputs();
};
......@@ -43,7 +43,9 @@
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
using namespace std;
......@@ -1349,6 +1351,38 @@ namespace ngraph
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Lstm)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
// TODO: for now, framework formats for src_layer, src_iter, weights_layer and weights_iter
// matches to the expected mkldnn format. we need to handle a case to insert convert Op's
// if the format doesn't matches.
set_default_layouts(external_function, node, false);
}
else
{
throw ngraph_error("LSTM fused op is only supported in MKLDNN for now.");
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Rnn)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
// TODO: for now, framework formats for src_layer, src_iter, weights_layer and weights_iter
// matches to the expected mkldnn format. we need to handle a case to insert convert Op's
// if the format doesn't matches.
set_default_layouts(external_function, node, false);
}
else
{
throw ngraph_error("RNN fused op is only supported in MKLDNN for now.");
}
}
}
}
}
......@@ -1396,6 +1430,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Sigmoid>},
{TI(ngraph::op::SigmoidBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::SigmoidBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Rnn>},
};
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
This diff is collapsed.
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class LSTMFusion;
class RNNFusion;
}
}
}
}
class ngraph::runtime::cpu::pass::LSTMFusion : public ngraph::pass::GraphRewrite
{
public:
LSTMFusion()
: GraphRewrite()
{
construct_sigmoid();
construct_lstm_fprop();
}
private:
void construct_sigmoid();
void construct_lstm_fprop();
};
class ngraph::runtime::cpu::pass::RNNFusion : public ngraph::pass::RecurrentGraphRewrite
{
public:
RNNFusion()
: RecurrentGraphRewrite()
{
construct_rnn_lstm_fprop();
}
private:
void construct_rnn_lstm_fprop();
};
......@@ -44,10 +44,14 @@
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/pass/cpu_concat_inputs.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_mat_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/serializer.hpp"
......@@ -1078,7 +1082,6 @@ TEST(cpu_fusion, rnn_fusion_from_json_model)
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
const size_t NUM_STEPS = 10;
auto mmb_predicate = [](std::shared_ptr<Node> node) {
auto users = node->get_users();
......@@ -1293,3 +1296,213 @@ TEST(cpu_fusion, batch_norm_folding)
auto cpu_results = execute(cpu_f, args, "CPU");
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
}
TEST(cpu_fusion, rnn_fprop_1_lstm_cell)
{
auto src_layer = make_shared<op::Parameter>(element::f32, Shape{10, 100});
auto src_iter = make_shared<op::Parameter>(element::f32, Shape{20, 100});
auto weights_layer = make_shared<op::Parameter>(element::f32, Shape{400, 100});
auto weights_iter = make_shared<op::Parameter>(element::f32, Shape{400, 100});
auto biases = make_shared<op::Parameter>(element::f32, Shape{400});
const int number_of_timesteps = 1;
const int number_of_gates_per_cell = 4;
const int src_seq_length = 1;
const int src_layer_feature_size = 100;
const int feature_size = 100;
const int num_rnn_cell_states = 2;
const int rnn_direction = 1;
const int num_of_rnn_fused_layer = 1;
auto rnn_node = make_shared<op::Rnn>(src_layer,
src_iter,
weights_layer,
weights_iter,
biases,
number_of_timesteps,
number_of_gates_per_cell,
src_seq_length,
src_layer_feature_size,
feature_size,
num_rnn_cell_states,
rnn_direction,
num_of_rnn_fused_layer);
auto rnn_ht_output = make_shared<op::GetOutputElement>(rnn_node, 0);
auto rnn_ct_output = make_shared<op::GetOutputElement>(rnn_node, 1);
auto func = make_shared<Function>(
NodeVector{rnn_ht_output, rnn_ct_output},
op::ParameterVector{src_layer, src_iter, weights_layer, weights_iter, biases});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::TensorView> src_layer_t =
backend->create_tensor(element::f32, src_layer->get_shape());
shared_ptr<runtime::TensorView> src_iter_t =
backend->create_tensor(element::f32, src_iter->get_shape());
shared_ptr<runtime::TensorView> weights_layer_t =
backend->create_tensor(element::f32, weights_layer->get_shape());
shared_ptr<runtime::TensorView> weights_iter_t =
backend->create_tensor(element::f32, weights_iter->get_shape());
shared_ptr<runtime::TensorView> biases_t =
backend->create_tensor(element::f32, biases->get_shape());
shared_ptr<runtime::TensorView> result_ht = backend->create_tensor(element::f32, {10, 100});
shared_ptr<runtime::TensorView> result_ct =
backend->create_tensor(element::f32, Shape{20, 100});
copy_data(src_layer_t, vector<float>(1000, 1));
copy_data(src_iter_t, vector<float>(2000, 1));
copy_data(weights_layer_t, vector<float>(400 * 100, 1));
copy_data(weights_iter_t, vector<float>(400 * 100, 1));
copy_data(biases_t, vector<float>(400, 1));
backend->call(func,
{result_ht, result_ct},
{src_layer_t, src_iter_t, weights_layer_t, weights_iter_t, biases_t});
vector<float> expected_ht(10 * 100, 0.964028f);
vector<float> expected_ct;
for (size_t i = 0; i < 20 * 100; i++)
{
if (i < 1000)
{
expected_ct.push_back(0.964028f);
}
else
{
expected_ct.push_back(2.0f);
}
}
EXPECT_TRUE(test::all_close(expected_ht, read_vector<float>(result_ht)));
EXPECT_TRUE(test::all_close(expected_ct, read_vector<float>(result_ct)));
}
TEST(cpu_fusion, fuse_lstm_cells)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/2rnn_layer_3lstm_cell.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
auto lstm_ops = get_ops_of_type<op::Lstm>(func);
EXPECT_EQ(lstm_ops.size(), 6);
}
TEST(cpu_fusion, fuse_2_layer_rnn)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/2rnn_layer_3lstm_cell.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t count = count_ops_of_type<op::Rnn>(func);
auto rnn_ops = get_ops_of_type<op::Rnn>(func);
EXPECT_EQ(rnn_ops.size(), count);
for (auto& node : rnn_ops)
{
EXPECT_EQ(node->get_num_timesteps(), node->get_src_sequence_length());
EXPECT_EQ(node->get_num_cell_states(), node->get_argument(1)->get_arguments().size());
}
}
TEST(cpu_fusion, fuse_1_layer_rnn)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/1rnn_layer_3lstm_cell.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
size_t count = count_ops_of_type<op::Rnn>(func);
auto rnn_ops = get_ops_of_type<op::Rnn>(func);
EXPECT_EQ(rnn_ops.size(), 1);
EXPECT_EQ(rnn_ops.size(), count);
for (auto& node : rnn_ops)
{
EXPECT_EQ(node->get_num_timesteps(), node->get_src_sequence_length());
EXPECT_EQ(node->get_num_cell_states(), node->get_argument(1)->get_arguments().size());
}
}
static std::shared_ptr<Function> make_function(const std::string& file_name)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, file_name);
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
return func;
}
TEST(cpu_fusion, rnn_fusion_inter_vs_cpu_1lstm_cell)
{
const std::string file_name("mxnet/1_lstm_cell_forward.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, rnn_fusion_inter_vs_cpu_1rnn_layer_3lstm_cell)
{
const std::string file_name("mxnet/1rnn_layer_3lstm_cell.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, rnn_fusion_inter_vs_cpu_2rnn_layer_3lstm_cell)
{
const std::string file_name("mxnet/2rnn_layer_3lstm_cell.json");
auto cpu_f = make_function(file_name);
auto int_f = make_function(file_name);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
......@@ -31,6 +31,7 @@
#include "ngraph/op/parameter.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
......
[{
"name" : "Function_0",
"ops" : [
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_10",
"op" : "Parameter",
"outputs" : ["Parameter_10_0"],
"shape" : [400]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_9",
"op" : "Parameter",
"outputs" : ["Parameter_9_0"],
"shape" : [ 400, 100 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_2",
"op" : "Parameter",
"outputs" : ["Parameter_2_0"],
"shape" : [400]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_1",
"op" : "Parameter",
"outputs" : ["Parameter_1_0"],
"shape" : [ 400, 100 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_0",
"op" : "Parameter",
"outputs" : ["Parameter_0_0"],
"shape" : [ 10, 100 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_7",
"op" : "Constant",
"outputs" : ["Constant_7_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_34",
"op" : "Constant",
"outputs" : ["Constant_34_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_30",
"op" : "Constant",
"outputs" : ["Constant_30_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_24",
"op" : "Constant",
"outputs" : ["Constant_24_0"],
"shape" : [],
"value" : ["1"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_17",
"op" : "Constant",
"outputs" : ["Constant_17_0"],
"shape" : [],
"value" : ["1"]
},
{
"axes" : [0],
"inputs" : ["Parameter_10"],
"name" : "Broadcast_13",
"op" : "Broadcast",
"outputs" : ["Broadcast_13_0"],
"shape" : [ 10, 400 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_9"],
"name" : "Reshape_11",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_11_0"]
},
{
"axes" : [0],
"inputs" : ["Parameter_2"],
"name" : "Broadcast_5",
"op" : "Broadcast",
"outputs" : ["Broadcast_5_0"],
"shape" : [ 10, 400 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_1"],
"name" : "Reshape_3",
"op" : "Reshape",
"output_shape" : [ 100, 400 ],
"outputs" : ["Reshape_3_0"]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_7"],
"name" : "Broadcast_8",
"op" : "Broadcast",
"outputs" : ["Broadcast_8_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_34"],
"name" : "Broadcast_35",
"op" : "Broadcast",
"outputs" : ["Broadcast_35_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_30"],
"name" : "Broadcast_31",
"op" : "Broadcast",
"outputs" : ["Broadcast_31_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_24"],
"name" : "Broadcast_25",
"op" : "Broadcast",
"outputs" : ["Broadcast_25_0"],
"shape" : [ 10, 100 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_17"],
"name" : "Broadcast_18",
"op" : "Broadcast",
"outputs" : ["Broadcast_18_0"],
"shape" : [ 10, 100 ]
},
{
"inputs" : [ "Parameter_0", "Reshape_3" ],
"name" : "Dot_4",
"op" : "Dot",
"outputs" : ["Dot_4_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Broadcast_8", "Reshape_11" ],
"name" : "Dot_12",
"op" : "Dot",
"outputs" : ["Dot_12_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Dot_4", "Broadcast_5" ],
"name" : "Add_6",
"op" : "Add",
"outputs" : ["Add_6_0"]
},
{
"inputs" : [ "Dot_12", "Broadcast_13" ],
"name" : "Add_14",
"op" : "Add",
"outputs" : ["Add_14_0"]
},
{
"inputs" : [ "Add_6", "Add_14" ],
"name" : "Add_15",
"op" : "Add",
"outputs" : ["Add_15_0"]
},
{
"inputs" : ["Add_15"],
"lower_bounds" : [ 0, 300 ],
"name" : "Slice_16",
"op" : "Slice",
"outputs" : ["Slice_16_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 400 ]
},
{
"inputs" : ["Add_15"],
"lower_bounds" : [ 0, 100 ],
"name" : "Slice_23",
"op" : "Slice",
"outputs" : ["Slice_23_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 200 ]
},
{
"inputs" : ["Add_15"],
"lower_bounds" : [ 0, 0 ],
"name" : "Slice_33",
"op" : "Slice",
"outputs" : ["Slice_33_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 100 ]
},
{
"inputs" : ["Add_15"],
"lower_bounds" : [ 0, 200 ],
"name" : "Slice_40",
"op" : "Slice",
"outputs" : ["Slice_40_0"],
"strides" : [ 1, 1 ],
"upper_bounds" : [ 10, 300 ]
},
{
"inputs" : ["Slice_16"],
"name" : "Negative_19",
"op" : "Negative",
"outputs" : ["Negative_19_0"]
},
{
"inputs" : ["Slice_23"],
"name" : "Negative_26",
"op" : "Negative",
"outputs" : ["Negative_26_0"]
},
{
"inputs" : ["Slice_33"],
"name" : "Negative_36",
"op" : "Negative",
"outputs" : ["Negative_36_0"]
},
{
"inputs" : ["Slice_40"],
"name" : "Tanh_41",
"op" : "Tanh",
"outputs" : ["Tanh_41_0"]
},
{
"inputs" : ["Negative_19"],
"name" : "Exp_20",
"op" : "Exp",
"outputs" : ["Exp_20_0"]
},
{
"inputs" : ["Negative_26"],
"name" : "Exp_27",
"op" : "Exp",
"outputs" : ["Exp_27_0"]
},
{
"inputs" : ["Negative_36"],
"name" : "Exp_37",
"op" : "Exp",
"outputs" : ["Exp_37_0"]
},
{
"inputs" : [ "Broadcast_18", "Exp_20" ],
"name" : "Add_21",
"op" : "Add",
"outputs" : ["Add_21_0"]
},
{
"inputs" : [ "Broadcast_25", "Exp_27" ],
"name" : "Add_28",
"op" : "Add",
"outputs" : ["Add_28_0"]
},
{
"inputs" : [ "Broadcast_35", "Exp_37" ],
"name" : "Add_38",
"op" : "Add",
"outputs" : ["Add_38_0"]
},
{
"inputs" : [ "Broadcast_18", "Add_21" ],
"name" : "Divide_22",
"op" : "Divide",
"outputs" : ["Divide_22_0"]
},
{
"inputs" : [ "Broadcast_25", "Add_28" ],
"name" : "Divide_29",
"op" : "Divide",
"outputs" : ["Divide_29_0"]
},
{
"inputs" : [ "Broadcast_35", "Add_38" ],
"name" : "Divide_39",
"op" : "Divide",
"outputs" : ["Divide_39_0"]
},
{
"inputs" : [ "Divide_29", "Broadcast_31" ],
"name" : "Multiply_32",
"op" : "Multiply",
"outputs" : ["Multiply_32_0"]
},
{
"inputs" : [ "Divide_39", "Tanh_41" ],
"name" : "Multiply_42",
"op" : "Multiply",
"outputs" : ["Multiply_42_0"]
},
{
"inputs" : [ "Multiply_32", "Multiply_42" ],
"name" : "Add_43",
"op" : "Add",
"outputs" : ["Add_43_0"]
},
{
"inputs" : ["Add_43"],
"name" : "Tanh_44",
"op" : "Tanh",
"outputs" : ["Tanh_44_0"]
},
{
"inputs" : [ "Divide_22", "Tanh_44" ],
"name" : "Multiply_45",
"op" : "Multiply",
"outputs" : ["Multiply_45_0"]
},
{
"inputs" : ["Multiply_45"],
"name" : "Result_46",
"op" : "Result",
"outputs" : ["Result_46_0"]
}
],
"parameters" : [
"Parameter_0", "Parameter_1", "Parameter_2", "Parameter_9",
"Parameter_10"
],
"result" : ["Result_46"]
}]
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment