Commit 36c213af authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Port ConvertLayout

parent a6be909f
...@@ -2933,31 +2933,21 @@ namespace ngraph ...@@ -2933,31 +2933,21 @@ namespace ngraph
dynamic_cast<runtime::cpu::LayoutDescriptor&>(*input_tvl).get_mkldnn_format(); dynamic_cast<runtime::cpu::LayoutDescriptor&>(*input_tvl).get_mkldnn_format();
auto output_format = auto output_format =
dynamic_cast<runtime::cpu::LayoutDescriptor&>(*output_tvl).get_mkldnn_format(); dynamic_cast<runtime::cpu::LayoutDescriptor&>(*output_tvl).get_mkldnn_format();
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[0].get_element_type());
writer << "{\n"; auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
writer.indent++; auto input_desc = mkldnn_emitter->build_memory_descriptor(args[0], input_format);
auto result_desc = mkldnn_emitter->build_memory_descriptor(out[0], output_format);
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; size_t reorder_index = mkldnn_emitter->build_reorder(input_desc, result_desc);
writer << "memory::desc input_desc = memory::desc({" << join(args[0].get_shape())
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(input_format)
<< ");\n";
writer << "memory::desc output_desc = memory::desc({" << join(out[0].get_shape())
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(output_format)
<< ");\n";
writer << "memory input = memory({input_desc, cpu_engine}, " << args[0].get_name()
<< ");\n";
writer << "memory output = memory({output_desc, cpu_engine}, " << out[0].get_name()
<< ");\n";
writer << "reorder prim = reorder(input, output);\n";
writer << "stream s = stream(stream::kind::eager);\n" auto& deps = mkldnn_emitter->get_primitive_deps(reorder_index);
<< "s.submit({prim}).wait();\n"; writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", "
writer.indent--; << args[0].get_name() << ");\n";
writer << "}\n"; writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", "
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(reorder_index) << ");\n";
} }
template <> template <>
......
...@@ -251,6 +251,19 @@ size_t MKLDNNEmitter::build_pooling_forward(mkldnn::algorithm pooling_algorithm, ...@@ -251,6 +251,19 @@ size_t MKLDNNEmitter::build_pooling_forward(mkldnn::algorithm pooling_algorithm,
return primitive_index; return primitive_index;
} }
size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc)
{
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(
new mkldnn::reorder(*mkldnn_primitives[input_index], *mkldnn_primitives[result_index]));
primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_elementwise_add( size_t MKLDNNEmitter::build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc, const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc, const mkldnn::memory::desc& input1_data_desc,
......
...@@ -95,6 +95,9 @@ namespace ngraph ...@@ -95,6 +95,9 @@ namespace ngraph
const ngraph::Shape& padding_below, const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above); const ngraph::Shape& padding_above);
size_t build_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
size_t build_elementwise_add( size_t build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc, const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc, const mkldnn::memory::desc& input1_data_desc,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment