Commit 20bd8bbc authored by Sang Ik Lee's avatar Sang Ik Lee Committed by Scott Cyphers

Replace relu implementation with element_wise (#2295)

parent f21eeb8d
...@@ -811,11 +811,13 @@ size_t MKLDNNEmitter::build_relu_forward(const mkldnn::memory::desc& input_desc, ...@@ -811,11 +811,13 @@ size_t MKLDNNEmitter::build_relu_forward(const mkldnn::memory::desc& input_desc,
size_t input_index = build_memory_primitive(input_desc); size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc); size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(new mkldnn::relu_forward( const float negative_slope = 0.0f;
{{mkldnn::prop_kind::forward_training, mkldnn::algorithm::eltwise_relu, input_desc, 0, 0}, auto relu_desc = mkldnn::eltwise_forward::desc(
executor::global_cpu_engine}, mkldnn::prop_kind::forward, mkldnn::algorithm::eltwise_relu, input_desc, negative_slope);
*m_mkldnn_primitives[input_index], auto relu_pd = mkldnn::eltwise_forward::primitive_desc(relu_desc, executor::global_cpu_engine);
*m_mkldnn_primitives[result_index]));
size_t primitive_index = insert_primitive(new mkldnn::eltwise_forward(
relu_pd, *m_mkldnn_primitives[input_index], *m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, result_index}; m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index; return primitive_index;
...@@ -829,11 +831,20 @@ size_t MKLDNNEmitter::build_relu_backward(const mkldnn::memory::desc& input_desc ...@@ -829,11 +831,20 @@ size_t MKLDNNEmitter::build_relu_backward(const mkldnn::memory::desc& input_desc
size_t delta_index = build_memory_primitive(delta_desc); size_t delta_index = build_memory_primitive(delta_desc);
size_t result_index = build_memory_primitive(result_desc); size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index = insert_primitive(new mkldnn::relu_backward( /* Backward relu */
{{mkldnn::algorithm::eltwise_relu, delta_desc, input_desc, 0, 0}, const float negative_slope = 0.0f;
executor::global_cpu_engine, auto relu_desc = mkldnn::eltwise_forward::desc(
{{mkldnn::prop_kind::forward, mkldnn::algorithm::eltwise_relu, input_desc, 0, 0}, mkldnn::prop_kind::forward, mkldnn::algorithm::eltwise_relu, input_desc, negative_slope);
executor::global_cpu_engine}}, auto relu_pd = mkldnn::eltwise_forward::primitive_desc(relu_desc, executor::global_cpu_engine);
/* create backward relu primitive_descriptor */
auto relu_bwd_desc = mkldnn::eltwise_backward::desc(
mkldnn::algorithm::eltwise_relu, result_desc, input_desc, negative_slope);
auto relu_bwd_pd = mkldnn::eltwise_backward::primitive_desc(
relu_bwd_desc, executor::global_cpu_engine, relu_pd);
size_t primitive_index =
insert_primitive(new mkldnn::eltwise_backward(relu_bwd_pd,
*m_mkldnn_primitives[input_index], *m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[delta_index], *m_mkldnn_primitives[delta_index],
*m_mkldnn_primitives[result_index])); *m_mkldnn_primitives[result_index]));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment