Commit 7ce15121 authored by Jayaram Bobba's avatar Jayaram Bobba

Move Relu backprop to MKLDNN emitter

parent a5e29489
...@@ -3258,73 +3258,27 @@ namespace ngraph ...@@ -3258,73 +3258,27 @@ namespace ngraph
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string( auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
args[0].get_element_type()); auto input_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto input_format = auto delta_desc = mkldnn_emitter->build_memory_descriptor(
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0); args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
auto delta_format = auto result_desc = mkldnn_emitter->build_memory_descriptor(
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1); out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
if (!runtime::cpu::mkldnn_utils::compare_mkldnn_formats(input_format,
delta_format))
{
throw ngraph_error(
"mkldnn emitter: Relu backprop fprop input and delta layouts should be "
"the same");
}
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
writer << "{\n"; size_t relu_index =
writer.indent++; mkldnn_emitter->build_relu_backward(input_desc, delta_desc, result_desc);
writer << "try {\n"; auto& deps = mkldnn_emitter->get_primitive_deps(relu_index);
writer.indent++; writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; << ", " << args[0].get_name() << ");\n";
writer << "memory::desc input_data_desc = memory::desc({" << join(arg_shape) writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< "}, " << et << ", " << ", " << args[1].get_name() << ");\n";
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(input_format) writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ");\n"; << ", " << out[0].get_name() << ");\n";
writer << "memory::desc delta_data_desc = memory::desc({"
<< join(args[1].get_shape()) << "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(delta_format)
<< ");\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape)
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(result_format)
<< ");\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, " writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< args[0].get_name() << ");\n"; << to_string(relu_index) << ");\n";
writer << "memory delta_data = memory({delta_data_desc, cpu_engine}, "
<< args[1].get_name() << ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n";
writer << "relu_forward::desc relu_fwd_desc = "
"relu_forward::desc(prop_kind::forward, "
"algorithm::eltwise_relu, input_data_desc, 0, 0);\n";
writer << "relu_forward::primitive_desc relu_fwd_prim_desc = "
"relu_forward::primitive_desc(relu_fwd_desc, cpu_engine);\n";
writer << "relu_backward::desc relu_bwd_desc = "
"relu_backward::desc(algorithm::eltwise_relu, "
"delta_data_desc, input_data_desc, 0, 0);\n";
writer << "relu_backward::primitive_desc relu_bdw_prim_desc = "
"relu_backward::primitive_desc(relu_bwd_desc, cpu_engine, "
"relu_fwd_prim_desc);\n";
writer
<< "relu_backward relu_bwd= relu_backward(relu_bdw_prim_desc, input_data, "
"delta_data, result);\n";
writer << "stream s = stream(stream::kind::eager);\n"
"s.submit({relu_bwd}).wait();\n";
writer.indent--;
writer << "} catch (const mkldnn::error& e) {\n";
writer.indent++;
writer << "throw ngraph::ngraph_error(\"MKLDNN ERROR (\" + std::to_string("
"e.status) + \"): \" + e.message);\n";
writer.indent--;
writer << "}\n";
writer.indent--;
writer << "}\n";
} }
else else
{ {
......
...@@ -360,6 +360,27 @@ size_t MKLDNNEmitter::build_relu_forward(const mkldnn::memory::desc& input_desc, ...@@ -360,6 +360,27 @@ size_t MKLDNNEmitter::build_relu_forward(const mkldnn::memory::desc& input_desc,
return primitive_index; return primitive_index;
} }
size_t MKLDNNEmitter::build_relu_backward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& result_desc)
{
size_t input_index = build_memory_primitive(input_desc);
size_t delta_index = build_memory_primitive(delta_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(new mkldnn::relu_backward(
{{mkldnn::algorithm::eltwise_relu, delta_desc, input_desc, 0, 0},
mkldnn_utils::global_cpu_engine,
{{mkldnn::prop_kind::forward, mkldnn::algorithm::eltwise_relu, input_desc, 0, 0},
mkldnn_utils::global_cpu_engine}},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[delta_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, delta_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_sigmoid_forward(const mkldnn::memory::desc& input_desc, size_t MKLDNNEmitter::build_sigmoid_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc) const mkldnn::memory::desc& result_desc)
{ {
......
...@@ -119,6 +119,10 @@ namespace ngraph ...@@ -119,6 +119,10 @@ namespace ngraph
size_t build_relu_forward(const mkldnn::memory::desc& input_desc, size_t build_relu_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc); const mkldnn::memory::desc& result_desc);
size_t build_relu_backward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& result_desc);
size_t build_sigmoid_forward(const mkldnn::memory::desc& input_desc, size_t build_sigmoid_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc); const mkldnn::memory::desc& result_desc);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment