Commit 8578694b authored by pthoreho's avatar pthoreho

style fix

parent f9191dd9
...@@ -150,34 +150,37 @@ namespace ngraph ...@@ -150,34 +150,37 @@ namespace ngraph
<< args[1].get_name() << ");\n"; << args[1].get_name() << ");\n";
writer << "out = arg0 + arg1;\n"; writer << "out = arg0 + arg1;\n";
#else #else
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
// get input element type // get input element type
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string( const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[1].get_element_type()); args[1].get_element_type());
std::vector<float>scale_vector(2, 1); std::vector<float> scale_vector(2, 1);
std::vector<mkldnn::memory::primitive_desc> inputs_pd; std::vector<mkldnn::memory::primitive_desc> inputs_pd;
auto input0_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0); auto input0_format =
auto input1_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1); runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto result_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0); auto input1_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input0_data_desc = mkldnn_emitter->build_memory_descriptor(args[0], input0_format); auto input0_data_desc =
auto input1_data_desc = mkldnn_emitter->build_memory_descriptor(args[1], input1_format); mkldnn_emitter->build_memory_descriptor(args[0], input0_format);
auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], result_format); auto input1_data_desc =
inputs_pd.push_back(mkldnn::memory::primitive_desc(input0_data_desc, mkldnn_emitter->build_memory_descriptor(args[1], input1_format);
runtime::cpu::mkldnn_utils::global_cpu_engine)); auto result_desc =
inputs_pd.push_back(mkldnn::memory::primitive_desc(input1_data_desc, mkldnn_emitter->build_memory_descriptor(out[0], result_format);
runtime::cpu::mkldnn_utils::global_cpu_engine)); inputs_pd.push_back(mkldnn::memory::primitive_desc(
input0_data_desc, runtime::cpu::mkldnn_utils::global_cpu_engine));
size_t add_index=0; inputs_pd.push_back(mkldnn::memory::primitive_desc(
add_index = mkldnn_emitter->build_elementwise_add(input0_data_desc, input1_data_desc, runtime::cpu::mkldnn_utils::global_cpu_engine));
input1_data_desc,
result_desc, size_t add_index = 0;
scale_vector, add_index = mkldnn_emitter->build_elementwise_add(
inputs_pd); input0_data_desc, input1_data_desc, result_desc, scale_vector, inputs_pd);
auto& deps = mkldnn_emitter->get_primitive_deps(add_index); auto& deps = mkldnn_emitter->get_primitive_deps(add_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n"; << ", " << args[0].get_name() << ");\n";
......
...@@ -137,11 +137,12 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu ...@@ -137,11 +137,12 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
return conv_index; return conv_index;
} }
size_t MKLDNNEmitter::build_elementwise_add(const mkldnn::memory::desc& input0_data_desc, size_t MKLDNNEmitter::build_elementwise_add(
const mkldnn::memory::desc& input1_data_desc, const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& input1_data_desc,
const std::vector<float>& scale_vector, const mkldnn::memory::desc& result_desc,
const std::vector<mkldnn::memory::primitive_desc>& inputs_pd) const std::vector<float>& scale_vector,
const std::vector<mkldnn::memory::primitive_desc>& inputs_pd)
{ {
std::vector<mkldnn::memory::primitive::at> inputs_primitive; std::vector<mkldnn::memory::primitive::at> inputs_primitive;
...@@ -154,9 +155,11 @@ size_t MKLDNNEmitter::build_elementwise_add(const mkldnn::memory::desc& input0_d ...@@ -154,9 +155,11 @@ size_t MKLDNNEmitter::build_elementwise_add(const mkldnn::memory::desc& input0_d
inputs_primitive.push_back(*mkldnn_primitives[input1_data_index]); inputs_primitive.push_back(*mkldnn_primitives[input1_data_index]);
// elementwise sum primtive descriptor // elementwise sum primtive descriptor
mkldnn::sum::primitive_desc sum_pd = mkldnn::sum::primitive_desc(result_desc, scale_vector, inputs_pd); mkldnn::sum::primitive_desc sum_pd =
mkldnn::sum::primitive_desc(result_desc, scale_vector, inputs_pd);
// sum primitive // sum primitive
size_t add_index = insert_primitive(new mkldnn::sum(sum_pd, inputs_primitive, *mkldnn_primitives[result_index])); size_t add_index = insert_primitive(
new mkldnn::sum(sum_pd, inputs_primitive, *mkldnn_primitives[result_index]));
primitive_deps[add_index] = {input1_data_index, input0_data_index, result_index}; primitive_deps[add_index] = {input1_data_index, input0_data_index, result_index};
return add_index; return add_index;
......
...@@ -68,11 +68,12 @@ namespace ngraph ...@@ -68,11 +68,12 @@ namespace ngraph
const ngraph::CoordinateDiff& padding_below, const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above); const ngraph::CoordinateDiff& padding_above);
size_t build_elementwise_add(const mkldnn::memory::desc& input0_data_desc, size_t build_elementwise_add(
const mkldnn::memory::desc& input1_data_desc, const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& input1_data_desc,
const std::vector<float>& scale_vector, const mkldnn::memory::desc& result_desc,
const std::vector<mkldnn::memory::primitive_desc>& input_pd); const std::vector<float>& scale_vector,
const std::vector<mkldnn::memory::primitive_desc>& input_pd);
private: private:
std::shared_ptr<CPU_ExternalFunction> external_function; std::shared_ptr<CPU_ExternalFunction> external_function;
......
...@@ -43,7 +43,6 @@ namespace ngraph ...@@ -43,7 +43,6 @@ namespace ngraph
{ {
namespace pass namespace pass
{ {
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Add) void CPUAssignment::ASSIGN_DECL(ngraph::op::Add)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment