Commit d57ef7d3 authored by pthoreho's avatar pthoreho

style fix

parent 42b4f4d0
......@@ -172,21 +172,23 @@ namespace ngraph
<< ", " << args[1].get_size() << ", 1>, Eigen::Unaligned> arg1("
<< args[1].get_name() << ");\n";
writer << "out = arg0 + arg1;\n";
#else
if (args[0].get_element_type() == element::f32 && args[1].get_element_type() == element::f32)
#else
if (args[0].get_element_type() == element::f32 &&
args[1].get_element_type() == element::f32)
{
auto input0_size_1d = 1;
auto input1_size_1d = 1;
auto result_size_1d = 1;
auto src_size = args[0].get_shape().size();
for (size_t i=0; i< src_size; i++)
for (size_t i = 0; i < src_size; i++)
{
input0_size_1d *= args[0].get_shape()[i];
input1_size_1d *= args[1].get_shape()[i];
result_size_1d *= out[0].get_shape()[i];
}
const string& et = get_mkldnn_data_type(args[0].get_element_type().c_type_string());
const string& et =
get_mkldnn_data_type(args[0].get_element_type().c_type_string());
// Bind to CPU engine
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
......@@ -200,8 +202,8 @@ namespace ngraph
<< "}, " << et << ", memory::format::x);\n";
writer << "memory::desc input1_data_desc = memory::desc({" << input1_size_1d
<< "}, " << et << ", memory::format::x);\n";
writer << "memory::desc result_desc = memory::desc({" << result_size_1d
<< "}, " << et << ", memory::format::x);\n";
writer << "memory::desc result_desc = memory::desc({" << result_size_1d << "}, "
<< et << ", memory::format::x);\n";
// memory for the user data
writer << "memory input0_data = memory({input0_data_desc, cpu_engine}, "
<< args[0].get_name() << ");\n";
......@@ -210,28 +212,30 @@ namespace ngraph
writer << "memory result = memory({result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n";
writer << "inputs_pd.push_back(memory::primitive_desc(input0_data_desc, cpu_engine));\n";
writer << "inputs_pd.push_back(memory::primitive_desc(input1_data_desc, cpu_engine));\n";
writer << "inputs_pd.push_back(memory::primitive_desc(input0_data_desc, "
"cpu_engine));\n";
writer << "inputs_pd.push_back(memory::primitive_desc(input1_data_desc, "
"cpu_engine));\n";
writer << "inputs_primitive.push_back(primitive::at(input0_data));\n";
writer << "inputs_primitive.push_back(primitive::at(input1_data));\n";
// elementwise sum primtive descriptor
writer << "sum::primitive_desc sum_pd = sum::primitive_desc(result_desc, scale_vector, inputs_pd);\n";
writer << "sum::primitive_desc sum_pd = sum::primitive_desc(result_desc, "
"scale_vector, inputs_pd);\n";
// sum primitive
writer << "sum sum_primitive = sum(sum_pd, inputs_primitive, result);\n";
// create stream and execute
writer << "stream s = stream(stream::kind::eager);\n"
<< "s.submit({sum_primitive}).wait();\n";
}
else
{
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] + "
<< args[1].get_name() << "[i];\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] + " << args[1].get_name() << "[i];\n";
writer << "}\n";
}
#endif
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment