Commit acc02402 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Port Relu

parent 36c213af
......@@ -3039,56 +3039,24 @@ namespace ngraph
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Relu)
{
const auto& arg_shape = args[0].get_shape();
const auto& result_shape = out[0].get_shape();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
const string& et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string(
args[0].get_element_type());
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
writer << "{\n";
writer.indent++;
size_t relu_index = mkldnn_emitter->build_relu_forward(input_desc, result_desc);
writer << "try {\n";
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
writer << "memory::desc input_data_desc = memory::desc({" << join(arg_shape)
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(input_format)
<< ");\n";
writer << "memory::desc result_desc = memory::desc({" << join(result_shape)
<< "}, " << et << ", "
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(result_format)
<< ");\n";
auto& deps = mkldnn_emitter->get_primitive_deps(relu_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "memory input_data = memory({input_data_desc, cpu_engine}, "
<< args[0].get_name() << ");\n";
writer << "memory result = memory({result_desc, cpu_engine}, "
<< out[0].get_name() << ");\n";
writer << "relu_forward::desc relu_fwd_desc = "
"relu_forward::desc(prop_kind::forward_training, "
"algorithm::eltwise_relu, input_data_desc, 0, 0);\n";
writer << "relu_forward::primitive_desc relu_prim_desc = "
"relu_forward::primitive_desc(relu_fwd_desc, cpu_engine);\n";
writer << "relu_forward relu_fwd= relu_forward(relu_prim_desc, input_data, "
"result);\n";
writer << "stream s = stream(stream::kind::eager);\n"
"s.submit({relu_fwd}).wait();\n";
writer.indent--;
writer << "} catch (const mkldnn::error& e) {\n";
writer.indent++;
writer << "throw ngraph::ngraph_error(\"MKLDNN ERROR (\" + std::to_string("
"e.status) + \"): \" + e.message);\n";
writer.indent--;
writer << "}\n";
writer.indent--;
writer << "}\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(relu_index) << ");\n";
}
else
{
......
......@@ -264,6 +264,22 @@ size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc,
return primitive_index;
}
size_t MKLDNNEmitter::build_relu_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc)
{
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(new mkldnn::relu_forward(
{{mkldnn::prop_kind::forward_training, mkldnn::algorithm::eltwise_relu, input_desc, 0, 0},
mkldnn_utils::global_cpu_engine},
*mkldnn_primitives[input_index],
*mkldnn_primitives[result_index]));
primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc,
......
......@@ -98,6 +98,9 @@ namespace ngraph
size_t build_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
size_t build_relu_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
size_t build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment