Commit 9a625060 authored by pthoreho's avatar pthoreho

- Addressed PR comments

- support layout assignment pass to mkldnn add, to insert reorder if the input1 and input0 format are different
parent 34d84d3d
...@@ -161,6 +161,6 @@ size_t MKLDNNEmitter::build_elementwise_add( ...@@ -161,6 +161,6 @@ size_t MKLDNNEmitter::build_elementwise_add(
size_t add_index = insert_primitive( size_t add_index = insert_primitive(
new mkldnn::sum(sum_pd, inputs_primitive, *mkldnn_primitives[result_index])); new mkldnn::sum(sum_pd, inputs_primitive, *mkldnn_primitives[result_index]));
primitive_deps[add_index] = {input1_data_index, input0_data_index, result_index}; primitive_deps[add_index] = {input0_data_index, input1_data_index, result_index};
return add_index; return add_index;
} }
\ No newline at end of file
...@@ -60,8 +60,8 @@ namespace ngraph ...@@ -60,8 +60,8 @@ namespace ngraph
// insert Add as MKLDNN op, only if the src_size is big. this is to avoid MKLDNN overhead // insert Add as MKLDNN op, only if the src_size is big. this is to avoid MKLDNN overhead
// for smaller tensor sizes // for smaller tensor sizes
if (node->get_input_element_type(0) == element::f32 && if (node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32 && arg0_rank >= 1 && node->get_input_element_type(1) == element::f32 && arg0_rank == 4 &&
arg1_rank >= 1 && src_size > 64000) arg1_rank == 4 && src_size > 64000)
{ {
auto op_annotations = auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "cpu_layout.hpp" #include "cpu_layout.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/op.hpp"
...@@ -666,6 +667,29 @@ namespace ngraph ...@@ -666,6 +667,29 @@ namespace ngraph
set_default_layouts(external_function, node); set_default_layouts(external_function, node);
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Add)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input0_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(input0_layout);
prim_input_formats.push_back(input0_layout);
prim_output_formats.push_back(input0_layout);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment