Commit c4c24cb0 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Pruthvi/fix rnn output (#1135)

* - Fixed replace output for the multi layer recurrent cell state tensor output
- Modified rnn add_output to consider direction and n_layer while calculating the output size for mkldnn dst_layer and dst_iter

* fix unit test failure
parent 784735d6
...@@ -107,9 +107,10 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer, ...@@ -107,9 +107,10 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
} }
add_output(src_layer->get_element_type(), add_output(src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_num_timesteps * m_batch_size), Shape{static_cast<unsigned long>(m_direction * m_num_timesteps * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)}); static_cast<unsigned long>(m_src_iter_feature_size)});
add_output(src_layer->get_element_type(), add_output(src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_num_cell_states * m_batch_size), Shape{static_cast<unsigned long>(m_num_cell_states * m_direction *
m_num_fused_layers * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)}); static_cast<unsigned long>(m_src_iter_feature_size)});
} }
...@@ -556,12 +556,16 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -556,12 +556,16 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
auto ht_slice = std::make_shared<op::Slice>( auto ht_slice = std::make_shared<op::Slice>(
rnn_ht_ct_out, rnn_ht_ct_out,
Coordinate{0, 0}, Coordinate{0, 0},
Coordinate{static_cast<unsigned long>(batch_size), Coordinate{static_cast<unsigned long>(batch_size * direction *
num_fused_rnn_layers),
static_cast<unsigned long>(feature_size)}); static_cast<unsigned long>(feature_size)});
auto ct_slice = std::make_shared<op::Slice>( auto ct_slice = std::make_shared<op::Slice>(
rnn_ht_ct_out, rnn_ht_ct_out,
Coordinate{static_cast<unsigned long>(batch_size), 0}, Coordinate{static_cast<unsigned long>(batch_size * direction *
Coordinate{static_cast<unsigned long>(2 * batch_size), num_fused_rnn_layers),
0},
Coordinate{static_cast<unsigned long>(2 * batch_size * direction *
num_fused_rnn_layers),
static_cast<unsigned long>(feature_size)}); static_cast<unsigned long>(feature_size)});
// check if the last LSTM cell has any consumers // check if the last LSTM cell has any consumers
...@@ -688,13 +692,13 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -688,13 +692,13 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
auto weights_iter = compute_multi_layer_rnn_inputs(weights_iter_label, m); auto weights_iter = compute_multi_layer_rnn_inputs(weights_iter_label, m);
auto bias = compute_multi_layer_rnn_inputs(bias_label, m); auto bias = compute_multi_layer_rnn_inputs(bias_label, m);
std::shared_ptr<op::Rnn> rnn_node = nullptr; std::vector<std::shared_ptr<op::Rnn>> rnn_nodes;
for (auto& rnn_goe_input : m.get_bound_nodes_for_pattern(rnn_ht_label))
for (auto& rnn_goe_input : m.get_bound_nodes_for_pattern(rnn_ht_label)[0]->get_arguments())
{ {
if (std::dynamic_pointer_cast<op::Rnn>(rnn_goe_input)) auto rnn_op = std::dynamic_pointer_cast<op::Rnn>(rnn_goe_input->get_arguments()[0]);
if (rnn_op)
{ {
rnn_node = std::dynamic_pointer_cast<op::Rnn>(rnn_goe_input); rnn_nodes.push_back(rnn_op);
} }
else else
{ {
...@@ -702,14 +706,14 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -702,14 +706,14 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
} }
} }
size_t num_time_steps = rnn_node->get_num_timesteps(); size_t num_time_steps = rnn_nodes[0]->get_num_timesteps();
size_t num_gates_in_lstm = rnn_node->get_gates_per_cell(); size_t num_gates_in_lstm = rnn_nodes[0]->get_gates_per_cell();
size_t batch_size = rnn_node->get_batch_size(); size_t batch_size = rnn_nodes[0]->get_batch_size();
size_t sequence_len = rnn_node->get_src_sequence_length(); size_t sequence_len = rnn_nodes[0]->get_src_sequence_length();
size_t src_layer_feature_size = rnn_node->get_src_layer_feature_size(); size_t src_layer_feature_size = rnn_nodes[0]->get_src_layer_feature_size();
size_t feature_size = rnn_node->get_src_iter_feature_size(); size_t feature_size = rnn_nodes[0]->get_src_iter_feature_size();
size_t num_rnn_cell_states = rnn_node->get_num_cell_states(); size_t num_rnn_cell_states = rnn_nodes[0]->get_num_cell_states();
size_t rnn_direction = rnn_node->get_direction(); size_t rnn_direction = rnn_nodes[0]->get_direction();
size_t num_fused_rnn_layers = m.get_number_of_recurrent_matches(); size_t num_fused_rnn_layers = m.get_number_of_recurrent_matches();
NGRAPH_DEBUG << "src_layer: " << join(src_layer->get_shape()); NGRAPH_DEBUG << "src_layer: " << join(src_layer->get_shape());
...@@ -721,7 +725,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -721,7 +725,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
NGRAPH_DEBUG << "batch_size: " << batch_size; NGRAPH_DEBUG << "batch_size: " << batch_size;
NGRAPH_DEBUG << "feature_size: " << feature_size; NGRAPH_DEBUG << "feature_size: " << feature_size;
if ((src_layer->get_arguments().size()) != rnn_node->get_num_timesteps() && if ((src_layer->get_arguments().size()) != rnn_nodes[0]->get_num_timesteps() &&
!std::dynamic_pointer_cast<op::Parameter>(src_layer)) !std::dynamic_pointer_cast<op::Parameter>(src_layer))
{ {
throw ngraph_error( throw ngraph_error(
...@@ -730,7 +734,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -730,7 +734,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
} }
if (std::dynamic_pointer_cast<op::Parameter>(src_layer) && if (std::dynamic_pointer_cast<op::Parameter>(src_layer) &&
rnn_node->get_num_timesteps() != 1) rnn_nodes[0]->get_num_timesteps() != 1)
{ {
throw ngraph_error( throw ngraph_error(
" input symbols for the layer fused RNN op, should be captured only for the first " " input symbols for the layer fused RNN op, should be captured only for the first "
...@@ -781,26 +785,75 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -781,26 +785,75 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
auto layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 0); auto layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 0);
auto layer_rnn_ht_ct = std::make_shared<op::GetOutputElement>(rnn, 1); auto layer_rnn_ht_ct = std::make_shared<op::GetOutputElement>(rnn, 1);
// find the last RNN cell GOE's and replace them with the layer fused RNN GOE. // multi layerd fused rnn second output {GOE1} holds the recurrent output state tensors for the last cell
for (auto& rnn_goes : rnn_node->get_users()) // of all the layers, we will slice the cell state output tensor {ht | ct} -> {ct} and feeds
{ // {ct} consumer from the fused RNN output.
NGRAPH_DEBUG << "rnn_goes: " << rnn_goes->get_name(); auto ht_slice_across_layer = std::make_shared<op::Slice>(
if (rnn_goes->get_users().empty()) layer_rnn_ht_ct,
Coordinate{0, 0},
Coordinate{
static_cast<unsigned long>(batch_size * rnn_direction * num_fused_rnn_layers),
static_cast<unsigned long>(feature_size)});
auto ct_slice_across_layer = std::make_shared<op::Slice>(
layer_rnn_ht_ct,
Coordinate{
static_cast<unsigned long>(batch_size * rnn_direction * num_fused_rnn_layers), 0},
Coordinate{
static_cast<unsigned long>(2 * batch_size * rnn_direction * num_fused_rnn_layers),
static_cast<unsigned long>(feature_size)});
// Replace all the users of RNN cell state {ct} across different user.
auto replace_rnn_output_cellstate = [&](std::shared_ptr<Node>& rnn_ct, size_t layer) {
std::shared_ptr<Node> node_to_replace = rnn_ct;
auto ct_slice = std::make_shared<op::Slice>(
ct_slice_across_layer,
Coordinate{static_cast<unsigned long>(batch_size * (layer - 1)), 0},
Coordinate{static_cast<unsigned long>(batch_size * rnn_direction * layer),
static_cast<unsigned long>(feature_size)});
if (rnn_ct->get_users().size() == 1)
{
if (std::dynamic_pointer_cast<op::Slice>(rnn_ct->get_users()[0]))
{
node_to_replace = rnn_ct->get_users()[0];
}
}
if (ngraph::is_used(node_to_replace))
{ {
continue; ngraph::replace_node(node_to_replace, ct_slice);
} }
if (auto rnn_goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(rnn_goes)) };
for (size_t index = 0; index < rnn_nodes.size(); index++)
{
for (auto& rnn_goes : rnn_nodes[index]->get_users())
{ {
if (rnn_goe_node->get_n() == 0) NGRAPH_DEBUG << "rnn_goes: " << rnn_goes->get_name();
if (rnn_goes->get_users().empty())
{ {
ngraph::replace_node(rnn_goes, layer_rnn_ht); continue;
} }
else if (rnn_goe_node->get_n() == 1)
if (auto rnn_goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(rnn_goes))
{ {
ngraph::replace_node(rnn_goes, layer_rnn_ht_ct); // we need to only replace the {ht} consumers of the last RNN layer,
// since for other layers the intermediate outputs {ht} will be computed
// within MKLDNN
if (index == 0)
{
if (rnn_goe_node->get_n() == 0)
{
ngraph::replace_node(rnn_goes, layer_rnn_ht);
}
}
if (rnn_goe_node->get_n() == 1)
{
replace_rnn_output_cellstate(rnn_goes, num_fused_rnn_layers - index);
}
} }
} }
} }
return true; return true;
}; };
......
...@@ -890,7 +890,7 @@ ...@@ -890,7 +890,7 @@
"outputs": [ "outputs": [
"Multiply_42_0" "Multiply_42_0"
] ]
}, },
{ {
"inputs": [ "inputs": [
"Multiply_32", "Multiply_32",
...@@ -901,6 +901,16 @@ ...@@ -901,6 +901,16 @@
"outputs": [ "outputs": [
"Add_43_0" "Add_43_0"
] ]
},
{
"inputs": [
"Add_43"
],
"name": "Result_96",
"op": "Result",
"outputs": [
"Result_96_0"
]
}, },
{ {
"inputs": [ "inputs": [
...@@ -1269,7 +1279,8 @@ ...@@ -1269,7 +1279,8 @@
], ],
"result": [ "result": [
"Result_93", "Result_93",
"Result_94" "Result_94",
"Result_96"
] ]
} }
] ]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment