Commit c4c24cb0 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Pruthvi/fix rnn output (#1135)

* - Fixed replace output for the multi layer recurrent cell state tensor output
- Modified rnn add_output to consider direction and n_layer while calculating the output size for mkldnn dst_layer and dst_iter

* fix unit test failure
parent 784735d6
......@@ -107,9 +107,10 @@ op::Rnn::Rnn(std::shared_ptr<Node> src_layer,
}
add_output(src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_num_timesteps * m_batch_size),
Shape{static_cast<unsigned long>(m_direction * m_num_timesteps * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
add_output(src_layer->get_element_type(),
Shape{static_cast<unsigned long>(m_num_cell_states * m_batch_size),
Shape{static_cast<unsigned long>(m_num_cell_states * m_direction *
m_num_fused_layers * m_batch_size),
static_cast<unsigned long>(m_src_iter_feature_size)});
}
......@@ -556,12 +556,16 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
auto ht_slice = std::make_shared<op::Slice>(
rnn_ht_ct_out,
Coordinate{0, 0},
Coordinate{static_cast<unsigned long>(batch_size),
Coordinate{static_cast<unsigned long>(batch_size * direction *
num_fused_rnn_layers),
static_cast<unsigned long>(feature_size)});
auto ct_slice = std::make_shared<op::Slice>(
rnn_ht_ct_out,
Coordinate{static_cast<unsigned long>(batch_size), 0},
Coordinate{static_cast<unsigned long>(2 * batch_size),
Coordinate{static_cast<unsigned long>(batch_size * direction *
num_fused_rnn_layers),
0},
Coordinate{static_cast<unsigned long>(2 * batch_size * direction *
num_fused_rnn_layers),
static_cast<unsigned long>(feature_size)});
// check if the last LSTM cell has any consumers
......@@ -688,13 +692,13 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
auto weights_iter = compute_multi_layer_rnn_inputs(weights_iter_label, m);
auto bias = compute_multi_layer_rnn_inputs(bias_label, m);
std::shared_ptr<op::Rnn> rnn_node = nullptr;
for (auto& rnn_goe_input : m.get_bound_nodes_for_pattern(rnn_ht_label)[0]->get_arguments())
std::vector<std::shared_ptr<op::Rnn>> rnn_nodes;
for (auto& rnn_goe_input : m.get_bound_nodes_for_pattern(rnn_ht_label))
{
if (std::dynamic_pointer_cast<op::Rnn>(rnn_goe_input))
auto rnn_op = std::dynamic_pointer_cast<op::Rnn>(rnn_goe_input->get_arguments()[0]);
if (rnn_op)
{
rnn_node = std::dynamic_pointer_cast<op::Rnn>(rnn_goe_input);
rnn_nodes.push_back(rnn_op);
}
else
{
......@@ -702,14 +706,14 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
}
}
size_t num_time_steps = rnn_node->get_num_timesteps();
size_t num_gates_in_lstm = rnn_node->get_gates_per_cell();
size_t batch_size = rnn_node->get_batch_size();
size_t sequence_len = rnn_node->get_src_sequence_length();
size_t src_layer_feature_size = rnn_node->get_src_layer_feature_size();
size_t feature_size = rnn_node->get_src_iter_feature_size();
size_t num_rnn_cell_states = rnn_node->get_num_cell_states();
size_t rnn_direction = rnn_node->get_direction();
size_t num_time_steps = rnn_nodes[0]->get_num_timesteps();
size_t num_gates_in_lstm = rnn_nodes[0]->get_gates_per_cell();
size_t batch_size = rnn_nodes[0]->get_batch_size();
size_t sequence_len = rnn_nodes[0]->get_src_sequence_length();
size_t src_layer_feature_size = rnn_nodes[0]->get_src_layer_feature_size();
size_t feature_size = rnn_nodes[0]->get_src_iter_feature_size();
size_t num_rnn_cell_states = rnn_nodes[0]->get_num_cell_states();
size_t rnn_direction = rnn_nodes[0]->get_direction();
size_t num_fused_rnn_layers = m.get_number_of_recurrent_matches();
NGRAPH_DEBUG << "src_layer: " << join(src_layer->get_shape());
......@@ -721,7 +725,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
NGRAPH_DEBUG << "batch_size: " << batch_size;
NGRAPH_DEBUG << "feature_size: " << feature_size;
if ((src_layer->get_arguments().size()) != rnn_node->get_num_timesteps() &&
if ((src_layer->get_arguments().size()) != rnn_nodes[0]->get_num_timesteps() &&
!std::dynamic_pointer_cast<op::Parameter>(src_layer))
{
throw ngraph_error(
......@@ -730,7 +734,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
}
if (std::dynamic_pointer_cast<op::Parameter>(src_layer) &&
rnn_node->get_num_timesteps() != 1)
rnn_nodes[0]->get_num_timesteps() != 1)
{
throw ngraph_error(
" input symbols for the layer fused RNN op, should be captured only for the first "
......@@ -781,26 +785,75 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
auto layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 0);
auto layer_rnn_ht_ct = std::make_shared<op::GetOutputElement>(rnn, 1);
// find the last RNN cell GOE's and replace them with the layer fused RNN GOE.
for (auto& rnn_goes : rnn_node->get_users())
// multi layerd fused rnn second output {GOE1} holds the recurrent output state tensors for the last cell
// of all the layers, we will slice the cell state output tensor {ht | ct} -> {ct} and feeds
// {ct} consumer from the fused RNN output.
auto ht_slice_across_layer = std::make_shared<op::Slice>(
layer_rnn_ht_ct,
Coordinate{0, 0},
Coordinate{
static_cast<unsigned long>(batch_size * rnn_direction * num_fused_rnn_layers),
static_cast<unsigned long>(feature_size)});
auto ct_slice_across_layer = std::make_shared<op::Slice>(
layer_rnn_ht_ct,
Coordinate{
static_cast<unsigned long>(batch_size * rnn_direction * num_fused_rnn_layers), 0},
Coordinate{
static_cast<unsigned long>(2 * batch_size * rnn_direction * num_fused_rnn_layers),
static_cast<unsigned long>(feature_size)});
// Replace all the users of RNN cell state {ct} across different user.
auto replace_rnn_output_cellstate = [&](std::shared_ptr<Node>& rnn_ct, size_t layer) {
std::shared_ptr<Node> node_to_replace = rnn_ct;
auto ct_slice = std::make_shared<op::Slice>(
ct_slice_across_layer,
Coordinate{static_cast<unsigned long>(batch_size * (layer - 1)), 0},
Coordinate{static_cast<unsigned long>(batch_size * rnn_direction * layer),
static_cast<unsigned long>(feature_size)});
if (rnn_ct->get_users().size() == 1)
{
if (std::dynamic_pointer_cast<op::Slice>(rnn_ct->get_users()[0]))
{
node_to_replace = rnn_ct->get_users()[0];
}
}
if (ngraph::is_used(node_to_replace))
{
ngraph::replace_node(node_to_replace, ct_slice);
}
};
for (size_t index = 0; index < rnn_nodes.size(); index++)
{
for (auto& rnn_goes : rnn_nodes[index]->get_users())
{
NGRAPH_DEBUG << "rnn_goes: " << rnn_goes->get_name();
if (rnn_goes->get_users().empty())
{
continue;
}
if (auto rnn_goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(rnn_goes))
{
// we need to only replace the {ht} consumers of the last RNN layer,
// since for other layers the intermediate outputs {ht} will be computed
// within MKLDNN
if (index == 0)
{
if (rnn_goe_node->get_n() == 0)
{
ngraph::replace_node(rnn_goes, layer_rnn_ht);
}
else if (rnn_goe_node->get_n() == 1)
}
if (rnn_goe_node->get_n() == 1)
{
ngraph::replace_node(rnn_goes, layer_rnn_ht_ct);
replace_rnn_output_cellstate(rnn_goes, num_fused_rnn_layers - index);
}
}
}
}
return true;
};
......
......@@ -901,6 +901,16 @@
"outputs": [
"Add_43_0"
]
},
{
"inputs": [
"Add_43"
],
"name": "Result_96",
"op": "Result",
"outputs": [
"Result_96_0"
]
},
{
"inputs": [
......@@ -1269,7 +1279,8 @@
],
"result": [
"Result_93",
"Result_94"
"Result_94",
"Result_96"
]
}
]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment