Commit 2e197c84 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

- emit mlkddn for LRN only for normalizing acorss/within channel (#4074)

*  - emit mlkddn for LRN only for normalizing acorss/within channel

* - remove explicit checking of axes in lrn builder
Co-authored-by: 's avatarJayaram Bobba <jayaram.bobba@intel.com>
parent 2fdb2676
...@@ -42,7 +42,7 @@ namespace ngraph ...@@ -42,7 +42,7 @@ namespace ngraph
AxisSet axes = lrn->get_reduction_axes(); AxisSet axes = lrn->get_reduction_axes();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node) && axes == AxisSet{1}) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{ {
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto lrn_desc = mkldnn_emitter->get_lrn_forward_desc(node); auto lrn_desc = mkldnn_emitter->get_lrn_forward_desc(node);
......
...@@ -548,11 +548,14 @@ namespace ngraph ...@@ -548,11 +548,14 @@ namespace ngraph
void CPUAssignment::ASSIGN_DECL(ngraph::op::LRN) void CPUAssignment::ASSIGN_DECL(ngraph::op::LRN)
{ {
(void)external_function; (void)external_function;
auto lrn = static_cast<ngraph::op::LRN*>(node);
AxisSet axes = lrn->get_reduction_axes();
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0); auto result_shape = node->get_output_shape(0);
if ((arg0_rank == 4) && node->get_input_element_type(0) == element::f32) if ((arg0_rank == 4) && node->get_input_element_type(0) == element::f32 &&
axes == AxisSet{1})
{ {
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node); runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment