Commit 2e197c84 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

- emit mlkddn for LRN only for normalizing acorss/within channel (#4074)

*  - emit mlkddn for LRN only for normalizing acorss/within channel

* - remove explicit checking of axes in lrn builder
Co-authored-by: 's avatarJayaram Bobba <jayaram.bobba@intel.com>
parent 2fdb2676
......@@ -42,7 +42,7 @@ namespace ngraph
AxisSet axes = lrn->get_reduction_axes();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node) && axes == AxisSet{1})
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto lrn_desc = mkldnn_emitter->get_lrn_forward_desc(node);
......
......@@ -548,11 +548,14 @@ namespace ngraph
void CPUAssignment::ASSIGN_DECL(ngraph::op::LRN)
{
(void)external_function;
auto lrn = static_cast<ngraph::op::LRN*>(node);
AxisSet axes = lrn->get_reduction_axes();
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if ((arg0_rank == 4) && node->get_input_element_type(0) == element::f32)
if ((arg0_rank == 4) && node->get_input_element_type(0) == element::f32 &&
axes == AxisSet{1})
{
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment