Commit 8578694b authored by pthoreho's avatar pthoreho

style fix

parent f9191dd9
......@@ -150,34 +150,37 @@ namespace ngraph
<< args[1].get_name() << ");\n";
writer << "out = arg0 + arg1;\n";
#else
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
// get input element type
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[1].get_element_type());
std::vector<float>scale_vector(2, 1);
std::vector<float> scale_vector(2, 1);
std::vector<mkldnn::memory::primitive_desc> inputs_pd;
auto input0_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto input1_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
auto result_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto input0_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto input1_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input0_data_desc = mkldnn_emitter->build_memory_descriptor(args[0], input0_format);
auto input1_data_desc = mkldnn_emitter->build_memory_descriptor(args[1], input1_format);
auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], result_format);
inputs_pd.push_back(mkldnn::memory::primitive_desc(input0_data_desc,
runtime::cpu::mkldnn_utils::global_cpu_engine));
inputs_pd.push_back(mkldnn::memory::primitive_desc(input1_data_desc,
runtime::cpu::mkldnn_utils::global_cpu_engine));
size_t add_index=0;
add_index = mkldnn_emitter->build_elementwise_add(input0_data_desc,
input1_data_desc,
result_desc,
scale_vector,
inputs_pd);
auto input0_data_desc =
mkldnn_emitter->build_memory_descriptor(args[0], input0_format);
auto input1_data_desc =
mkldnn_emitter->build_memory_descriptor(args[1], input1_format);
auto result_desc =
mkldnn_emitter->build_memory_descriptor(out[0], result_format);
inputs_pd.push_back(mkldnn::memory::primitive_desc(
input0_data_desc, runtime::cpu::mkldnn_utils::global_cpu_engine));
inputs_pd.push_back(mkldnn::memory::primitive_desc(
input1_data_desc, runtime::cpu::mkldnn_utils::global_cpu_engine));
size_t add_index = 0;
add_index = mkldnn_emitter->build_elementwise_add(
input0_data_desc, input1_data_desc, result_desc, scale_vector, inputs_pd);
auto& deps = mkldnn_emitter->get_primitive_deps(add_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
......
......@@ -137,11 +137,12 @@ size_t MKLDNNEmitter::build_convolution_forward(const mkldnn::memory::desc& inpu
return conv_index;
}
size_t MKLDNNEmitter::build_elementwise_add(const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc,
const mkldnn::memory::desc& result_desc,
const std::vector<float>& scale_vector,
const std::vector<mkldnn::memory::primitive_desc>& inputs_pd)
size_t MKLDNNEmitter::build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc,
const mkldnn::memory::desc& result_desc,
const std::vector<float>& scale_vector,
const std::vector<mkldnn::memory::primitive_desc>& inputs_pd)
{
std::vector<mkldnn::memory::primitive::at> inputs_primitive;
......@@ -154,9 +155,11 @@ size_t MKLDNNEmitter::build_elementwise_add(const mkldnn::memory::desc& input0_d
inputs_primitive.push_back(*mkldnn_primitives[input1_data_index]);
// elementwise sum primtive descriptor
mkldnn::sum::primitive_desc sum_pd = mkldnn::sum::primitive_desc(result_desc, scale_vector, inputs_pd);
mkldnn::sum::primitive_desc sum_pd =
mkldnn::sum::primitive_desc(result_desc, scale_vector, inputs_pd);
// sum primitive
size_t add_index = insert_primitive(new mkldnn::sum(sum_pd, inputs_primitive, *mkldnn_primitives[result_index]));
size_t add_index = insert_primitive(
new mkldnn::sum(sum_pd, inputs_primitive, *mkldnn_primitives[result_index]));
primitive_deps[add_index] = {input1_data_index, input0_data_index, result_index};
return add_index;
......
......@@ -68,11 +68,12 @@ namespace ngraph
const ngraph::CoordinateDiff& padding_below,
const ngraph::CoordinateDiff& padding_above);
size_t build_elementwise_add(const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc,
const mkldnn::memory::desc& result_desc,
const std::vector<float>& scale_vector,
const std::vector<mkldnn::memory::primitive_desc>& input_pd);
size_t build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc,
const mkldnn::memory::desc& result_desc,
const std::vector<float>& scale_vector,
const std::vector<mkldnn::memory::primitive_desc>& input_pd);
private:
std::shared_ptr<CPU_ExternalFunction> external_function;
......
......@@ -43,7 +43,6 @@ namespace ngraph
{
namespace pass
{
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Add)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment