Commit 20bd8bbc authored by Sang Ik Lee's avatar Sang Ik Lee Committed by Scott Cyphers

Replace relu implementation with element_wise (#2295)

parent f21eeb8d
......@@ -811,11 +811,13 @@ size_t MKLDNNEmitter::build_relu_forward(const mkldnn::memory::desc& input_desc,
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(new mkldnn::relu_forward(
{{mkldnn::prop_kind::forward_training, mkldnn::algorithm::eltwise_relu, input_desc, 0, 0},
executor::global_cpu_engine},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[result_index]));
const float negative_slope = 0.0f;
auto relu_desc = mkldnn::eltwise_forward::desc(
mkldnn::prop_kind::forward, mkldnn::algorithm::eltwise_relu, input_desc, negative_slope);
auto relu_pd = mkldnn::eltwise_forward::primitive_desc(relu_desc, executor::global_cpu_engine);
size_t primitive_index = insert_primitive(new mkldnn::eltwise_forward(
relu_pd, *m_mkldnn_primitives[input_index], *m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
......@@ -829,14 +831,23 @@ size_t MKLDNNEmitter::build_relu_backward(const mkldnn::memory::desc& input_desc
size_t delta_index = build_memory_primitive(delta_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(new mkldnn::relu_backward(
{{mkldnn::algorithm::eltwise_relu, delta_desc, input_desc, 0, 0},
executor::global_cpu_engine,
{{mkldnn::prop_kind::forward, mkldnn::algorithm::eltwise_relu, input_desc, 0, 0},
executor::global_cpu_engine}},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[delta_index],
*m_mkldnn_primitives[result_index]));
/* Backward relu */
const float negative_slope = 0.0f;
auto relu_desc = mkldnn::eltwise_forward::desc(
mkldnn::prop_kind::forward, mkldnn::algorithm::eltwise_relu, input_desc, negative_slope);
auto relu_pd = mkldnn::eltwise_forward::primitive_desc(relu_desc, executor::global_cpu_engine);
/* create backward relu primitive_descriptor */
auto relu_bwd_desc = mkldnn::eltwise_backward::desc(
mkldnn::algorithm::eltwise_relu, result_desc, input_desc, negative_slope);
auto relu_bwd_pd = mkldnn::eltwise_backward::primitive_desc(
relu_bwd_desc, executor::global_cpu_engine, relu_pd);
size_t primitive_index =
insert_primitive(new mkldnn::eltwise_backward(relu_bwd_pd,
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[delta_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, delta_index, result_index};
return primitive_index;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment