Commit b466027e authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[CPU] Fix ambiguous 'op' namespace. (#2683)

parent 105f03bc
...@@ -38,7 +38,7 @@ namespace ngraph ...@@ -38,7 +38,7 @@ namespace ngraph
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
size_t count = out[0].get_size(); size_t count = out[0].get_size();
auto alpha = static_cast<const op::BoundedRelu*>(node)->get_alpha(); auto alpha = static_cast<const ngraph::op::BoundedRelu*>(node)->get_alpha();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
......
...@@ -38,7 +38,7 @@ namespace ngraph ...@@ -38,7 +38,7 @@ namespace ngraph
auto& out_tensor = external_function->get_tensor_data(out[0].get_name()); auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
size_t count = out[0].get_size(); size_t count = out[0].get_size();
auto alpha = static_cast<const op::LeakyRelu*>(node)->get_alpha(); auto alpha = static_cast<const ngraph::op::LeakyRelu*>(node)->get_alpha();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
......
...@@ -313,7 +313,7 @@ namespace ngraph ...@@ -313,7 +313,7 @@ namespace ngraph
auto arg0_shape = args[0].get_shape(); auto arg0_shape = args[0].get_shape();
auto arg1_shape = args[1].get_shape(); auto arg1_shape = args[1].get_shape();
auto daxes = quantize->get_axes(); auto daxes = quantize->get_axes();
op::Quantize::RoundMode round_mode = quantize->get_round_mode(); ngraph::op::Quantize::RoundMode round_mode = quantize->get_round_mode();
if (args[0].get_element_type() == element::f32) if (args[0].get_element_type() == element::f32)
{ {
......
...@@ -705,7 +705,7 @@ bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node) ...@@ -705,7 +705,7 @@ bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node)
void runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(Node* node) void runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(Node* node)
{ {
auto ngraph_op = static_cast<op::Op*>(node); auto ngraph_op = static_cast<ngraph::op::Op*>(node);
auto op_annotations = std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); auto op_annotations = std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true); op_annotations->set_mkldnn_op(true);
ngraph_op->set_op_annotations(op_annotations); ngraph_op->set_op_annotations(op_annotations);
......
...@@ -126,7 +126,7 @@ namespace ngraph ...@@ -126,7 +126,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConcat) void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConcat)
{ {
auto quantized_concat = static_cast<op::QuantizedConcat*>(node); auto quantized_concat = static_cast<ngraph::op::QuantizedConcat*>(node);
if ((node->get_input_element_type(0) == element::i8 || if ((node->get_input_element_type(0) == element::i8 ||
node->get_input_element_type(0) == element::u8) && node->get_input_element_type(0) == element::u8) &&
...@@ -195,7 +195,7 @@ namespace ngraph ...@@ -195,7 +195,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBiasAdd) void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBiasAdd)
{ {
auto convolution = static_cast<op::ConvolutionBiasAdd*>(node); auto convolution = static_cast<ngraph::op::ConvolutionBiasAdd*>(node);
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::ConvolutionBiasAdd>(node)) if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::ConvolutionBiasAdd>(node))
{ {
...@@ -212,7 +212,7 @@ namespace ngraph ...@@ -212,7 +212,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::GetOutputElement) void CPUAssignment::ASSIGN_DECL(ngraph::op::GetOutputElement)
{ {
auto goe = static_cast<op::GetOutputElement*>(node); auto goe = static_cast<ngraph::op::GetOutputElement*>(node);
auto op_annotations = auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->add_in_place_oi_pair({0, goe->get_n(), false}); op_annotations->add_in_place_oi_pair({0, goe->get_n(), false});
...@@ -222,7 +222,7 @@ namespace ngraph ...@@ -222,7 +222,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionAdd) void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionAdd)
{ {
auto convolution = static_cast<op::ConvolutionAdd*>(node); auto convolution = static_cast<ngraph::op::ConvolutionAdd*>(node);
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::ConvolutionAdd>(node)) if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::ConvolutionAdd>(node))
{ {
...@@ -257,7 +257,7 @@ namespace ngraph ...@@ -257,7 +257,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBackpropData) void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBackpropData)
{ {
auto convolution = static_cast<op::ConvolutionBackpropData*>(node); auto convolution = static_cast<ngraph::op::ConvolutionBackpropData*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1); auto arg1_shape = node->get_input_shape(1);
...@@ -282,7 +282,7 @@ namespace ngraph ...@@ -282,7 +282,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBackpropFilters) void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBackpropFilters)
{ {
auto convolution = static_cast<op::ConvolutionBackpropFilters*>(node); auto convolution = static_cast<ngraph::op::ConvolutionBackpropFilters*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1); auto arg1_shape = node->get_input_shape(1);
...@@ -316,7 +316,8 @@ namespace ngraph ...@@ -316,7 +316,8 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBiasBackpropFiltersBias) void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBiasBackpropFiltersBias)
{ {
auto convolution = static_cast<op::ConvolutionBiasBackpropFiltersBias*>(node); auto convolution =
static_cast<ngraph::op::ConvolutionBiasBackpropFiltersBias*>(node);
auto data_shape = node->get_input_shape(0); auto data_shape = node->get_input_shape(0);
auto delta_shape = node->get_input_shape(1); auto delta_shape = node->get_input_shape(1);
...@@ -340,7 +341,7 @@ namespace ngraph ...@@ -340,7 +341,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::AvgPool) void CPUAssignment::ASSIGN_DECL(ngraph::op::AvgPool)
{ {
auto avg_pool = static_cast<op::AvgPool*>(node); auto avg_pool = static_cast<ngraph::op::AvgPool*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -357,7 +358,7 @@ namespace ngraph ...@@ -357,7 +358,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::AvgPoolBackprop) void CPUAssignment::ASSIGN_DECL(ngraph::op::AvgPoolBackprop)
{ {
auto avg_pool = static_cast<op::AvgPoolBackprop*>(node); auto avg_pool = static_cast<ngraph::op::AvgPoolBackprop*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -374,7 +375,7 @@ namespace ngraph ...@@ -374,7 +375,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPool) void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPool)
{ {
auto max_pool = static_cast<op::MaxPool*>(node); auto max_pool = static_cast<ngraph::op::MaxPool*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -391,7 +392,7 @@ namespace ngraph ...@@ -391,7 +392,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolWithIndices) void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolWithIndices)
{ {
auto max_pool = static_cast<op::MaxPoolWithIndices*>(node); auto max_pool = static_cast<ngraph::op::MaxPoolWithIndices*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -407,7 +408,7 @@ namespace ngraph ...@@ -407,7 +408,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolBackprop) void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolBackprop)
{ {
auto max_pool = static_cast<op::MaxPoolBackprop*>(node); auto max_pool = static_cast<ngraph::op::MaxPoolBackprop*>(node);
auto arg1_shape = node->get_input_shape(1); auto arg1_shape = node->get_input_shape(1);
auto arg1_rank = arg1_shape.size(); auto arg1_rank = arg1_shape.size();
...@@ -424,7 +425,7 @@ namespace ngraph ...@@ -424,7 +425,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolWithIndicesBackprop) void CPUAssignment::ASSIGN_DECL(ngraph::op::MaxPoolWithIndicesBackprop)
{ {
auto max_pool = static_cast<op::MaxPoolWithIndicesBackprop*>(node); auto max_pool = static_cast<ngraph::op::MaxPoolWithIndicesBackprop*>(node);
auto arg1_shape = node->get_input_shape(1); auto arg1_shape = node->get_input_shape(1);
auto arg1_rank = arg1_shape.size(); auto arg1_rank = arg1_shape.size();
...@@ -440,7 +441,7 @@ namespace ngraph ...@@ -440,7 +441,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Relu) void CPUAssignment::ASSIGN_DECL(ngraph::op::Relu)
{ {
auto relu = static_cast<op::Relu*>(node); auto relu = static_cast<ngraph::op::Relu*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -464,7 +465,7 @@ namespace ngraph ...@@ -464,7 +465,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ReplaceSlice) void CPUAssignment::ASSIGN_DECL(ngraph::op::ReplaceSlice)
{ {
auto replace_slice = static_cast<op::ReplaceSlice*>(node); auto replace_slice = static_cast<ngraph::op::ReplaceSlice*>(node);
// ReplaceSlice is independent of data type. Hence not checking type // ReplaceSlice is independent of data type. Hence not checking type
auto op_annotations = auto op_annotations =
...@@ -480,7 +481,7 @@ namespace ngraph ...@@ -480,7 +481,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::UpdateSlice) void CPUAssignment::ASSIGN_DECL(ngraph::op::UpdateSlice)
{ {
auto update_slice = static_cast<op::UpdateSlice*>(node); auto update_slice = static_cast<ngraph::op::UpdateSlice*>(node);
auto op_annotations = auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
...@@ -601,7 +602,7 @@ namespace ngraph ...@@ -601,7 +602,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Softmax) void CPUAssignment::ASSIGN_DECL(ngraph::op::Softmax)
{ {
auto softmax = static_cast<op::Softmax*>(node); auto softmax = static_cast<ngraph::op::Softmax*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -618,7 +619,7 @@ namespace ngraph ...@@ -618,7 +619,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Slice) void CPUAssignment::ASSIGN_DECL(ngraph::op::Slice)
{ {
auto slice = static_cast<op::Slice*>(node); auto slice = static_cast<ngraph::op::Slice*>(node);
auto strides = slice->get_strides(); auto strides = slice->get_strides();
if (!is_strided(strides) && node->get_input_element_type(0) == element::f32) if (!is_strided(strides) && node->get_input_element_type(0) == element::f32)
{ {
...@@ -649,7 +650,7 @@ namespace ngraph ...@@ -649,7 +650,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BoundedRelu) void CPUAssignment::ASSIGN_DECL(ngraph::op::BoundedRelu)
{ {
auto bounded_relu = static_cast<op::BoundedRelu*>(node); auto bounded_relu = static_cast<ngraph::op::BoundedRelu*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -673,7 +674,7 @@ namespace ngraph ...@@ -673,7 +674,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::LeakyRelu) void CPUAssignment::ASSIGN_DECL(ngraph::op::LeakyRelu)
{ {
auto leaky_relu = static_cast<op::LeakyRelu*>(node); auto leaky_relu = static_cast<ngraph::op::LeakyRelu*>(node);
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -719,7 +720,8 @@ namespace ngraph ...@@ -719,7 +720,8 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolutionBiasAdd) void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolutionBiasAdd)
{ {
auto quantized_conv_bias = static_cast<op::QuantizedConvolutionBiasAdd*>(node); auto quantized_conv_bias =
static_cast<ngraph::op::QuantizedConvolutionBiasAdd*>(node);
auto op_annotations = auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true); op_annotations->set_mkldnn_op(true);
...@@ -733,7 +735,7 @@ namespace ngraph ...@@ -733,7 +735,7 @@ namespace ngraph
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolutionBiasSignedAdd) void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolutionBiasSignedAdd)
{ {
auto quantized_conv_bias = auto quantized_conv_bias =
static_cast<op::QuantizedConvolutionBiasSignedAdd*>(node); static_cast<ngraph::op::QuantizedConvolutionBiasSignedAdd*>(node);
auto op_annotations = auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true); op_annotations->set_mkldnn_op(true);
...@@ -758,7 +760,7 @@ namespace ngraph ...@@ -758,7 +760,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Dequantize) void CPUAssignment::ASSIGN_DECL(ngraph::op::Dequantize)
{ {
auto dequantize = static_cast<op::Dequantize*>(node); auto dequantize = static_cast<ngraph::op::Dequantize*>(node);
// TODO(nbpatel): Support dynamic offset via mkldnn // TODO(nbpatel): Support dynamic offset via mkldnn
// Go through reference if the offset is not a constant // Go through reference if the offset is not a constant
if (!dequantize->get_argument(2)->is_constant()) if (!dequantize->get_argument(2)->is_constant())
...@@ -796,7 +798,7 @@ namespace ngraph ...@@ -796,7 +798,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Quantize) void CPUAssignment::ASSIGN_DECL(ngraph::op::Quantize)
{ {
auto quantize = static_cast<op::Quantize*>(node); auto quantize = static_cast<ngraph::op::Quantize*>(node);
// TODO(nbpatel): Support dynamic offset via mkldnn // TODO(nbpatel): Support dynamic offset via mkldnn
// Go through reference if the offset is not a constant // Go through reference if the offset is not a constant
if (!quantize->get_argument(2)->is_constant()) if (!quantize->get_argument(2)->is_constant())
...@@ -805,8 +807,8 @@ namespace ngraph ...@@ -805,8 +807,8 @@ namespace ngraph
} }
auto offset_const_op = auto offset_const_op =
std::static_pointer_cast<ngraph::op::Constant>(quantize->get_argument(2)); std::static_pointer_cast<ngraph::op::Constant>(quantize->get_argument(2));
op::Quantize::RoundMode round_mode = quantize->get_round_mode(); ngraph::op::Quantize::RoundMode round_mode = quantize->get_round_mode();
if (round_mode != op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN) if (round_mode != ngraph::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
{ {
return; return;
} }
...@@ -845,7 +847,7 @@ namespace ngraph ...@@ -845,7 +847,7 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Convert) void CPUAssignment::ASSIGN_DECL(ngraph::op::Convert)
{ {
auto convert = static_cast<op::Convert*>(node); auto convert = static_cast<ngraph::op::Convert*>(node);
if ((node->get_input_element_type(0) == element::i8 && if ((node->get_input_element_type(0) == element::i8 &&
node->get_output_element_type(0) == element::u8) || node->get_output_element_type(0) == element::u8) ||
(node->get_input_element_type(0) == element::u8 && (node->get_input_element_type(0) == element::u8 &&
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -64,7 +64,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -64,7 +64,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
{ {
if (n->description() == "Concat") if (n->description() == "Concat")
{ {
auto concat = std::static_pointer_cast<op::Concat>(n); auto concat = std::static_pointer_cast<ngraph::op::Concat>(n);
auto shape = concat->get_input_shape(0); auto shape = concat->get_input_shape(0);
auto axis = concat->get_concatenation_axis(); auto axis = concat->get_concatenation_axis();
auto product = 1; auto product = 1;
...@@ -134,7 +134,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -134,7 +134,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
{ {
if (arg->is_op()) if (arg->is_op())
{ {
auto op = std::static_pointer_cast<op::Op>(arg); auto op = std::static_pointer_cast<ngraph::op::Op>(arg);
auto annotation = op->get_op_annotations(); auto annotation = op->get_op_annotations();
if (annotation && annotation->get_in_place_oi_pairs().size() > 0) if (annotation && annotation->get_in_place_oi_pairs().size() > 0)
...@@ -177,7 +177,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -177,7 +177,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
{ {
if (user->is_op()) if (user->is_op())
{ {
auto op = std::static_pointer_cast<op::Op>(user); auto op = std::static_pointer_cast<ngraph::op::Op>(user);
if (auto op_annotations = op->get_op_annotations()) if (auto op_annotations = op->get_op_annotations())
{ {
if (op_annotations->get_in_place_oi_pairs().size() > 0) if (op_annotations->get_in_place_oi_pairs().size() > 0)
...@@ -227,7 +227,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr< ...@@ -227,7 +227,7 @@ bool runtime::cpu::pass::CPUMemoryOptimization::run_on_function(std::shared_ptr<
{ {
if (n->description() == "Slice") if (n->description() == "Slice")
{ {
auto slice = std::static_pointer_cast<op::Slice>(n); auto slice = std::static_pointer_cast<ngraph::op::Slice>(n);
auto in_shape = slice->get_input_shape(0); auto in_shape = slice->get_input_shape(0);
auto out_shape = slice->get_output_shape(0); auto out_shape = slice->get_output_shape(0);
auto strides = slice->get_strides(); auto strides = slice->get_strides();
......
...@@ -66,15 +66,16 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid() ...@@ -66,15 +66,16 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid()
{ {
// construct variance // construct variance
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4}); auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input); auto neg_input = std::make_shared<ngraph::op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input); auto exp_neg_input = std::make_shared<ngraph::op::Exp>(neg_input);
// broadcast input // broadcast input
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{}); auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto broadcast_constant = std::make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1}); auto broadcast_constant =
std::make_shared<ngraph::op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, broadcast_constant); auto add_exp = std::make_shared<ngraph::op::Add>(exp_neg_input, broadcast_constant);
auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp); auto divide_1_over_exp = std::make_shared<ngraph::op::Divide>(broadcast_constant, add_exp);
// Define a call back that needs to called once the DFG matches the pattern // Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
...@@ -96,7 +97,7 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid() ...@@ -96,7 +97,7 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid()
return false; return false;
} }
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]); auto sigmoid_node = std::make_shared<ngraph::op::Sigmoid>(pattern_map[input]);
ngraph::replace_node(m.get_match_root(), sigmoid_node); ngraph::replace_node(m.get_match_root(), sigmoid_node);
return true; return true;
}; };
...@@ -147,177 +148,184 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop() ...@@ -147,177 +148,184 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop()
auto ct_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100}); auto ct_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100});
auto broadcast_pred = [](std::shared_ptr<Node> n) { auto broadcast_pred = [](std::shared_ptr<Node> n) {
return ((std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr) || return ((std::dynamic_pointer_cast<ngraph::op::Broadcast>(n) != nullptr) ||
(std::dynamic_pointer_cast<op::Reshape>(n) != nullptr)); (std::dynamic_pointer_cast<ngraph::op::Reshape>(n) != nullptr));
}; };
// Fused MatMuls // Fused MatMuls
// (W_{ii} | (W_{if} | W_{ig} | W_{io}) * x_t + (b_{ii} | b_{if} | b_{ig} | b_{io}) // (W_{ii} | (W_{if} | W_{ig} | W_{io}) * x_t + (b_{ii} | b_{if} | b_{ig} | b_{io})
auto dot1 = std::make_shared<op::Dot>(xt, w_i2h); auto dot1 = std::make_shared<ngraph::op::Dot>(xt, w_i2h);
auto add1 = std::make_shared<op::Add>( auto add1 = std::make_shared<ngraph::op::Add>(
dot1, std::make_shared<pattern::op::Skip>(bias_i2h, broadcast_pred)); dot1, std::make_shared<pattern::op::Skip>(bias_i2h, broadcast_pred));
// (W_{hi} | (W_{hf} | W_{hg} | W_{ho}) * h_{(t-1)} + (b_{hi} | b_{hf} | b_{hg} | b_{ho}) // (W_{hi} | (W_{hf} | W_{hg} | W_{ho}) * h_{(t-1)} + (b_{hi} | b_{hf} | b_{hg} | b_{ho})
auto dot2 = std::make_shared<op::Dot>(ht_1, w_h2h); auto dot2 = std::make_shared<ngraph::op::Dot>(ht_1, w_h2h);
auto add2 = std::make_shared<op::Add>( auto add2 = std::make_shared<ngraph::op::Add>(
dot2, std::make_shared<pattern::op::Skip>(bias_h2h, broadcast_pred)); dot2, std::make_shared<pattern::op::Skip>(bias_h2h, broadcast_pred));
auto X = std::make_shared<op::Add>(add2, add1); auto X = std::make_shared<ngraph::op::Add>(add2, add1);
// construct gates // construct gates
auto it = std::make_shared<op::Sigmoid>( auto it = std::make_shared<ngraph::op::Sigmoid>(
std::make_shared<op::Slice>(X, Coordinate{0, 0}, Coordinate{10, 100})); std::make_shared<ngraph::op::Slice>(X, Coordinate{0, 0}, Coordinate{10, 100}));
auto ft = std::make_shared<op::Sigmoid>( auto ft = std::make_shared<ngraph::op::Sigmoid>(
std::make_shared<op::Slice>(X, Coordinate{0, 100}, Coordinate{10, 200})); std::make_shared<ngraph::op::Slice>(X, Coordinate{0, 100}, Coordinate{10, 200}));
auto gt = std::make_shared<op::Tanh>( auto gt = std::make_shared<ngraph::op::Tanh>(
std::make_shared<op::Slice>(X, Coordinate{0, 200}, Coordinate{10, 300})); std::make_shared<ngraph::op::Slice>(X, Coordinate{0, 200}, Coordinate{10, 300}));
auto ot = std::make_shared<op::Sigmoid>( auto ot = std::make_shared<ngraph::op::Sigmoid>(
std::make_shared<op::Slice>(X, Coordinate{0, 300}, Coordinate{10, 400})); std::make_shared<ngraph::op::Slice>(X, Coordinate{0, 300}, Coordinate{10, 400}));
// construct (c_t) cell state // construct (c_t) cell state
auto ct = std::make_shared<op::Add>(std::make_shared<op::Multiply>(ft, ct_1), auto ct = std::make_shared<ngraph::op::Add>(std::make_shared<ngraph::op::Multiply>(ft, ct_1),
std::make_shared<op::Multiply>(it, gt)); std::make_shared<ngraph::op::Multiply>(it, gt));
auto ct_label = std::make_shared<pattern::op::Label>(ct, nullptr, NodeVector{ct}); auto ct_label = std::make_shared<pattern::op::Label>(ct, nullptr, NodeVector{ct});
// construct (h_t) // construct (h_t)
auto ht = std::make_shared<op::Multiply>(ot, std::make_shared<op::Tanh>(ct_label)); auto ht =
std::make_shared<ngraph::op::Multiply>(ot, std::make_shared<ngraph::op::Tanh>(ct_label));
// Define a call back that needs to called once the DFG matches the pattern // Define a call back that needs to called once the DFG matches the pattern
pattern::graph_rewrite_callback callback = pattern::graph_rewrite_callback callback = [ct_label,
[ct_label, w_i2h, bias_i2h, w_h2h, bias_h2h, xt, ht_1, ct_1](pattern::Matcher& m) { w_i2h,
NGRAPH_DEBUG << "In a callback for construct_fprop_lstm pattern against " bias_i2h,
<< m.get_match_root()->get_name(); w_h2h,
bias_h2h,
xt,
ht_1,
ct_1](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_lstm pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
if (m.get_match_root()->get_element_type() != element::f32) if (m.get_match_root()->get_element_type() != element::f32)
{ {
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name() NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!"; << " type is not float!";
return false; return false;
} }
CHECK_RANK(pattern_map[xt], 2); CHECK_RANK(pattern_map[xt], 2);
CHECK_RANK(pattern_map[ht_1], 2); CHECK_RANK(pattern_map[ht_1], 2);
CHECK_RANK(pattern_map[w_i2h], 2); CHECK_RANK(pattern_map[w_i2h], 2);
CHECK_RANK(pattern_map[w_h2h], 2); CHECK_RANK(pattern_map[w_h2h], 2);
CHECK_RANK(pattern_map[bias_i2h], 1); CHECK_RANK(pattern_map[bias_i2h], 1);
CHECK_RANK(pattern_map[bias_h2h], 1); CHECK_RANK(pattern_map[bias_h2h], 1);
auto weights_layer = pattern_map[w_i2h]; auto weights_layer = pattern_map[w_i2h];
auto weights_iter = pattern_map[w_h2h]; auto weights_iter = pattern_map[w_h2h];
auto src_layer = pattern_map[xt]; auto src_layer = pattern_map[xt];
auto hidden_state = pattern_map[ht_1]; auto hidden_state = pattern_map[ht_1];
auto cell_state = pattern_map[ct_1]; auto cell_state = pattern_map[ct_1];
// TODO: (Pruthvi) temporary workaround for GNMT slow down // TODO: (Pruthvi) temporary workaround for GNMT slow down
// this checks avoids fusing of LSTM cells if its a part of decoder, we // this checks avoids fusing of LSTM cells if its a part of decoder, we
// will remove this once mkldnn optimizes individual LSTM cell or once // will remove this once mkldnn optimizes individual LSTM cell or once
// we have decoder pattern for GNMT. // we have decoder pattern for GNMT.
if (!(std::dynamic_pointer_cast<op::Broadcast>(cell_state) && if (!(std::dynamic_pointer_cast<ngraph::op::Broadcast>(cell_state) &&
std::dynamic_pointer_cast<op::Constant>(cell_state->get_argument(0))) && std::dynamic_pointer_cast<ngraph::op::Constant>(cell_state->get_argument(0))) &&
!(std::dynamic_pointer_cast<op::Slice>(cell_state) && !(std::dynamic_pointer_cast<ngraph::op::Slice>(cell_state) &&
std::dynamic_pointer_cast<op::GetOutputElement>(cell_state->get_argument(0)))) std::dynamic_pointer_cast<ngraph::op::GetOutputElement>(cell_state->get_argument(0))))
{ {
return false; return false;
} }
auto swap_lstm_inputs = [&]() -> void { auto swap_lstm_inputs = [&]() -> void {
src_layer = pattern_map[ht_1]; src_layer = pattern_map[ht_1];
hidden_state = pattern_map[xt]; hidden_state = pattern_map[xt];
weights_layer = pattern_map[w_h2h]; weights_layer = pattern_map[w_h2h];
weights_iter = pattern_map[w_i2h]; weights_iter = pattern_map[w_i2h];
}; };
// LSTM kernel expects ht_1 and ct_1 to have the same shape but the // LSTM kernel expects ht_1 and ct_1 to have the same shape but the
// pattern matcher cannot guarantee this since the computations are // pattern matcher cannot guarantee this since the computations are
// symmetric around x_t and ht_1. Use heuristics to swap the matched // symmetric around x_t and ht_1. Use heuristics to swap the matched
// labels // labels
if (std::dynamic_pointer_cast<op::Broadcast>(src_layer) && if (std::dynamic_pointer_cast<ngraph::op::Broadcast>(src_layer) &&
std::dynamic_pointer_cast<op::Constant>(src_layer->get_argument(0))) std::dynamic_pointer_cast<ngraph::op::Constant>(src_layer->get_argument(0)))
{ {
// First timestep of an RNN layer // First timestep of an RNN layer
swap_lstm_inputs(); swap_lstm_inputs();
} }
else if (hidden_state->get_shape() != cell_state->get_shape()) else if (hidden_state->get_shape() != cell_state->get_shape())
{
swap_lstm_inputs();
}
else if (std::dynamic_pointer_cast<ngraph::op::GetOutputElement>(
cell_state->get_argument(0)))
{
// swap the inputs if the cell_state and hidden state does not
// belong to the same Lstm
if (!hidden_state->get_argument(0)->get_arguments().size() ||
(hidden_state->get_argument(0)->get_arguments()[0] !=
cell_state->get_argument(0)->get_arguments()[0]))
{ {
swap_lstm_inputs(); swap_lstm_inputs();
} }
else if (std::dynamic_pointer_cast<op::GetOutputElement>(cell_state->get_argument(0))) }
{
// swap the inputs if the cell_state and hidden state does not
// belong to the same Lstm
if (!hidden_state->get_argument(0)->get_arguments().size() ||
(hidden_state->get_argument(0)->get_arguments()[0] !=
cell_state->get_argument(0)->get_arguments()[0]))
{
swap_lstm_inputs();
}
}
if (hidden_state->get_shape() != cell_state->get_shape()) if (hidden_state->get_shape() != cell_state->get_shape())
{ {
NGRAPH_DEBUG NGRAPH_DEBUG << "Lstm MKLDNN kernel requires recurrent output hidden states to match ";
<< "Lstm MKLDNN kernel requires recurrent output hidden states to match "; return false;
return false; }
}
// set LSTM cell attributes // set LSTM cell attributes
size_t lstm_n_gates = 4; size_t lstm_n_gates = 4;
size_t batch_size = src_layer->get_shape()[0]; size_t batch_size = src_layer->get_shape()[0];
size_t direction = 1; size_t direction = 1;
size_t layers = 1; size_t layers = 1;
auto dlc = weights_layer->get_shape()[1] / (lstm_n_gates * direction * layers); auto dlc = weights_layer->get_shape()[1] / (lstm_n_gates * direction * layers);
auto slc = weights_layer->get_shape()[0]; auto slc = weights_layer->get_shape()[0];
auto dic = weights_iter->get_shape()[1] / (lstm_n_gates * direction * layers); auto dic = weights_iter->get_shape()[1] / (lstm_n_gates * direction * layers);
auto sic = weights_iter->get_shape()[0]; auto sic = weights_iter->get_shape()[0];
ngraph::runtime::cpu::rnn_utils::rnntype rnn_type = ngraph::runtime::cpu::rnn_utils::rnntype rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm; ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
if (dlc != dic)
{
NGRAPH_DEBUG << "Not fusing, since Lstm kernel requires dst_layer feature size "
<< "equals to dts_iter feature size";
return false;
}
std::shared_ptr<Node> src_iter = if (dlc != dic)
std::make_shared<op::Concat>(NodeVector{hidden_state, cell_state}, 0); {
if (src_layer->get_shape()[1] != slc || src_iter->get_shape()[1] != sic) NGRAPH_DEBUG << "Not fusing, since Lstm kernel requires dst_layer feature size "
{ << "equals to dts_iter feature size";
NGRAPH_DEBUG << "Feature size mismatch between weights and input tensors"; return false;
return false; }
}
auto bias = std::make_shared<op::Add>(pattern_map[bias_i2h], pattern_map[bias_h2h]); std::shared_ptr<Node> src_iter =
std::make_shared<ngraph::op::Concat>(NodeVector{hidden_state, cell_state}, 0);
if (src_layer->get_shape()[1] != slc || src_iter->get_shape()[1] != sic)
{
NGRAPH_DEBUG << "Feature size mismatch between weights and input tensors";
return false;
}
auto lstm_node = std::make_shared<op::Lstm>( auto bias = std::make_shared<ngraph::op::Add>(pattern_map[bias_i2h], pattern_map[bias_h2h]);
src_layer, src_iter, weights_layer, weights_iter, bias, rnn_type);
auto lstm_ht_output = std::make_shared<op::GetOutputElement>(lstm_node, 0); auto lstm_node = std::make_shared<ngraph::op::Lstm>(
auto lstm_ht_ct_output = std::make_shared<op::GetOutputElement>(lstm_node, 1); src_layer, src_iter, weights_layer, weights_iter, bias, rnn_type);
// dst_iter of lstm mkldnn output holds the results of both recurrent state auto lstm_ht_output = std::make_shared<ngraph::op::GetOutputElement>(lstm_node, 0);
// tensor outputs. we need to slice the ct. auto lstm_ht_ct_output = std::make_shared<ngraph::op::GetOutputElement>(lstm_node, 1);
auto ht_slice = std::make_shared<op::Slice>(
lstm_ht_output, Coordinate{0, 0}, Coordinate{batch_size, dlc});
auto ct_slice = std::make_shared<op::Slice>(
lstm_ht_ct_output, Coordinate{batch_size, 0}, Coordinate{(2 * batch_size), dic});
if (lstm_node->get_outputs().at(0).get_inputs().size() != 2) // dst_iter of lstm mkldnn output holds the results of both recurrent state
{ // tensor outputs. we need to slice the ct.
throw ngraph_error("Lstm node doesnt have two outputs"); auto ht_slice = std::make_shared<ngraph::op::Slice>(
} lstm_ht_output, Coordinate{0, 0}, Coordinate{batch_size, dlc});
// Now identify the nodes which consumes the output of LSTM nodes auto ct_slice = std::make_shared<ngraph::op::Slice>(
// and replace them accordingly lstm_ht_ct_output, Coordinate{batch_size, 0}, Coordinate{(2 * batch_size), dic});
// find the user's for {ht|ct} and replace them with lstm_goe_1
if (ngraph::is_used(pattern_map[ct_label].get())) if (lstm_node->get_outputs().at(0).get_inputs().size() != 2)
{ {
replace_collapse_node_user(pattern_map[ct_label], ct_slice->get_outputs().at(0)); throw ngraph_error("Lstm node doesnt have two outputs");
} }
// find the user's for {ht} and replace them with lstm_goe_0 // Now identify the nodes which consumes the output of LSTM nodes
ngraph::replace_node(m.get_match_root(), ht_slice); // and replace them accordingly
return true; // find the user's for {ht|ct} and replace them with lstm_goe_1
}; if (ngraph::is_used(pattern_map[ct_label].get()))
{
replace_collapse_node_user(pattern_map[ct_label], ct_slice->get_outputs().at(0));
}
// find the user's for {ht} and replace them with lstm_goe_0
ngraph::replace_node(m.get_match_root(), ht_slice);
return true;
};
auto m = std::make_shared<pattern::Matcher>(ht, callback, "LSTMFusion.Fprop"); auto m = std::make_shared<pattern::Matcher>(ht, callback, "LSTMFusion.Fprop");
this->add_matcher(m); this->add_matcher(m);
} }
...@@ -330,44 +338,45 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -330,44 +338,45 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
auto lstm_ht = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100}); auto lstm_ht = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100});
auto lstm_ct = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100}); auto lstm_ct = std::make_shared<pattern::op::Label>(element::f32, Shape{10, 100});
auto lstm_src_iter = std::make_shared<op::Concat>(NodeVector{lstm_ht, lstm_ct}, 0); auto lstm_src_iter = std::make_shared<ngraph::op::Concat>(NodeVector{lstm_ht, lstm_ct}, 0);
auto lstm_src_iter_label = auto lstm_src_iter_label =
std::make_shared<pattern::op::Label>(lstm_src_iter, nullptr, NodeVector{lstm_src_iter}); std::make_shared<pattern::op::Label>(lstm_src_iter, nullptr, NodeVector{lstm_src_iter});
auto lstm_weights_layer_shared = std::make_shared<pattern::op::Label>( auto lstm_weights_layer_shared = std::make_shared<pattern::op::Label>(
element::f32, Shape{400, 100}, pattern::has_class<op::Parameter>()); element::f32, Shape{400, 100}, pattern::has_class<ngraph::op::Parameter>());
auto lstm_weights_layer = auto lstm_weights_layer = std::make_shared<ngraph::op::Reshape>(
std::make_shared<op::Reshape>(lstm_weights_layer_shared, AxisVector{1, 0}, Shape{100, 400}); lstm_weights_layer_shared, AxisVector{1, 0}, Shape{100, 400});
auto lstm_weights_layer_label = std::make_shared<pattern::op::Label>( auto lstm_weights_layer_label = std::make_shared<pattern::op::Label>(
lstm_weights_layer, nullptr, NodeVector{lstm_weights_layer}); lstm_weights_layer, nullptr, NodeVector{lstm_weights_layer});
auto lstm_weights_iter_shared = std::make_shared<pattern::op::Label>( auto lstm_weights_iter_shared = std::make_shared<pattern::op::Label>(
element::f32, Shape{400, 100}, pattern::has_class<op::Parameter>()); element::f32, Shape{400, 100}, pattern::has_class<ngraph::op::Parameter>());
auto lstm_weights_iter = auto lstm_weights_iter = std::make_shared<ngraph::op::Reshape>(
std::make_shared<op::Reshape>(lstm_weights_iter_shared, AxisVector{1, 0}, Shape{100, 400}); lstm_weights_iter_shared, AxisVector{1, 0}, Shape{100, 400});
auto lstm_weights_iter_label = std::make_shared<pattern::op::Label>( auto lstm_weights_iter_label = std::make_shared<pattern::op::Label>(
lstm_weights_iter, nullptr, NodeVector{lstm_weights_iter}); lstm_weights_iter, nullptr, NodeVector{lstm_weights_iter});
auto lstm_bias_layer_shared = std::make_shared<pattern::op::Label>(element::f32, Shape{400}); auto lstm_bias_layer_shared = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto lstm_bias_iter_shared = std::make_shared<pattern::op::Label>(element::f32, Shape{400}); auto lstm_bias_iter_shared = std::make_shared<pattern::op::Label>(element::f32, Shape{400});
auto lstm_bias = std::make_shared<op::Add>(lstm_bias_layer_shared, lstm_bias_iter_shared); auto lstm_bias =
std::make_shared<ngraph::op::Add>(lstm_bias_layer_shared, lstm_bias_iter_shared);
auto lstm_bias_label = auto lstm_bias_label =
std::make_shared<pattern::op::Label>(lstm_bias, nullptr, NodeVector{lstm_bias}); std::make_shared<pattern::op::Label>(lstm_bias, nullptr, NodeVector{lstm_bias});
ngraph::runtime::cpu::rnn_utils::rnntype ref_rnn_type = ngraph::runtime::cpu::rnn_utils::rnntype ref_rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm; ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
auto lstm = std::make_shared<op::Lstm>(lstm_src_layer, auto lstm = std::make_shared<ngraph::op::Lstm>(lstm_src_layer,
lstm_src_iter_label, lstm_src_iter_label,
lstm_weights_layer_label, lstm_weights_layer_label,
lstm_weights_iter_label, lstm_weights_iter_label,
lstm_bias_label, lstm_bias_label,
ref_rnn_type); ref_rnn_type);
auto lstm_goe = std::make_shared<op::GetOutputElement>(lstm, 1); auto lstm_goe = std::make_shared<ngraph::op::GetOutputElement>(lstm, 1);
// We cannot attach labels to multi-output nodes, so we attach a label to the goe instead // We cannot attach labels to multi-output nodes, so we attach a label to the goe instead
auto lstm_goe_label = auto lstm_goe_label =
std::make_shared<pattern::op::Label>(lstm_goe, nullptr, NodeVector{lstm_goe}); std::make_shared<pattern::op::Label>(lstm_goe, nullptr, NodeVector{lstm_goe});
auto lstm_goe_slice = auto lstm_goe_slice =
std::make_shared<op::Slice>(lstm_goe_label, Coordinate{10, 0}, Coordinate{20, 100}); std::make_shared<ngraph::op::Slice>(lstm_goe_label, Coordinate{10, 0}, Coordinate{20, 100});
pattern::recurrent_graph_rewrite_callback callback = [lstm_goe_label, pattern::recurrent_graph_rewrite_callback callback = [lstm_goe_label,
lstm_src_layer, lstm_src_layer,
...@@ -387,7 +396,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -387,7 +396,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
{ {
auto node_labels = m.get_bound_nodes_for_pattern(input_label); auto node_labels = m.get_bound_nodes_for_pattern(input_label);
std::reverse(node_labels.begin(), node_labels.end()); std::reverse(node_labels.begin(), node_labels.end());
return std::make_shared<op::Concat>(node_labels, 0); return std::make_shared<ngraph::op::Concat>(node_labels, 0);
} }
}; };
...@@ -429,9 +438,9 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -429,9 +438,9 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
} }
auto check_const_input = [&](std::shared_ptr<Node> n) { auto check_const_input = [&](std::shared_ptr<Node> n) {
if (std::dynamic_pointer_cast<op::Constant>(n) || if (std::dynamic_pointer_cast<ngraph::op::Constant>(n) ||
(std::dynamic_pointer_cast<op::Broadcast>(n) && (std::dynamic_pointer_cast<ngraph::op::Broadcast>(n) &&
std::dynamic_pointer_cast<op::Constant>(n->get_argument(0)))) std::dynamic_pointer_cast<ngraph::op::Constant>(n->get_argument(0))))
{ {
return true; return true;
} }
...@@ -460,26 +469,27 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -460,26 +469,27 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
return false; return false;
} }
auto rnn = std::make_shared<op::Rnn>(rnn_src_layer, auto rnn = std::make_shared<ngraph::op::Rnn>(rnn_src_layer,
rnn_src_iter, rnn_src_iter,
rnn_weights_layer, rnn_weights_layer,
rnn_weights_iter, rnn_weights_iter,
rnn_bias, rnn_bias,
sequence_len, sequence_len,
lstm_n_gates, lstm_n_gates,
sequence_len, sequence_len,
num_cell_states, num_cell_states,
direction, direction,
num_fused_rnn_layers, num_fused_rnn_layers,
rnn_type); rnn_type);
std::vector<std::shared_ptr<op::Slice>> ht_slice_per_timestep(sequence_len, nullptr); std::vector<std::shared_ptr<ngraph::op::Slice>> ht_slice_per_timestep(sequence_len,
auto rnn_ht_goe = std::make_shared<op::GetOutputElement>(rnn, 0); nullptr);
auto rnn_ht_ct_goe = std::make_shared<op::GetOutputElement>(rnn, 1); auto rnn_ht_goe = std::make_shared<ngraph::op::GetOutputElement>(rnn, 0);
auto rnn_ht_ct_goe = std::make_shared<ngraph::op::GetOutputElement>(rnn, 1);
for (size_t i = 0, start_index = 0; i < sequence_len; i++, start_index += batch_size) for (size_t i = 0, start_index = 0; i < sequence_len; i++, start_index += batch_size)
{ {
ht_slice_per_timestep[i] = (std::make_shared<op::Slice>( ht_slice_per_timestep[i] = (std::make_shared<ngraph::op::Slice>(
rnn_ht_goe, rnn_ht_goe,
Coordinate{start_index, 0}, Coordinate{start_index, 0},
Coordinate{start_index + batch_size, src_iter_feature_size})); Coordinate{start_index + batch_size, src_iter_feature_size}));
...@@ -503,7 +513,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -503,7 +513,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
for (size_t index = 0; index < sequence_len; index++) for (size_t index = 0; index < sequence_len; index++)
{ {
auto goe_nodes = op::get_output_elements(lstm_nodes[index]); auto goe_nodes = ngraph::op::get_output_elements(lstm_nodes[index]);
// if there is no GOE followed by the Lstm, their might be pattern match error // if there is no GOE followed by the Lstm, their might be pattern match error
// we will return safely // we will return safely
...@@ -521,7 +531,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -521,7 +531,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
{ {
if (ngraph::is_used(goe0_user.get())) if (ngraph::is_used(goe0_user.get()))
{ {
if (!std::dynamic_pointer_cast<op::Slice>(goe0_user)) if (!std::dynamic_pointer_cast<ngraph::op::Slice>(goe0_user))
{ {
NGRAPH_DEBUG << "Did not find LSTM slice to replace with RNN slice"; NGRAPH_DEBUG << "Did not find LSTM slice to replace with RNN slice";
return false; return false;
...@@ -536,7 +546,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -536,7 +546,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
} }
} }
auto rnn_ct_goe = op::get_output_elements(lstm_nodes[sequence_len - 1])[1]; auto rnn_ct_goe = ngraph::op::get_output_elements(lstm_nodes[sequence_len - 1])[1];
if (rnn_ct_goe) if (rnn_ct_goe)
{ {
replace_collapse_node_user(rnn_ct_goe, rnn_ht_ct_goe->get_outputs().at(0)); replace_collapse_node_user(rnn_ct_goe, rnn_ht_ct_goe->get_outputs().at(0));
...@@ -566,7 +576,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -566,7 +576,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
static std::shared_ptr<Node> stack_rnn_inputs(NodeVector rnn_input_nodes) static std::shared_ptr<Node> stack_rnn_inputs(NodeVector rnn_input_nodes)
{ {
std::reverse(rnn_input_nodes.begin(), rnn_input_nodes.end()); std::reverse(rnn_input_nodes.begin(), rnn_input_nodes.end());
return std::make_shared<op::Concat>(rnn_input_nodes, 0); return std::make_shared<ngraph::op::Concat>(rnn_input_nodes, 0);
} }
void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_fusion_fprop() void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_fusion_fprop()
...@@ -585,20 +595,20 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -585,20 +595,20 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
ngraph::runtime::cpu::rnn_utils::rnntype ref_rnn_type = ngraph::runtime::cpu::rnn_utils::rnntype ref_rnn_type =
ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm; ngraph::runtime::cpu::rnn_utils::rnntype::vanilla_lstm;
auto ref_rnn_node = std::make_shared<op::Rnn>(rnn_src_layer, auto ref_rnn_node = std::make_shared<ngraph::op::Rnn>(rnn_src_layer,
rnn_src_iter, rnn_src_iter,
rnn_weights_layer, rnn_weights_layer,
rnn_weights_iter, rnn_weights_iter,
rnn_bias, rnn_bias,
ref_number_of_timesteps, ref_number_of_timesteps,
ref_number_of_gates_per_cell, ref_number_of_gates_per_cell,
ref_src_seq_length, ref_src_seq_length,
ref_num_rnn_cell_states, ref_num_rnn_cell_states,
ref_rnn_direction, ref_rnn_direction,
ref_num_of_rnn_fused_layer, ref_num_of_rnn_fused_layer,
ref_rnn_type); ref_rnn_type);
auto rnn_goe0 = std::make_shared<op::GetOutputElement>(ref_rnn_node, 0); auto rnn_goe0 = std::make_shared<ngraph::op::GetOutputElement>(ref_rnn_node, 0);
auto rnn_goe0_label = auto rnn_goe0_label =
std::make_shared<pattern::op::Label>(rnn_goe0, nullptr, NodeVector{rnn_goe0}); std::make_shared<pattern::op::Label>(rnn_goe0, nullptr, NodeVector{rnn_goe0});
...@@ -622,10 +632,11 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -622,10 +632,11 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
auto rnn_goe0_bounded_nodes = m.get_bound_nodes_for_pattern(rnn_goe0_label); auto rnn_goe0_bounded_nodes = m.get_bound_nodes_for_pattern(rnn_goe0_label);
std::vector<std::shared_ptr<op::Rnn>> rnn_nodes; std::vector<std::shared_ptr<ngraph::op::Rnn>> rnn_nodes;
for (auto rnn_goe : m.get_bound_nodes_for_pattern(rnn_goe0_label)) for (auto rnn_goe : m.get_bound_nodes_for_pattern(rnn_goe0_label))
{ {
if (auto rnn_op = std::dynamic_pointer_cast<op::Rnn>(rnn_goe->get_arguments()[0])) if (auto rnn_op =
std::dynamic_pointer_cast<ngraph::op::Rnn>(rnn_goe->get_arguments()[0]))
{ {
rnn_nodes.push_back(rnn_op); rnn_nodes.push_back(rnn_op);
} }
...@@ -695,21 +706,21 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -695,21 +706,21 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
"layer"); "layer");
} }
auto rnn = std::make_shared<op::Rnn>(mrnn_src_layer, auto rnn = std::make_shared<ngraph::op::Rnn>(mrnn_src_layer,
mrnn_src_iter, mrnn_src_iter,
mrnn_weights_layer, mrnn_weights_layer,
mrnn_weights_iter, mrnn_weights_iter,
mrnn_bias, mrnn_bias,
num_timesteps, num_timesteps,
lstm_n_gates, lstm_n_gates,
sequence_len, sequence_len,
num_rnn_cell_states, num_rnn_cell_states,
rnn_direction, rnn_direction,
num_fused_rnn_layers, num_fused_rnn_layers,
rnn_type); rnn_type);
auto mrnn_ht = std::make_shared<op::GetOutputElement>(rnn, 0); auto mrnn_ht = std::make_shared<ngraph::op::GetOutputElement>(rnn, 0);
auto mrnn_ht_ct = std::make_shared<op::GetOutputElement>(rnn, 1); auto mrnn_ht_ct = std::make_shared<ngraph::op::GetOutputElement>(rnn, 1);
// Replace all the users of RNN cell state {ct} across different user. // Replace all the users of RNN cell state {ct} across different user.
auto replace_rnn_output_cellstate = [&](std::shared_ptr<Node> rnn_ct_goe1, size_t layer) { auto replace_rnn_output_cellstate = [&](std::shared_ptr<Node> rnn_ct_goe1, size_t layer) {
...@@ -718,7 +729,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -718,7 +729,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
// of all the layers, {{ht_1 | ct_1} || {ht2 |ct2} || ....{htn | ctn}} // of all the layers, {{ht_1 | ct_1} || {ht2 |ct2} || ....{htn | ctn}}
// we will slice the cell state output tensor {ct_*} from the fused RNN kerenel output and feeds // we will slice the cell state output tensor {ct_*} from the fused RNN kerenel output and feeds
// {ct_*} consumer if any // {ct_*} consumer if any
auto ct_slice = std::make_shared<op::Slice>( auto ct_slice = std::make_shared<ngraph::op::Slice>(
mrnn_ht_ct, mrnn_ht_ct,
Coordinate{((layer - 1) * batch_size * num_rnn_cell_states) + batch_size, 0}, Coordinate{((layer - 1) * batch_size * num_rnn_cell_states) + batch_size, 0},
Coordinate{layer * batch_size * num_rnn_cell_states, src_iter_feature_size}); Coordinate{layer * batch_size * num_rnn_cell_states, src_iter_feature_size});
...@@ -732,7 +743,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -732,7 +743,7 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
// i.e {RNN7, RNN6, RNN5.... RNN0} // i.e {RNN7, RNN6, RNN5.... RNN0}
for (size_t index = 0; index < rnn_nodes.size(); index++) for (size_t index = 0; index < rnn_nodes.size(); index++)
{ {
auto goe_nodes = op::get_output_elements(rnn_nodes[index]); auto goe_nodes = ngraph::op::get_output_elements(rnn_nodes[index]);
// if there is no GOE followed by the Lstm, their might be pattern match error // if there is no GOE followed by the Lstm, their might be pattern match error
// we will return safely // we will return safely
if (goe_nodes.size() != 2) if (goe_nodes.size() != 2)
...@@ -771,15 +782,17 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -771,15 +782,17 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn() void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
{ {
auto rnn_left_to_right = std::make_shared<pattern::op::Label>( auto rnn_left_to_right = std::make_shared<pattern::op::Label>(
element::f32, Shape{1, 256}, pattern::has_class<op::Rnn>()); element::f32, Shape{1, 256}, pattern::has_class<ngraph::op::Rnn>());
auto rnn_right_to_left = std::make_shared<pattern::op::Label>( auto rnn_right_to_left = std::make_shared<pattern::op::Label>(
element::f32, Shape{1, 256}, pattern::has_class<op::Rnn>()); element::f32, Shape{1, 256}, pattern::has_class<ngraph::op::Rnn>());
auto reshape_pred = [](std::shared_ptr<Node> n) { auto reshape_pred = [](std::shared_ptr<Node> n) {
return (std::dynamic_pointer_cast<op::Reshape>(n) != nullptr); return (std::dynamic_pointer_cast<ngraph::op::Reshape>(n) != nullptr);
}; };
auto rnn_left_to_right_goe0 = std::make_shared<op::GetOutputElement>(rnn_left_to_right, 0); auto rnn_left_to_right_goe0 =
auto rnn_right_to_left_goe0 = std::make_shared<op::GetOutputElement>(rnn_right_to_left, 0); std::make_shared<ngraph::op::GetOutputElement>(rnn_left_to_right, 0);
auto rnn_right_to_left_goe0 =
std::make_shared<ngraph::op::GetOutputElement>(rnn_right_to_left, 0);
auto rnn_rtol_goe0_reshape_ntc = auto rnn_rtol_goe0_reshape_ntc =
std::make_shared<pattern::op::Skip>(rnn_right_to_left_goe0, reshape_pred); std::make_shared<pattern::op::Skip>(rnn_right_to_left_goe0, reshape_pred);
...@@ -791,21 +804,23 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn() ...@@ -791,21 +804,23 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
std::make_shared<pattern::op::Skip>(rnn_ltor_goe0_reshape_ntc, reshape_pred); std::make_shared<pattern::op::Skip>(rnn_ltor_goe0_reshape_ntc, reshape_pred);
auto reverse_seq_predicate = [](std::shared_ptr<Node> node) { auto reverse_seq_predicate = [](std::shared_ptr<Node> node) {
return pattern::has_class<op::ReverseSequence>()(node) || return pattern::has_class<ngraph::op::ReverseSequence>()(node) ||
pattern::has_class<op::Reverse>()(node); pattern::has_class<ngraph::op::Reverse>()(node);
}; };
auto skip_reverse_seq = auto skip_reverse_seq =
std::make_shared<pattern::op::Skip>(rnn_rtol_goe0_reshape_tnc, reverse_seq_predicate); std::make_shared<pattern::op::Skip>(rnn_rtol_goe0_reshape_tnc, reverse_seq_predicate);
auto concat = auto concat = std::make_shared<ngraph::op::Concat>(
std::make_shared<op::Concat>(NodeVector{rnn_ltor_goe0_reshape_tnc, skip_reverse_seq}, 0); NodeVector{rnn_ltor_goe0_reshape_tnc, skip_reverse_seq}, 0);
// Define a call back that needs to called once the DFG matches the pattern // Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [rnn_left_to_right, ngraph::pattern::graph_rewrite_callback callback = [rnn_left_to_right,
rnn_right_to_left](pattern::Matcher& m) { rnn_right_to_left](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto rnn_ltor_node = std::static_pointer_cast<op::Rnn>(pattern_map[rnn_left_to_right]); auto rnn_ltor_node =
auto rnn_rtol_node = std::static_pointer_cast<op::Rnn>(pattern_map[rnn_right_to_left]); std::static_pointer_cast<ngraph::op::Rnn>(pattern_map[rnn_left_to_right]);
auto rnn_rtol_node =
std::static_pointer_cast<ngraph::op::Rnn>(pattern_map[rnn_right_to_left]);
if (rnn_ltor_node->get_src_sequence_length() != rnn_rtol_node->get_src_sequence_length()) if (rnn_ltor_node->get_src_sequence_length() != rnn_rtol_node->get_src_sequence_length())
{ {
...@@ -852,7 +867,7 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn() ...@@ -852,7 +867,7 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
auto nodes = auto nodes =
NodeVector{rnn_ltor_node->get_argument(index), rnn_rtol_node->get_argument(index)}; NodeVector{rnn_ltor_node->get_argument(index), rnn_rtol_node->get_argument(index)};
return std::make_shared<op::Concat>(nodes, 0); return std::make_shared<ngraph::op::Concat>(nodes, 0);
}; };
auto src_layer = rnn_ltor_node->get_arguments()[0]; auto src_layer = rnn_ltor_node->get_arguments()[0];
...@@ -861,20 +876,20 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn() ...@@ -861,20 +876,20 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
auto weights_iter = construct_birnn_inputs(3); auto weights_iter = construct_birnn_inputs(3);
auto bias = construct_birnn_inputs(4); auto bias = construct_birnn_inputs(4);
auto rnn = std::make_shared<op::Rnn>(src_layer, auto rnn = std::make_shared<ngraph::op::Rnn>(src_layer,
src_iter, src_iter,
weights_layer, weights_layer,
weights_iter, weights_iter,
bias, bias,
num_time_steps, num_time_steps,
lstm_n_gates, lstm_n_gates,
sequence_len, sequence_len,
num_rnn_cell_states, num_rnn_cell_states,
rnn_direction, rnn_direction,
num_fused_rnn_layers, num_fused_rnn_layers,
rnn_type); rnn_type);
auto layer_rnn_ht = std::make_shared<op::GetOutputElement>(rnn, 0); auto layer_rnn_ht = std::make_shared<ngraph::op::GetOutputElement>(rnn, 0);
size_t batch_size = layer_rnn_ht->get_shape()[0] / num_time_steps; size_t batch_size = layer_rnn_ht->get_shape()[0] / num_time_steps;
size_t feature_size = layer_rnn_ht->get_shape()[1]; size_t feature_size = layer_rnn_ht->get_shape()[1];
...@@ -882,17 +897,17 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn() ...@@ -882,17 +897,17 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
std::shared_ptr<Node> layer_rnn_ht_reshape = layer_rnn_ht; std::shared_ptr<Node> layer_rnn_ht_reshape = layer_rnn_ht;
if (m.get_match_root()->get_shape() != layer_rnn_ht->get_shape()) if (m.get_match_root()->get_shape() != layer_rnn_ht->get_shape())
{ {
layer_rnn_ht_reshape = std::make_shared<op::Reshape>( layer_rnn_ht_reshape = std::make_shared<ngraph::op::Reshape>(
layer_rnn_ht, AxisVector{0, 1}, Shape{num_time_steps, batch_size, feature_size}); layer_rnn_ht, AxisVector{0, 1}, Shape{num_time_steps, batch_size, feature_size});
} }
// we will check if the node being replaced is in Shape{n, t, c}, if so we will transpose // we will check if the node being replaced is in Shape{n, t, c}, if so we will transpose
if (m.get_match_root()->get_shape() == Shape{batch_size, num_time_steps, feature_size}) if (m.get_match_root()->get_shape() == Shape{batch_size, num_time_steps, feature_size})
{ {
layer_rnn_ht_reshape = layer_rnn_ht_reshape = std::make_shared<ngraph::op::Reshape>(
std::make_shared<op::Reshape>(layer_rnn_ht_reshape, layer_rnn_ht_reshape,
AxisVector{1, 0, 2}, AxisVector{1, 0, 2},
Shape{batch_size, num_time_steps, feature_size}); Shape{batch_size, num_time_steps, feature_size});
} }
ngraph::replace_node(m.get_match_root(), layer_rnn_ht_reshape); ngraph::replace_node(m.get_match_root(), layer_rnn_ht_reshape);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment