Commit d57ef7d3 authored by pthoreho's avatar pthoreho

style fix

parent 42b4f4d0
...@@ -172,21 +172,23 @@ namespace ngraph ...@@ -172,21 +172,23 @@ namespace ngraph
<< ", " << args[1].get_size() << ", 1>, Eigen::Unaligned> arg1(" << ", " << args[1].get_size() << ", 1>, Eigen::Unaligned> arg1("
<< args[1].get_name() << ");\n"; << args[1].get_name() << ");\n";
writer << "out = arg0 + arg1;\n"; writer << "out = arg0 + arg1;\n";
#else #else
if (args[0].get_element_type() == element::f32 && args[1].get_element_type() == element::f32) if (args[0].get_element_type() == element::f32 &&
args[1].get_element_type() == element::f32)
{ {
auto input0_size_1d = 1; auto input0_size_1d = 1;
auto input1_size_1d = 1; auto input1_size_1d = 1;
auto result_size_1d = 1; auto result_size_1d = 1;
auto src_size = args[0].get_shape().size(); auto src_size = args[0].get_shape().size();
for (size_t i=0; i< src_size; i++) for (size_t i = 0; i < src_size; i++)
{ {
input0_size_1d *= args[0].get_shape()[i]; input0_size_1d *= args[0].get_shape()[i];
input1_size_1d *= args[1].get_shape()[i]; input1_size_1d *= args[1].get_shape()[i];
result_size_1d *= out[0].get_shape()[i]; result_size_1d *= out[0].get_shape()[i];
} }
const string& et = get_mkldnn_data_type(args[0].get_element_type().c_type_string()); const string& et =
get_mkldnn_data_type(args[0].get_element_type().c_type_string());
// Bind to CPU engine // Bind to CPU engine
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
...@@ -200,8 +202,8 @@ namespace ngraph ...@@ -200,8 +202,8 @@ namespace ngraph
<< "}, " << et << ", memory::format::x);\n"; << "}, " << et << ", memory::format::x);\n";
writer << "memory::desc input1_data_desc = memory::desc({" << input1_size_1d writer << "memory::desc input1_data_desc = memory::desc({" << input1_size_1d
<< "}, " << et << ", memory::format::x);\n"; << "}, " << et << ", memory::format::x);\n";
writer << "memory::desc result_desc = memory::desc({" << result_size_1d writer << "memory::desc result_desc = memory::desc({" << result_size_1d << "}, "
<< "}, " << et << ", memory::format::x);\n"; << et << ", memory::format::x);\n";
// memory for the user data // memory for the user data
writer << "memory input0_data = memory({input0_data_desc, cpu_engine}, " writer << "memory input0_data = memory({input0_data_desc, cpu_engine}, "
<< args[0].get_name() << ");\n"; << args[0].get_name() << ");\n";
...@@ -210,28 +212,30 @@ namespace ngraph ...@@ -210,28 +212,30 @@ namespace ngraph
writer << "memory result = memory({result_desc, cpu_engine}, " writer << "memory result = memory({result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n"; << out[0].get_name() << ");\n";
writer << "inputs_pd.push_back(memory::primitive_desc(input0_data_desc, cpu_engine));\n"; writer << "inputs_pd.push_back(memory::primitive_desc(input0_data_desc, "
writer << "inputs_pd.push_back(memory::primitive_desc(input1_data_desc, cpu_engine));\n"; "cpu_engine));\n";
writer << "inputs_pd.push_back(memory::primitive_desc(input1_data_desc, "
"cpu_engine));\n";
writer << "inputs_primitive.push_back(primitive::at(input0_data));\n"; writer << "inputs_primitive.push_back(primitive::at(input0_data));\n";
writer << "inputs_primitive.push_back(primitive::at(input1_data));\n"; writer << "inputs_primitive.push_back(primitive::at(input1_data));\n";
// elementwise sum primtive descriptor // elementwise sum primtive descriptor
writer << "sum::primitive_desc sum_pd = sum::primitive_desc(result_desc, scale_vector, inputs_pd);\n"; writer << "sum::primitive_desc sum_pd = sum::primitive_desc(result_desc, "
"scale_vector, inputs_pd);\n";
// sum primitive // sum primitive
writer << "sum sum_primitive = sum(sum_pd, inputs_primitive, result);\n"; writer << "sum sum_primitive = sum(sum_pd, inputs_primitive, result);\n";
// create stream and execute // create stream and execute
writer << "stream s = stream(stream::kind::eager);\n" writer << "stream s = stream(stream::kind::eager);\n"
<< "s.submit({sum_primitive}).wait();\n"; << "s.submit({sum_primitive}).wait();\n";
} }
else else
{ {
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] + " writer << " " << out[0].get_name() << "[i] = " << args[0].get_name()
<< args[1].get_name() << "[i];\n"; << "[i] + " << args[1].get_name() << "[i];\n";
writer << "}\n"; writer << "}\n";
} }
#endif #endif
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment