Commit 4552f48a authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Fallback to reference kernel for failing tests (#4047)

parent 566af28b
...@@ -40,7 +40,9 @@ namespace ngraph ...@@ -40,7 +40,9 @@ namespace ngraph
auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name()); auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto out_buffer_index = external_function->get_buffer_index(out[0].get_name()); auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node)) AxisSet axes = lrn->get_reduction_axes();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node) && axes == AxisSet{1})
{ {
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto lrn_desc = mkldnn_emitter->get_lrn_forward_desc(node); auto lrn_desc = mkldnn_emitter->get_lrn_forward_desc(node);
...@@ -77,7 +79,6 @@ namespace ngraph ...@@ -77,7 +79,6 @@ namespace ngraph
} }
else else
{ {
AxisSet axes = lrn->get_reduction_axes();
double alpha = lrn->get_alpha(); double alpha = lrn->get_alpha();
double beta = lrn->get_beta(); double beta = lrn->get_beta();
double bias = lrn->get_bias(); double bias = lrn->get_bias();
......
...@@ -11,16 +11,6 @@ max_3d_to_scalar_int32 ...@@ -11,16 +11,6 @@ max_3d_to_scalar_int32
send_recv send_recv
send_recv_ring send_recv_ring
# axes input param not supported
lrn_across_h
lrn_across_hw
lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
lrn_2d_across_empty
lrn_2d_across_outermost_axis
# ONNX TopK with dynamic K # ONNX TopK with dynamic K
top_k_opset_10 top_k_opset_10
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment