Commit b466027e authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[CPU] Fix ambiguous 'op' namespace. (#2683)

parent 105f03bc
......@@ -38,7 +38,7 @@ namespace ngraph
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
size_t count = out[0].get_size();
auto alpha = static_cast<const op::BoundedRelu*>(node)->get_alpha();
auto alpha = static_cast<const ngraph::op::BoundedRelu*>(node)->get_alpha();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
......
......@@ -38,7 +38,7 @@ namespace ngraph
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
size_t count = out[0].get_size();
auto alpha = static_cast<const op::LeakyRelu*>(node)->get_alpha();
auto alpha = static_cast<const ngraph::op::LeakyRelu*>(node)->get_alpha();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
......
......@@ -313,7 +313,7 @@ namespace ngraph
auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape();
auto daxes = quantize->get_axes();
op::Quantize::RoundMode round_mode = quantize->get_round_mode();
ngraph::op::Quantize::RoundMode round_mode = quantize->get_round_mode();
if (args[0].get_element_type() == element::f32)
{
......
......@@ -705,7 +705,7 @@ bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node)
void runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(Node* node)
{
auto ngraph_op = static_cast<op::Op*>(node);
auto ngraph_op = static_cast<ngraph::op::Op*>(node);
auto op_annotations = std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
ngraph_op->set_op_annotations(op_annotations);
......
......@@ -126,7 +126,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConcat)
{
auto quantized_concat = static_cast<op::QuantizedConcat*>(node);
auto quantized_concat = static_cast<ngraph::op::QuantizedConcat*>(node);
if ((node->get_input_element_type(0) == element::i8 ||
node->get_input_element_type(0) == element::u8) &&
......@@ -195,7 +195,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBiasAdd)
{
auto convolution = static_cast<op::ConvolutionBiasAdd*>(node);
auto convolution = static_cast<ngraph::op::ConvolutionBiasAdd*>(node);
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::ConvolutionBiasAdd>(node))
{
......@@ -212,7 +212,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::GetOutputElement)
{
auto goe = static_cast<op::GetOutputElement*>(node);
auto goe = static_cast<ngraph::op::GetOutputElement*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->add_in_place_oi_pair({0, goe->get_n(), false});
......@@ -222,7 +222,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionAdd)
{
auto convolution = static_cast<op::ConvolutionAdd*>(node);
auto convolution = static_cast<ngraph::op::ConvolutionAdd*>(node);
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::ConvolutionAdd>(node))
{
......@@ -257,7 +257,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBackpropData)
{
auto convolution = static_cast<op::ConvolutionBackpropData*>(node);
auto convolution = static_cast<ngraph::op::ConvolutionBackpropData*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
......@@ -282,7 +282,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBackpropFilters)
{
auto convolution = static_cast<op::ConvolutionBackpropFilters*>(node);
auto convolution = static_cast<ngraph::op::ConvolutionBackpropFilters*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
......@@ -316,7 +316,8 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBiasBackpropFiltersBias)
{
auto convolution = static_cast<op::ConvolutionBiasBackpropFiltersBias*>(node);
auto convolution =
static_cast<ngraph::op::ConvolutionBiasBackpropFiltersBias*>(node);
auto data_shape = node->get_input_shape(0);
auto delta_shape = node->get_input_shape(1);
......@@ -340,7 +341,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::AvgPool)
{
auto avg_pool = static_cast<op::AvgPool*>(node);
auto avg_pool = static_cast<ngraph::op::AvgPool*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -357,7 +358,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::AvgPoolBackprop)
{
auto avg_pool = static_cast<op::AvgPoolBackprop*>(node);
auto avg_pool = static_cast<ngraph::op::AvgPoolBackprop*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -374,7 +375,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPool)
{
auto max_pool = static_cast<op::MaxPool*>(node);
auto max_pool = static_cast<ngraph::op::MaxPool*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -391,7 +392,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolWithIndices)
{
auto max_pool = static_cast<op::MaxPoolWithIndices*>(node);
auto max_pool = static_cast<ngraph::op::MaxPoolWithIndices*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -407,7 +408,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolBackprop)
{
auto max_pool = static_cast<op::MaxPoolBackprop*>(node);
auto max_pool = static_cast<ngraph::op::MaxPoolBackprop*>(node);
auto arg1_shape = node->get_input_shape(1);
auto arg1_rank = arg1_shape.size();
......@@ -424,7 +425,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolWithIndicesBackprop)
{
auto max_pool = static_cast<op::MaxPoolWithIndicesBackprop*>(node);
auto max_pool = static_cast<ngraph::op::MaxPoolWithIndicesBackprop*>(node);
auto arg1_shape = node->get_input_shape(1);
auto arg1_rank = arg1_shape.size();
......@@ -440,7 +441,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Relu)
{
auto relu = static_cast<op::Relu*>(node);
auto relu = static_cast<ngraph::op::Relu*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -464,7 +465,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ReplaceSlice)
{
auto replace_slice = static_cast<op::ReplaceSlice*>(node);
auto replace_slice = static_cast<ngraph::op::ReplaceSlice*>(node);
// ReplaceSlice is independent of data type. Hence not checking type
auto op_annotations =
......@@ -480,7 +481,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::UpdateSlice)
{
auto update_slice = static_cast<op::UpdateSlice*>(node);
auto update_slice = static_cast<ngraph::op::UpdateSlice*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
......@@ -601,7 +602,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Softmax)
{
auto softmax = static_cast<op::Softmax*>(node);
auto softmax = static_cast<ngraph::op::Softmax*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -618,7 +619,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Slice)
{
auto slice = static_cast<op::Slice*>(node);
auto slice = static_cast<ngraph::op::Slice*>(node);
auto strides = slice->get_strides();
if (!is_strided(strides) && node->get_input_element_type(0) == element::f32)
{
......@@ -649,7 +650,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BoundedRelu)
{
auto bounded_relu = static_cast<op::BoundedRelu*>(node);
auto bounded_relu = static_cast<ngraph::op::BoundedRelu*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -673,7 +674,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::LeakyRelu)
{
auto leaky_relu = static_cast<op::LeakyRelu*>(node);
auto leaky_relu = static_cast<ngraph::op::LeakyRelu*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -719,7 +720,8 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolutionBiasAdd)
{
auto quantized_conv_bias = static_cast<op::QuantizedConvolutionBiasAdd*>(node);
auto quantized_conv_bias =
static_cast<ngraph::op::QuantizedConvolutionBiasAdd*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
......@@ -733,7 +735,7 @@ namespace ngraph
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolutionBiasSignedAdd)
{
auto quantized_conv_bias =
static_cast<op::QuantizedConvolutionBiasSignedAdd*>(node);
static_cast<ngraph::op::QuantizedConvolutionBiasSignedAdd*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
......@@ -758,7 +760,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Dequantize)
{
auto dequantize = static_cast<op::Dequantize*>(node);
auto dequantize = static_cast<ngraph::op::Dequantize*>(node);
// TODO(nbpatel): Support dynamic offset via mkldnn
// Go through reference if the offset is not a constant
if (!dequantize->get_argument(2)->is_constant())
......@@ -796,7 +798,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Quantize)
{
auto quantize = static_cast<op::Quantize*>(node);
auto quantize = static_cast<ngraph::op::Quantize*>(node);
// TODO(nbpatel): Support dynamic offset via mkldnn
// Go through reference if the offset is not a constant
if (!quantize->get_argument(2)->is_constant())
......@@ -805,8 +807,8 @@ namespace ngraph
}
auto offset_const_op =
std::static_pointer_cast<ngraph::op::Constant>(quantize->get_argument(2));
op::Quantize::RoundMode round_mode = quantize->get_round_mode();
if (round_mode != op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
ngraph::op::Quantize::RoundMode round_mode = quantize->get_round_mode();
if (round_mode != ngraph::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
{
return;
}
......@@ -845,7 +847,7 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Convert)
{
auto convert = static_cast<op::Convert*>(node);
auto convert = static_cast<ngraph::op::Convert*>(node);
if ((node->get_input_element_type(0) == element::i8 &&
node->get_output_element_type(0) == element::u8) ||
(node->get_input_element_type(0) == element::u8 &&
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -64,7 +64,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
{
if (n->description() == "Concat")
{
auto concat = std::static_pointer_cast<op::Concat>(n);
auto concat = std::static_pointer_cast<ngraph::op::Concat>(n);
auto shape = concat->get_input_shape(0);
auto axis = concat->get_concatenation_axis();
auto product = 1;
......@@ -134,7 +134,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
{
if (arg->is_op())
{
auto op = std::static_pointer_cast<op::Op>(arg);
auto op = std::static_pointer_cast<ngraph::op::Op>(arg);
auto annotation = op->get_op_annotations();
if (annotation && annotation->get_in_place_oi_pairs().size() > 0)
......@@ -177,7 +177,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
{
if (user->is_op())
{
auto op = std::static_pointer_cast<op::Op>(user);
auto op = std::static_pointer_cast<ngraph::op::Op>(user);
if (auto op_annotations = op->get_op_annotations())
{
if (op_annotations->get_in_place_oi_pairs().size() > 0)
......@@ -227,7 +227,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
{
if (n->description() == "Slice")
{
auto slice = std::static_pointer_cast<op::Slice>(n);
auto slice = std::static_pointer_cast<ngraph::op::Slice>(n);
auto in_shape = slice->get_input_shape(0);
auto out_shape = slice->get_output_shape(0);
auto strides = slice->get_strides();
......
......@@ -66,15 +66,16 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid()
{
// construct variance
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input);
auto neg_input = std::make_shared<ngraph::op::Negative>(input);
auto exp_neg_input = std::make_shared<ngraph::op::Exp>(neg_input);
// broadcast input
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto broadcast_constant = std::make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto broadcast_constant =
std::make_shared<ngraph::op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, broadcast_constant);
auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
auto add_exp = std::make_shared<ngraph::op::Add>(exp_neg_input, broadcast_constant);
auto divide_1_over_exp = std::make_shared<ngraph::op::Divide>(broadcast_constant, add_exp);
// Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
......@@ -96,7 +97,7 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid()
return false;
}
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]);
auto sigmoid_node = std::make_shared<ngraph::op::Sigmoid>(pattern_map[input]);
ngraph::replace_node(m.get_match_root(), sigmoid_node);
return true;
};
......@@ -147,177 +148,184 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop()
auto ct_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100});
auto broadcast_pred = [](std::shared_ptr<Node> n) {
return ((std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr) ||
(std::dynamic_pointer_cast<op::Reshape>(n) != nullptr));
return ((std::dynamic_pointer_cast<ngraph::op::Broadcast>(n) != nullptr) ||
(std::dynamic_pointer_cast<ngraph::op::Reshape>(n) != nullptr));
};
// Fused MatMuls
// (W_{ii} | (W_{if} | W_{ig} | W_{io}) * x_t + (b_{ii} | b_{if} | b_{ig} | b_{io})
auto dot1 = std::make_shared<op::Dot>(xt, w_i2h);
auto add1 = std::make_shared<op::Add>(
auto dot1 = std::make_shared<ngraph::op::Dot>(xt, w_i2h);
auto add1 = std::make_shared<ngraph::op::Add>(
dot1, std::make_shared<pattern::op::Skip>(bias_i2h, broadcast_pred));
// (W_{hi} | (W_{hf} | W_{hg} | W_{ho}) * h_{(t-1)} + (b_{hi} | b_{hf} | b_{hg} | b_{ho})
auto dot2 = std::make_shared<op::Dot>(ht_1, w_h2h);
auto add2 = std::make_shared<op::Add>(
auto dot2 = std::make_shared<ngraph::op::Dot>(ht_1, w_h2h);
auto add2 = std::make_shared<ngraph::op::Add>(
dot2, std::make_shared<pattern::op::Skip>(bias_h2h, broadcast_pred));
auto X = std::make_shared<op::Add>(add2, add1);
auto X = std::make_shared<ngraph::op::Add>(add2, add1);
// construct gates
auto it = std::make_shared<op::Sigmoid>(
std::make_shared<op::Slice>(X, Coordinate{0, 0}, Coordinate{10, 100}));
auto ft = std::make_shared<op::Sigmoid>(
std::make_shared<op::Slice>(X, Coordinate{0, 100}, Coordinate{10, 200}));
auto gt = std::make_shared<op::Tanh>(
std::make_shared<op::Slice>(X, Coordinate{0, 200}, Coordinate{10, 300}));
auto ot = std::make_shared<op::Sigmoid>(
std::make_shared<op::Slice>(X, Coordinate{0, 300}, Coordinate{10, 400}));
auto it = std::make_shared<ngraph::op::Sigmoid>(
std::make_shared<ngraph::op::Slice>(X, Coordinate{0, 0}, Coordinate{10, 100}));
auto ft = std::make_shared<ngraph::op::Sigmoid>(
std::make_shared<ngraph::op::Slice>(X, Coordinate{0, 100}, Coordinate{10, 200}));
auto gt = std::make_shared<ngraph::op::Tanh>(
std::make_shared<ngraph::op::Slice>(X, Coordinate{0, 200}, Coordinate{10, 300}));
auto ot = std::make_shared<ngraph::op::Sigmoid>(
std::make_shared<ngraph::op::Slice>(X, Coordinate{0, 300}, Coordinate{10, 400}));
// construct (c_t) cell state
auto ct = std::make_shared<op::Add>(std::make_shared<op::Multiply>(ft, ct_1),
std::make_shared<op::Multiply>(it, gt));
auto ct = std::make_shared<ngraph::op::Add>(std::make_shared<ngraph::op::Multiply>(ft, ct_1),
std::make_shared<ngraph::op::Multiply>(it, gt));
auto ct_label = std::make_shared<pattern::op::Label>(ct, nullptr, NodeVector{ct});
// construct (h_t)
auto ht = std::make_shared<op::Multiply>(ot, std::make_shared<op::Tanh>(ct_label));
auto ht =
std::make_shared<ngraph::op::Multiply>(ot, std::make_shared<ngraph::op::Tanh>(ct_label));
// Define a call back that needs to called once the DFG matches the pattern
pattern::graph_rewrite_callback callback =
[ct_label, w_i2h, bias_i2h, w_h2h, bias_h2h, xt, ht_1, ct_1](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_lstm pattern against "
<< m.get_match_root()->get_name();
pattern::graph_rewrite_callback callback = [ct_label,
w_i2h,
bias_i2h,
w_h2h,
bias_h2h,
xt,
ht_1,
ct_1](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_lstm pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto pattern_map = m.get_pattern_map();
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
CHECK_RANK(pattern_map[xt], 2);
CHECK_RANK(pattern_map[ht_1], 2);
CHECK_RANK(pattern_map[w_i2h], 2);
CHECK_RANK(pattern_map[w_h2h], 2);
CHECK_RANK(pattern_map[bias_i2h], 1);
CHECK_RANK(pattern_map[bias_h2h], 1);
auto weights_layer = pattern_map[w_i2h];
auto weights_iter = pattern_map[w_h2h];
auto src_layer = pattern_map[xt];
auto hidden_state = pattern_map[ht_1];
auto cell_state = pattern_map[ct_1];
// TODO: (Pruthvi) temporary workaround for GNMT slow down
// this checks avoids fusing of LSTM cells if its a part of decoder, we
// will remove this once mkldnn optimizes individual LSTM cell or once
// we have decoder pattern for GNMT.
if (!(std::dynamic_pointer_cast<op::Broadcast>(cell_state) &&
std::dynamic_pointer_cast<op::Constant>(cell_state->get_argument(0))) &&
!(std::dynamic_pointer_cast<op::Slice>(cell_state) &&
std::dynamic_pointer_cast<op::GetOutputElement>(cell_state->get_argument(0))))
{
return false;
}
CHECK_RANK(pattern_map[xt], 2);
CHECK_RANK(pattern_map[ht_1], 2);
CHECK_RANK(pattern_map[w_i2h], 2);
CHECK_RANK(pattern_map[w_h2h], 2);
CHECK_RANK(pattern_map[bias_i2h], 1);
CHECK_RANK(pattern_map[bias_h2h], 1);
auto weights_layer = pattern_map[w_i2h];
auto weights_iter = pattern_map[w_h2h];
auto src_layer = pattern_map[xt];
auto hidden_state = pattern_map[ht_1];
auto cell_state = pattern_map[ct_1];
// TODO: (Pruthvi) temporary workaround for GNMT slow down
// this checks avoids fusing of LSTM cells if its a part of decoder, we
// will remove this once mkldnn optimizes individual LSTM cell or once
// we have decoder pattern for GNMT.
if (!(std::dynamic_pointer_cast<ngraph::op::Broadcast>(cell_state) &&
std::dynamic_pointer_cast<ngraph::op::Constant>(cell_state->get_argument(0))) &&
!(std::dynamic_pointer_cast<ngraph::op::Slice>(cell_state) &&
std::dynamic_pointer_cast<ngraph::op::GetOutputElement>(cell_state->get_argument(0))))
{
return false;
}
auto swap_lstm_inputs = [&]() -> void {
src_layer = pattern_map[ht_1];
hidden_state = pattern_map[xt];
weights_layer = pattern_map[w_h2h];
weights_iter = pattern_map[w_i2h];
};
// LSTM kernel expects ht_1 and ct_1 to have the same shape but the
// pattern matcher cannot guarantee this since the computations are
// symmetric around x_t and ht_1. Use heuristics to swap the matched
// labels
if (std::dynamic_pointer_cast<op::Broadcast>(src_layer) &&
std::dynamic_pointer_cast<op::Constant>(src_layer->get_argument(0)))
{
// First timestep of an RNN layer
swap_lstm_inputs();
}
else if (hidden_state->get_shape() != cell_state->get_shape())
auto swap_lstm_inputs = [&]() -> void {
src_layer = pattern_map[ht_1];
hidden_state = pattern_map[xt];
weights_layer = pattern_map[w_h2h];
weights_iter = pattern_map[w_i2h];
};
// LSTM kernel expects ht_1 and ct_1 to have the same shape but the
// pattern matcher cannot guarantee this since the computations are
// symmetric around x_t and ht_1. Use heuristics to swap the matched
// labels
if (std::dynamic_pointer_cast<ngraph::op::Broadcast>(src_layer) &&
std::dynamic_pointer_cast<ngraph::op::Constant>(src_layer->get_argument(0)))
{
// First timestep of an RNN layer
swap_lstm_inputs();
}
else if (hidden_state->get_shape() != cell_state->get_shape())
{
swap_lstm_inputs();
}
else if (std::dynamic_pointer_cast<ngraph::op::GetOutputElement>(
cell_state->get_argument(0)))
{
// swap the inputs if the cell_state and hidden state does not
// belong to the same Lstm
if (!hidden_state->get_argument(0)->get_arguments().size() ||
(hidden_state->get_argument(0)->get_arguments()[0] !=
cell_state->get_argument(0)->get_arguments()[0]))
{
swap_lstm_inputs();
}
else if (std::dynamic_pointer_cast<op::GetOutputElement>(cell_state->get_argument(0)))
{
// swap the inputs if the cell_state and hidden state does not
// belong to the same Lstm
if (!hidden_state->get_argument(0)->get_arguments().size() ||
(hidden_state->get_argument(0)->get_arguments()[0] !=
cell_state->get_argument(0)->get_arguments()[0]))
{
swap_lstm_inputs();
}
}
}
if (hidden_state->get_shape() != cell_state->get_shape())
{
NGRAPH_DEBUG
<< "Lstm MKLDNN kernel requires recurrent output hidden states to match ";
return false;
}
if (hidden_state->get_shape() != cell_state->get_shape())
{
NGRAPH_DEBUG << "Lstm MKLDNN kernel requires recurrent output hidden states to match ";
return false;
}
// set LSTM cell attributes
size_t lstm_n_gates = 4;
size_t batch_size = src_layer->get_shape()[0];
size_t direction = 1;
size_t layers = 1;
auto dlc = weights_layer->get_shape()[1] / (lstm_n_gates * direction * layers);
auto slc = weights_layer->get_shape()[0];
auto dic = weights_iter->get_shape()[1] / (lstm_n_gates * direction * layers);
auto sic = weights_iter->get_shape()[0];
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
if (dlc != dic)
{
NGRAPH_DEBUG << "Not fusing, since Lstm kernel requires dst_layer feature size "
<< "equals to dts_iter feature size";
return false;
}
// set LSTM cell attributes
size_t lstm_n_gates = 4;
size_t batch_size = src_layer->get_shape()[0];
size_t direction = 1;
size_t layers = 1;
auto dlc = weights_layer->get_shape()[1] / (lstm_n_gates * direction * layers);
auto slc = weights_layer->get_shape()[0];
auto dic = weights_iter->get_shape()[1] / (lstm_n_gates * direction * layers);
auto sic = weights_iter->get_shape()[0];
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
std::shared_ptr<Node> src_iter =
std::make_shared<op::Concat>(NodeVector{hidden_state, cell_state}, 0);
if (src_layer->get_shape()[1] != slc || src_iter->get_shape()[1] != sic)
{
NGRAPH_DEBUG << "Feature size mismatch between weights and input tensors";
return false;
}
if (dlc != dic)
{
NGRAPH_DEBUG << "Not fusing, since Lstm kernel requires dst_layer feature size "
<< "equals to dts_iter feature size";
return false;
}
auto bias = std::make_shared<op::Add>(pattern_map[bias_i2h], pattern_map[bias_h2h]);
std::shared_ptr<Node> src_iter =
std::make_shared<ngraph::op::Concat>(NodeVector{hidden_state, cell_state}, 0);
if (src_layer->get_shape()[1] != slc || src_iter->get_shape()[1] != sic)
{
NGRAPH_DEBUG << "Feature size mismatch between weights and input tensors";
return false;
}
auto lstm_node = std::make_shared<op::Lstm>(
src_layer, src_iter, weights_layer, weights_iter, bias, rnn_type);
auto bias = std::make_shared<ngraph::op::Add>(pattern_map[bias_i2h], pattern_map[bias_h2h]);
auto lstm_ht_output = std::make_shared<op::GetOutputElement>(lstm_node, 0);
auto lstm_ht_ct_output = std::make_shared<op::GetOutputElement>(lstm_node, 1);
auto lstm_node = std::make_shared<ngraph::op::Lstm>(
src_layer, src_iter, weights_layer, weights_iter, bias, rnn_type);
// dst_iter of lstm mkldnn output holds the results of both recurrent state
// tensor outputs. we need to slice the ct.
auto ht_slice = std::make_shared<op::Slice>(
lstm_ht_output, Coordinate{0, 0}, Coordinate{batch_size, dlc});
auto ct_slice = std::make_shared<op::Slice>(
lstm_ht_ct_output, Coordinate{batch_size, 0}, Coordinate{(2 * batch_size), dic});
auto lstm_ht_output = std::make_shared<ngraph::op::GetOutputElement>(lstm_node, 0);
auto lstm_ht_ct_output = std::make_shared<ngraph::op::GetOutputElement>(lstm_node, 1);
if (lstm_node->get_outputs().at(0).get_inputs().size() != 2)
{
throw ngraph_error("Lstm node doesnt have two outputs");
}
// Now identify the nodes which consumes the output of LSTM nodes
// and replace them accordingly
// find the user's for {ht|ct} and replace them with lstm_goe_1
if (ngraph::is_used(pattern_map[ct_label].get()))
{
replace_collapse_node_user(pattern_map[ct_label], ct_slice->get_outputs().at(0));
}
// find the user's for {ht} and replace them with lstm_goe_0
ngraph::replace_node(m.get_match_root(), ht_slice);
return true;
};
// dst_iter of lstm mkldnn output holds the results of both recurrent state
// tensor outputs. we need to slice the ct.
auto ht_slice = std::make_shared<ngraph::op::Slice>(
lstm_ht_output, Coordinate{0, 0}, Coordinate{batch_size, dlc});
auto ct_slice = std::make_shared<ngraph::op::Slice>(
lstm_ht_ct_output, Coordinate{batch_size, 0}, Coordinate{(2 * batch_size), dic});
if (lstm_node->get_outputs().at(0).get_inputs().size() != 2)
{
throw ngraph_error("Lstm node doesnt have two outputs");
}
// Now identify the nodes which consumes the output of LSTM nodes
// and replace them accordingly
// find the user's for {ht|ct} and replace them with lstm_goe_1
if (ngraph::is_used(pattern_map[ct_label].get()))
{
replace_collapse_node_user(pattern_map[ct_label], ct_slice->get_outputs().at(0));
}
// find the user's for {ht} and replace them with lstm_goe_0
ngraph::replace_node(m.get_match_root(), ht_slice);
return true;
};
auto m = std::make_shared<pattern::Matcher>(ht, callback, "LSTMFusion.Fprop");
this->add_matcher(m);
}
......@@ -330,44 +338,45 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
auto lstm_ht = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100});
auto lstm_ct = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100});
auto lstm_src_iter = std::make_shared<op::Concat>(NodeVector{lstm_ht, lstm_ct}, 0);
auto lstm_src_iter = std::make_shared<ngraph::op::Concat>(NodeVector{lstm_ht, lstm_ct}, 0);
auto lstm_src_iter_label =
std::make_shared<pattern::op::Label>(lstm_src_iter, nullptr, NodeVector{lstm_src_iter});
auto lstm_weights_layer_shared = std::make_shared<pattern::op::Label>(
element::f32, Shape{400, 100}, pattern::has_class<op::Parameter>());
auto lstm_weights_layer =
std::make_shared<op::Reshape>(lstm_weights_layer_shared, AxisVector{1, 0}, Shape{100, 400});
element::f32, Shape{400, 100}, pattern::has_class<ngraph::op::Parameter>());
auto lstm_weights_layer = std::make_shared<ngraph::op::Reshape>(
lstm_weights_layer_shared, AxisVector{1, 0}, Shape{100, 400});
auto lstm_weights_layer_label = std::make_shared<pattern::op::Label>(
lstm_weights_layer, nullptr, NodeVector{lstm_weights_layer});
auto lstm_weights_iter_shared = std::make_shared<pattern::op::Label>(
element::f32, Shape{400, 100}, pattern::has_class<op::Parameter>());
auto lstm_weights_iter =
std::make_shared<op::Reshape>(lstm_weights_iter_shared, AxisVector{1, 0}, Shape{100, 400});
element::f32, Shape{400, 100}, pattern::has_class<ngraph::op::Parameter>());
auto lstm_weights_iter = std::make_shared<ngraph::op::Reshape>(
lstm_weights_iter_shared, AxisVector{1, 0}, Shape{100, 400});
auto lstm_weights_iter_label = std::make_shared<pattern::op::Label>(
lstm_weights_iter, nullptr, NodeVector{lstm_weights_iter});
auto lstm_bias_layer_shared = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto lstm_bias_iter_shared = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto lstm_bias = std::make_shared<op::Add>(lstm_bias_layer_shared, lstm_bias_iter_shared);
auto lstm_bias =
std::make_shared<ngraph::op::Add>(lstm_bias_layer_shared, lstm_bias_iter_shared);
auto lstm_bias_label =
std::make_shared<pattern::op::Label>(lstm_bias, nullptr, NodeVector{lstm_bias});
ngraph::runtime::cpu::rnn_utils::rnntype ref_rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
auto lstm = std::make_shared<op::Lstm>(lstm_src_layer,
lstm_src_iter_label,
lstm_weights_layer_label,
lstm_weights_iter_label,
lstm_bias_label,
ref_rnn_type);
auto lstm_goe = std::make_shared<op::GetOutputElement>(lstm, 1);
auto lstm = std::make_shared<ngraph::op::Lstm>(lstm_src_layer,
lstm_src_iter_label,
lstm_weights_layer_label,
lstm_weights_iter_label,
lstm_bias_label,
ref_rnn_type);
auto lstm_goe = std::make_shared<ngraph::op::GetOutputElement>(lstm, 1);
// We cannot attach labels to multi-output nodes, so we attach a label to the goe instead
auto lstm_goe_label =
std::make_shared<pattern::op::Label>(lstm_goe, nullptr, NodeVector{lstm_goe});
auto lstm_goe_slice =
std::make_shared<op::Slice>(lstm_goe_label, Coordinate{10, 0}, Coordinate{20, 100});
std::make_shared<ngraph::op::Slice>(lstm_goe_label, Coordinate{10, 0}, Coordinate{20, 100});
pattern::recurrent_graph_rewrite_callback callback = [lstm_goe_label,
lstm_src_layer,
......@@ -387,7 +396,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
{
auto node_labels = m.get_bound_nodes_for_pattern(input_label);
std::reverse(node_labels.begin(), node_labels.end());
return std::make_shared<op::Concat>(node_labels, 0);
return std::make_shared<ngraph::op::Concat>(node_labels, 0);
}
};
......@@ -429,9 +438,9 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
}
auto check_const_input = [&](std::shared_ptr<Node> n) {
if (std::dynamic_pointer_cast<op::Constant>(n) ||
(std::dynamic_pointer_cast<op::Broadcast>(n) &&
std::dynamic_pointer_cast<op::Constant>(n->get_argument(0))))
if (std::dynamic_pointer_cast<ngraph::op::Constant>(n) ||
(std::dynamic_pointer_cast<ngraph::op::Broadcast>(n) &&
std::dynamic_pointer_cast<ngraph::op::Constant>(n->get_argument(0))))
{
return true;
}
......@@ -460,26 +469,27 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
return false;
}
auto rnn = std::make_shared<op::Rnn>(rnn_src_layer,
rnn_src_iter,
rnn_weights_layer,
rnn_weights_iter,
rnn_bias,
sequence_len,
lstm_n_gates,
sequence_len,
num_cell_states,
direction,
num_fused_rnn_layers,
rnn_type);
std::vector<std::shared_ptr<op::Slice>> ht_slice_per_timestep(sequence_len, nullptr);
auto rnn_ht_goe = std::make_shared<op::GetOutputElement>(rnn, 0);
auto rnn_ht_ct_goe = std::make_shared<op::GetOutputElement>(rnn, 1);
auto rnn = std::make_shared<ngraph::op::Rnn>(rnn_src_layer,
rnn_src_iter,
rnn_weights_layer,
rnn_weights_iter,
rnn_bias,
sequence_len,
lstm_n_gates,
sequence_len,
num_cell_states,
direction,
num_fused_rnn_layers,
rnn_type);
std::vector<std::shared_ptr<ngraph::op::Slice>> ht_slice_per_timestep(sequence_len,
nullptr);
auto rnn_ht_goe = std::make_shared<ngraph::op::GetOutputElement>(rnn, 0);
auto rnn_ht_ct_goe = std::make_shared<ngraph::op::GetOutputElement>(rnn, 1);
for (size_t i = 0, start_index = 0; i < sequence_len; i++, start_index += batch_size)
{
ht_slice_per_timestep[i] = (std::make_shared<op::Slice>(
ht_slice_per_timestep[i] = (std::make_shared<ngraph::op::Slice>(
rnn_ht_goe,
Coordinate{start_index, 0},
Coordinate{start_index + batch_size, src_iter_feature_size}));
......@@ -503,7 +513,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
for (size_t index = 0; index < sequence_len; index++)
{
auto goe_nodes = op::get_output_elements(lstm_nodes[index]);
auto goe_nodes = ngraph::op::get_output_elements(lstm_nodes[index]);
// if there is no GOE followed by the Lstm, their might be pattern match error
// we will return safely
......@@ -521,7 +531,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
{
if (ngraph::is_used(goe0_user.get()))
{
if (!std::dynamic_pointer_cast<op::Slice>(goe0_user))
if (!std::dynamic_pointer_cast<ngraph::op::Slice>(goe0_user))
{
NGRAPH_DEBUG << "Did not find LSTM slice to replace with RNN slice";
return false;
......@@ -536,7 +546,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
}
}
auto rnn_ct_goe = op::get_output_elements(lstm_nodes[sequence_len - 1])[1];
auto rnn_ct_goe = ngraph::op::get_output_elements(lstm_nodes[sequence_len - 1])[1];
if (rnn_ct_goe)
{
replace_collapse_node_user(rnn_ct_goe, rnn_ht_ct_goe->get_outputs().at(0));
......@@ -566,7 +576,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
static std::shared_ptr<Node> stack_rnn_inputs(NodeVector rnn_input_nodes)
{
std::reverse(rnn_input_nodes.begin(), rnn_input_nodes.end());
return std::make_shared<op::Concat>(rnn_input_nodes, 0);
return std::make_shared<ngraph::op::Concat>(rnn_input_nodes, 0);
}
void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_fusion_fprop()
......@@ -585,20 +595,20 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
ngraph::runtime::cpu::rnn_utils::rnntype ref_rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
auto ref_rnn_node = std::make_shared<op::Rnn>(rnn_src_layer,
rnn_src_iter,
rnn_weights_layer,
rnn_weights_iter,
rnn_bias,
ref_number_of_timesteps,
ref_number_of_gates_per_cell,
ref_src_seq_length,
ref_num_rnn_cell_states,
ref_rnn_direction,
ref_num_of_rnn_fused_layer,
ref_rnn_type);
auto rnn_goe0 = std::make_shared<op::GetOutputElement>(ref_rnn_node, 0);
auto ref_rnn_node = std::make_shared<ngraph::op::Rnn>(rnn_src_layer,
rnn_src_iter,
rnn_weights_layer,
rnn_weights_iter,
rnn_bias,
ref_number_of_timesteps,
ref_number_of_gates_per_cell,
ref_src_seq_length,
ref_num_rnn_cell_states,
ref_rnn_direction,
ref_num_of_rnn_fused_layer,
ref_rnn_type);
auto rnn_goe0 = std::make_shared<ngraph::op::GetOutputElement>(ref_rnn_node, 0);
auto rnn_goe0_label =
std::make_shared<pattern::op::Label>(rnn_goe0, nullptr, NodeVector{rnn_goe0});
......@@ -622,10 +632,11 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
auto rnn_goe0_bounded_nodes = m.get_bound_nodes_for_pattern(rnn_goe0_label);
std::vector<std::shared_ptr<op::Rnn>> rnn_nodes;
std::vector<std::shared_ptr<ngraph::op::Rnn>> rnn_nodes;
for (auto rnn_goe : m.get_bound_nodes_for_pattern(rnn_goe0_label))
{
if (auto rnn_op = std::dynamic_pointer_cast<op::Rnn>(rnn_goe->get_arguments()[0]))
if (auto rnn_op =
std::dynamic_pointer_cast<ngraph::op::Rnn>(rnn_goe->get_arguments()[0]))
{
rnn_nodes.push_back(rnn_op);
}
......@@ -695,21 +706,21 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
"layer");
}
auto rnn = std::make_shared<op::Rnn>(mrnn_src_layer,
mrnn_src_iter,
mrnn_weights_layer,
mrnn_weights_iter,
mrnn_bias,
num_timesteps,
lstm_n_gates,
sequence_len,
num_rnn_cell_states,
rnn_direction,
num_fused_rnn_layers,
rnn_type);
auto mrnn_ht = std::make_shared<op::GetOutputElement>(rnn, 0);
auto mrnn_ht_ct = std::make_shared<op::GetOutputElement>(rnn, 1);
auto rnn = std::make_shared<ngraph::op::Rnn>(mrnn_src_layer,
mrnn_src_iter,
mrnn_weights_layer,
mrnn_weights_iter,
mrnn_bias,
num_timesteps,
lstm_n_gates,
sequence_len,
num_rnn_cell_states,
rnn_direction,
num_fused_rnn_layers,
rnn_type);
auto mrnn_ht = std::make_shared<ngraph::op::GetOutputElement>(rnn, 0);
auto mrnn_ht_ct = std::make_shared<ngraph::op::GetOutputElement>(rnn, 1);
// Replace all the users of RNN cell state {ct} across different user.
auto replace_rnn_output_cellstate = [&](std::shared_ptr<Node> rnn_ct_goe1, size_t layer) {
......@@ -718,7 +729,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
// of all the layers, {{ht_1 | ct_1} || {ht2 |ct2} || ....{htn | ctn}}
// we will slice the cell state output tensor {ct_*} from the fused RNN kerenel output and feeds
// {ct_*} consumer if any
auto ct_slice = std::make_shared<op::Slice>(
auto ct_slice = std::make_shared<ngraph::op::Slice>(
mrnn_ht_ct,
Coordinate{((layer - 1) * batch_size * num_rnn_cell_states) + batch_size, 0},
Coordinate{layer * batch_size * num_rnn_cell_states, src_iter_feature_size});
......@@ -732,7 +743,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
// i.e {RNN7, RNN6, RNN5.... RNN0}
for (size_t index = 0; index < rnn_nodes.size(); index++)
{
auto goe_nodes = op::get_output_elements(rnn_nodes[index]);
auto goe_nodes = ngraph::op::get_output_elements(rnn_nodes[index]);
// if there is no GOE followed by the Lstm, their might be pattern match error
// we will return safely
if (goe_nodes.size() != 2)
......@@ -771,15 +782,17 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
{
auto rnn_left_to_right = std::make_shared<pattern::op::Label>(
element::f32, Shape{1, 256}, pattern::has_class<op::Rnn>());
element::f32, Shape{1, 256}, pattern::has_class<ngraph::op::Rnn>());
auto rnn_right_to_left = std::make_shared<pattern::op::Label>(
element::f32, Shape{1, 256}, pattern::has_class<op::Rnn>());
element::f32, Shape{1, 256}, pattern::has_class<ngraph::op::Rnn>());
auto reshape_pred = [](std::shared_ptr<Node> n) {
return (std::dynamic_pointer_cast<op::Reshape>(n) != nullptr);
return (std::dynamic_pointer_cast<ngraph::op::Reshape>(n) != nullptr);
};
auto rnn_left_to_right_goe0 = std::make_shared<op::GetOutputElement>(rnn_left_to_right, 0);
auto rnn_right_to_left_goe0 = std::make_shared<op::GetOutputElement>(rnn_right_to_left, 0);
auto rnn_left_to_right_goe0 =
std::make_shared<ngraph::op::GetOutputElement>(rnn_left_to_right, 0);
auto rnn_right_to_left_goe0 =
std::make_shared<ngraph::op::GetOutputElement>(rnn_right_to_left, 0);
auto rnn_rtol_goe0_reshape_ntc =
std::make_shared<pattern::op::Skip>(rnn_right_to_left_goe0, reshape_pred);
......@@ -791,21 +804,23 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
std::make_shared<pattern::op::Skip>(rnn_ltor_goe0_reshape_ntc, reshape_pred);
auto reverse_seq_predicate = [](std::shared_ptr<Node> node) {
return pattern::has_class<op::ReverseSequence>()(node) ||
pattern::has_class<op::Reverse>()(node);
return pattern::has_class<ngraph::op::ReverseSequence>()(node) ||
pattern::has_class<ngraph::op::Reverse>()(node);
};
auto skip_reverse_seq =
std::make_shared<pattern::op::Skip>(rnn_rtol_goe0_reshape_tnc, reverse_seq_predicate);
auto concat =
std::make_shared<op::Concat>(NodeVector{rnn_ltor_goe0_reshape_tnc, skip_reverse_seq}, 0);
auto concat = std::make_shared<ngraph::op::Concat>(
NodeVector{rnn_ltor_goe0_reshape_tnc, skip_reverse_seq}, 0);
// Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [rnn_left_to_right,
rnn_right_to_left](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto rnn_ltor_node = std::static_pointer_cast<op::Rnn>(pattern_map[rnn_left_to_right]);
auto rnn_rtol_node = std::static_pointer_cast<op::Rnn>(pattern_map[rnn_right_to_left]);
auto rnn_ltor_node =
std::static_pointer_cast<ngraph::op::Rnn>(pattern_map[rnn_left_to_right]);
auto rnn_rtol_node =
std::static_pointer_cast<ngraph::op::Rnn>(pattern_map[rnn_right_to_left]);
if (rnn_ltor_node->get_src_sequence_length() != rnn_rtol_node->get_src_sequence_length())
{
......@@ -852,7 +867,7 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
auto nodes =
NodeVector{rnn_ltor_node->get_argument(index), rnn_rtol_node->get_argument(index)};
return std::make_shared<op::Concat>(nodes, 0);
return std::make_shared<ngraph::op::Concat>(nodes, 0);
};
auto src_layer = rnn_ltor_node->get_arguments()[0];
......@@ -861,20 +876,20 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
auto weights_iter = construct_birnn_inputs(3);
auto bias = construct_birnn_inputs(4);
auto rnn = std::make_shared<op::Rnn>(src_layer,
src_iter,
weights_layer,
weights_iter,
bias,
num_time_steps,
lstm_n_gates,
sequence_len,
num_rnn_cell_states,
rnn_direction,
num_fused_rnn_layers,
rnn_type);
auto layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 0);
auto rnn = std::make_shared<ngraph::op::Rnn>(src_layer,
src_iter,
weights_layer,
weights_iter,
bias,
num_time_steps,
lstm_n_gates,
sequence_len,
num_rnn_cell_states,
rnn_direction,
num_fused_rnn_layers,
rnn_type);
auto layer_rnn_ht = std::make_shared<ngraph::op::GetOutputElement>(rnn, 0);
size_t batch_size = layer_rnn_ht->get_shape()[0] / num_time_steps;
size_t feature_size = layer_rnn_ht->get_shape()[1];
......@@ -882,17 +897,17 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
std::shared_ptr<Node> layer_rnn_ht_reshape = layer_rnn_ht;
if (m.get_match_root()->get_shape() != layer_rnn_ht->get_shape())
{
layer_rnn_ht_reshape = std::make_shared<op::Reshape>(
layer_rnn_ht_reshape = std::make_shared<ngraph::op::Reshape>(
layer_rnn_ht, AxisVector{0, 1}, Shape{num_time_steps, batch_size, feature_size});
}
// we will check if the node being replaced is in Shape{n, t, c}, if so we will transpose
if (m.get_match_root()->get_shape() == Shape{batch_size, num_time_steps, feature_size})
{
layer_rnn_ht_reshape =
std::make_shared<op::Reshape>(layer_rnn_ht_reshape,
AxisVector{1, 0, 2},
Shape{batch_size, num_time_steps, feature_size});
layer_rnn_ht_reshape = std::make_shared<ngraph::op::Reshape>(
layer_rnn_ht_reshape,
AxisVector{1, 0, 2},
Shape{batch_size, num_time_steps, feature_size});
}
ngraph::replace_node(m.get_match_root(), layer_rnn_ht_reshape);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment