Commit 42b4f4d0 authored by pthoreho's avatar pthoreho

- elementwise Add mkldnn support in cpu emitter

parent 233e4b1b
...@@ -173,12 +173,67 @@ namespace ngraph ...@@ -173,12 +173,67 @@ namespace ngraph
<< args[1].get_name() << ");\n"; << args[1].get_name() << ");\n";
writer << "out = arg0 + arg1;\n"; writer << "out = arg0 + arg1;\n";
#else #else
if (args[0].get_element_type() == element::f32 && args[1].get_element_type() == element::f32)
{
auto input0_size_1d = 1;
auto input1_size_1d = 1;
auto result_size_1d = 1;
auto src_size = args[0].get_shape().size();
for (size_t i=0; i< src_size; i++)
{
input0_size_1d *= args[0].get_shape()[i];
input1_size_1d *= args[1].get_shape()[i];
result_size_1d *= out[0].get_shape()[i];
}
const string& et = get_mkldnn_data_type(args[0].get_element_type().c_type_string());
// Bind to CPU engine
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
writer << "std::vector<float>scale_vector(2, 1);\n";
writer << "std::vector<memory::primitive_desc> inputs_pd;\n";
writer << "std::vector<memory::primitive::at> inputs_primitive;\n";
// memory desc for inputs
writer << "memory::desc input0_data_desc = memory::desc({" << input0_size_1d
<< "}, " << et << ", memory::format::x);\n";
writer << "memory::desc input1_data_desc = memory::desc({" << input1_size_1d
<< "}, " << et << ", memory::format::x);\n";
writer << "memory::desc result_desc = memory::desc({" << result_size_1d
<< "}, " << et << ", memory::format::x);\n";
// memory for the user data
writer << "memory input0_data = memory({input0_data_desc, cpu_engine}, "
<< args[0].get_name() << ");\n";
writer << "memory input1_data = memory({input1_data_desc, cpu_engine}, "
<< args[1].get_name() << ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n";
writer << "inputs_pd.push_back(memory::primitive_desc(input0_data_desc, cpu_engine));\n";
writer << "inputs_pd.push_back(memory::primitive_desc(input1_data_desc, cpu_engine));\n";
writer << "inputs_primitive.push_back(primitive::at(input0_data));\n";
writer << "inputs_primitive.push_back(primitive::at(input1_data));\n";
// elementwise sum primtive descriptor
writer << "sum::primitive_desc sum_pd = sum::primitive_desc(result_desc, scale_vector, inputs_pd);\n";
// sum primitive
writer << "sum sum_primitive = sum(sum_pd, inputs_primitive, result);\n";
// create stream and execute
writer << "stream s = stream(stream::kind::eager);\n"
<< "s.submit({sum_primitive}).wait();\n";
}
else
{
writer << "#pragma omp parallel for\n"; writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n"; writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer << "{\n"; writer << "{\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] + " writer << " " << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] + "
<< args[1].get_name() << "[i];\n"; << args[1].get_name() << "[i];\n";
writer << "}\n"; writer << "}\n";
}
#endif #endif
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <unordered_set> #include <unordered_set>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/avg_pool.hpp" #include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/batch_norm.hpp" #include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
...@@ -39,6 +40,7 @@ namespace ngraph ...@@ -39,6 +40,7 @@ namespace ngraph
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
static const std::unordered_set<std::type_index> s_op_registry{ static const std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::Add),
TI(ngraph::op::AvgPool), TI(ngraph::op::AvgPool),
TI(ngraph::op::AvgPoolBackprop), TI(ngraph::op::AvgPoolBackprop),
TI(ngraph::op::BatchNorm), TI(ngraph::op::BatchNorm),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment