Commit e5d606b8 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

LSTM MKLDNN integration for ONNX LSTM op (#3327)

* - Add graph pass method for onnx lstmcell rewrite with lstm cpu op
- insert reshapes to keep the weights in ldigo format
- test case for onnx LstmCell to CPU Lstm

* fix typo

* - check LSTMCell for the fused op decomposistion in the backend

* - fix bug in onnx_lstm graph pass
- passes unit test

* style-fix

* - fix compilation error
- use IFCO gate ordering for bias

*  - Skip LSTMCell to LSTM CPU fusion for peephole

* - add comment && remove duplicate function

* -use dynamic_pointer_cast to check for constant

* - onnx bias will be of shape (2 * gates_count * hidden_size) bias of Wb and Rb are concatenated, we will split the bias, add and rearrange in order IFCO

* - Use most derived LSTM ctor for pattern matching

* - Style Fix

* style fix

* Address PR comments

* - add support for graph pass (MKLDNN version > 1) for mapping LSTMCell -> LSTM CPU op

* fix unit test failure for MKLDNN V1.0
parent c831dc71
...@@ -86,6 +86,7 @@ ...@@ -86,6 +86,7 @@
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/group_conv.hpp" #include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp" #include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
...@@ -1157,6 +1158,26 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1157,6 +1158,26 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
auto dex = is_direct_execution(); auto dex = is_direct_execution();
auto is_supported = [dex](const Node& node) { auto is_supported = [dex](const Node& node) {
// this checks averts the decomposition of LSTMCell
// we will map LSTMCell to LSTM CPU op in the later
// graph pass
if (typeid(ngraph::op::LSTMCell) == typeid(node))
{
// MKLDNN version < 1.0 doesnt support peephole for LSTM, we will skip if the LSTMCell has peephole.
// LSTMCell with no peephole support is constant initialized to zero
// TODO (pthoreho) : For MKLDNN > V1.0, change mkldnn kernel integration to compute for LSTMCell
// with peephole as well.
if (std::dynamic_pointer_cast<ngraph::op::Constant>(node.get_argument(6)) != nullptr)
{
return true;
}
else
{
return false;
}
}
if (dex) if (dex)
{ {
auto handler = GetGlobalBuildDispatcher().find(type_index(typeid(node))); auto handler = GetGlobalBuildDispatcher().find(type_index(typeid(node)));
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/exp.hpp" #include "ngraph/op/exp.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp" #include "ngraph/op/negative.hpp"
...@@ -62,6 +63,177 @@ ...@@ -62,6 +63,177 @@
} }
using namespace ngraph; using namespace ngraph;
void ngraph::runtime::cpu::pass::LSTMFusion::construct_onnx_lstmcell_fprop()
{
size_t ref_batch_size = 2;
size_t ref_input_size = 3;
size_t ref_hidden_size = 3;
size_t ref_gates_count = 4;
auto X =
std::make_shared<pattern::op::Label>(element::f32, Shape{ref_batch_size, ref_input_size});
auto W = std::make_shared<pattern::op::Label>(
element::f32, Shape{ref_gates_count * ref_hidden_size, ref_input_size});
auto R = std::make_shared<pattern::op::Label>(
element::f32, Shape{ref_gates_count * ref_hidden_size, ref_hidden_size});
auto bias_ref = std::make_shared<pattern::op::Label>(
element::f32, Shape{2 * ref_gates_count * ref_hidden_size});
auto peep_hole = std::make_shared<pattern::op::Label>(element::f32, Shape{3 * ref_hidden_size});
auto H_t =
std::make_shared<pattern::op::Label>(element::f32, Shape{ref_batch_size, ref_hidden_size});
auto C_t =
std::make_shared<pattern::op::Label>(element::f32, Shape{ref_batch_size, ref_hidden_size});
auto ref_lstm_cell =
std::make_shared<op::LSTMCell>(X,
W,
R,
H_t,
C_t,
ref_hidden_size,
bias_ref,
peep_hole,
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
std::vector<float>{},
std::vector<float>{},
0.f,
false);
auto callback = [X, W, R, H_t, C_t](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
auto target_lstm_node = m.get_match_root();
auto lstmcell_op = std::dynamic_pointer_cast<op::LSTMCell>(m.get_match_root());
auto src_iter =
std::make_shared<ngraph::op::Concat>(NodeVector{pattern_map[H_t], pattern_map[C_t]}, 0);
auto bias_iofc = target_lstm_node->get_argument(5);
// we need to reorder W, R and bias from IOFC to IFCO gate order
// Note: ONNX runtime provides W, R and bias in the gate order [IOFC] but
// MKLDNN computes LSTM kernel in the [IFCO] order.
auto get_weights_ifco_gate_order =
[&](std::shared_ptr<Node> weights_graph_node) -> std::shared_ptr<Node> {
// slices will be in ICFO order
std::vector<std::shared_ptr<Node>> gate_slices;
size_t dim0 = weights_graph_node->get_shape()[0] / 4;
size_t dim1 = weights_graph_node->get_shape()[1];
for (size_t i = 0; i < 4; i++)
{
auto slice = std::make_shared<ngraph::op::Slice>(
weights_graph_node, Coordinate{i * dim0, 0}, Coordinate{(i + 1) * dim0, dim1});
gate_slices.push_back(slice);
}
auto weights_ifco = std::make_shared<ngraph::op::Concat>(
NodeVector{gate_slices[0], gate_slices[2], gate_slices[3], gate_slices[1]}, 0);
return weights_ifco;
};
auto get_bias_ifco_gate_order =
[&](std::shared_ptr<Node> bias_graph_node) -> std::shared_ptr<Node> {
size_t hidden_size = lstmcell_op->get_hidden_size();
auto Wb_bias = std::make_shared<ngraph::op::Slice>(
bias_graph_node, Coordinate{0}, Coordinate{4 * hidden_size});
auto Rb_bias = std::make_shared<ngraph::op::Slice>(
bias_graph_node, Coordinate{4 * hidden_size}, Coordinate{2 * 4 * hidden_size});
auto bias = std::make_shared<op::Add>(Wb_bias, Rb_bias);
// slices will be in ICFO order
std::vector<std::shared_ptr<Node>> gate_slices;
for (size_t i = 0; i < 4; i++)
{
auto slice = std::make_shared<ngraph::op::Slice>(
bias, Coordinate{i * hidden_size}, Coordinate{(i + 1) * hidden_size});
gate_slices.push_back(slice);
}
auto new_bias = std::make_shared<ngraph::op::Concat>(
NodeVector{gate_slices[0], gate_slices[2], gate_slices[3], gate_slices[1]}, 0);
return new_bias;
};
auto W_iofc = pattern_map[W];
auto R_iofc = pattern_map[R];
auto W_ifco = get_weights_ifco_gate_order(W_iofc);
auto R_ifco = get_weights_ifco_gate_order(R_iofc);
// here onnx bias will be of shape (2 * gates_count * hidden_size) bias of Wb and Rb are concatenated, we will split the bias, add and rearrange in order IFCO
auto bias_ifco = get_bias_ifco_gate_order(bias_iofc);
auto W_reshape = std::make_shared<op::Reshape>(
W_ifco, AxisVector{1, 0}, Shape{W_ifco->get_shape()[1], W_ifco->get_shape()[0]});
auto R_reshape = std::make_shared<op::Reshape>(
R_ifco, AxisVector{1, 0}, Shape{R_ifco->get_shape()[1], R_ifco->get_shape()[0]});
#if MKLDNN_VERSION_MAJOR < 1
auto lstm_node = std::make_shared<ngraph::op::Lstm>(
pattern_map[X], src_iter, W_reshape, R_reshape, bias_ifco, rnn_type);
if (lstm_node->get_outputs().size() != 2)
{
throw ngraph_error("Lstm node doesnt have two outputs");
}
#else
auto lstm_node = std::make_shared<ngraph::op::Lstm>(pattern_map[X],
pattern_map[H_t],
pattern_map[C_t],
W_reshape,
R_reshape,
bias_ifco,
rnn_type);
if (lstm_node->get_outputs().size() != 3)
{
throw ngraph_error("Lstm node doesnt have three outputs");
}
#endif
#if MKLDNN_VERSION_MAJOR < 1
auto lstm_ht_output = std::make_shared<ngraph::op::GetOutputElement>(lstm_node, 0);
auto lstm_ht_ct_output = std::make_shared<ngraph::op::GetOutputElement>(lstm_node, 1);
#else
auto lstm_ht_output = std::make_shared<ngraph::op::GetOutputElement>(lstm_node, 1);
auto ct_slice = std::make_shared<ngraph::op::GetOutputElement>(lstm_node, 2);
#endif
// set LSTM cell attributes
const size_t lstm_n_gates = 4;
const size_t batch_size = pattern_map[X]->get_shape()[0];
const size_t direction = 1;
const size_t layers = 1;
auto dic = pattern_map[R]->get_shape()[0] / (lstm_n_gates * direction * layers);
auto goe_nodes = ngraph::op::get_output_elements(m.get_match_root());
auto dst_layer = goe_nodes[0];
auto dst_iter = goe_nodes[1];
// dst_iter of lstm mkldnn output holds the results of both recurrent state
// tensor outputs. we need to slice the ct.
#if MKLDNN_VERSION_MAJOR < 1
auto ct_slice = std::make_shared<ngraph::op::Slice>(
lstm_ht_ct_output, Coordinate{batch_size, 0}, Coordinate{(2 * batch_size), dic});
#endif
// find the user's for {ht} and replace them with lstm_goe_0
if (std::dynamic_pointer_cast<ngraph::op::GetOutputElement>(dst_iter) != nullptr)
{
ngraph::replace_node(dst_iter, ct_slice);
}
// find the user's for {ht} and replace them with lstm_goe_0
if (std::dynamic_pointer_cast<ngraph::op::GetOutputElement>(dst_layer) != nullptr)
{
ngraph::replace_node(dst_layer, lstm_ht_output);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(ref_lstm_cell, "LSTMFusion.onnx_lstm_cell");
this->add_matcher(m, callback);
}
void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid() void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid()
{ {
// construct variance // construct variance
......
...@@ -45,11 +45,13 @@ public: ...@@ -45,11 +45,13 @@ public:
{ {
construct_sigmoid(); construct_sigmoid();
construct_lstm_fprop(); construct_lstm_fprop();
construct_onnx_lstmcell_fprop();
} }
private: private:
void construct_sigmoid(); void construct_sigmoid();
void construct_lstm_fprop(); void construct_lstm_fprop();
void construct_onnx_lstmcell_fprop();
}; };
class CPU_BACKEND_API ngraph::runtime::cpu::pass::RNNFusion class CPU_BACKEND_API ngraph::runtime::cpu::pass::RNNFusion
......
...@@ -3820,6 +3820,53 @@ TEST(cpu_fusion, rnn_fusion_1rnn_layer_3lstm_cell) ...@@ -3820,6 +3820,53 @@ TEST(cpu_fusion, rnn_fusion_1rnn_layer_3lstm_cell)
} }
} }
TEST(cpu_fusion, lstm_cell)
{
auto make_function = []() {
const size_t batch_size = 3;
const size_t input_size = 4;
const size_t hidden_size = 4;
const size_t gates_count = 4;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
const auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X, W, R, H_t, C_t, hidden_size);
auto ht = make_shared<op::GetOutputElement>(lstm_cell, 0);
auto ct = make_shared<op::GetOutputElement>(lstm_cell, 1);
auto lstm_function =
make_shared<Function>(NodeVector{ht, ct}, ParameterVector{X, W, R, H_t, C_t});
return lstm_function;
};
auto lstm_function_cpu = make_function();
auto lstm_function_inter = make_function();
test::Uniform<float> rng(-1.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : lstm_function_cpu->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(lstm_function_inter, args, "INTERPRETER");
auto cpu_results = execute(lstm_function_cpu, args, "CPU");
size_t lstm_op_count = count_ops_of_type<op::LSTMCell>(lstm_function_cpu);
EXPECT_EQ(lstm_op_count, 0);
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell) TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell)
{ {
const std::string file_name("mxnet/2rnn_layer_3lstm_cell.json"); const std::string file_name("mxnet/2rnn_layer_3lstm_cell.json");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment