Unverified Commit 5a5579f7 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by GitHub

Handle negative axis in ArithmeticReductionKeepDims (#4393)

Co-authored-by: 's avatarAshok Emani <ashok.emani@intel.com>
parent 6f680b9e
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp" #include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -33,7 +34,6 @@ void op::util::ArithmeticReductionKeepDims::validate_and_infer_types() ...@@ -33,7 +34,6 @@ void op::util::ArithmeticReductionKeepDims::validate_and_infer_types()
{ {
if (m_keep_dims) if (m_keep_dims)
{ {
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0); auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank(); auto input_rank = input_shape.rank();
PartialShape result_shape{PartialShape::dynamic()}; PartialShape result_shape{PartialShape::dynamic()};
...@@ -43,20 +43,32 @@ void op::util::ArithmeticReductionKeepDims::validate_and_infer_types() ...@@ -43,20 +43,32 @@ void op::util::ArithmeticReductionKeepDims::validate_and_infer_types()
if (input_rank.is_static() && reduction_axes_constant()) if (input_rank.is_static() && reduction_axes_constant())
{ {
std::vector<Dimension> dims; AxisSet reduction_axes;
for (auto axis : reduction_axes) auto reduction_axes_val =
as_type<op::Constant>(input_value(1).get_node())->cast_vector<int64_t>();
for (auto axis : reduction_axes_val)
{ {
NODE_VALIDATION_CHECK(this, try
axis < size_t(input_rank), {
"Reduction axis (", axis = normalize_axis(this, axis, input_rank);
axis, }
") is out of bounds ", catch (const ngraph_error&)
"(argument shape: ", {
input_shape, NODE_VALIDATION_CHECK(this,
", reduction axes: ", false,
reduction_axes, "Reduction axis (",
")"); axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
}
reduction_axes.insert(axis);
} }
std::vector<Dimension> dims;
for (size_t i = 0; i < size_t(input_rank); i++) for (size_t i = 0; i < size_t(input_rank); i++)
{ {
if (reduction_axes.count(i) == 0) if (reduction_axes.count(i) == 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment