Unverified Commit 5a5579f7 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by GitHub

Handle negative axis in ArithmeticReductionKeepDims (#4393)

Co-authored-by: 's avatarAshok Emani <ashok.emani@intel.com>
parent 6f680b9e
......@@ -16,6 +16,7 @@
#include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
......@@ -33,7 +34,6 @@ void op::util::ArithmeticReductionKeepDims::validate_and_infer_types()
{
if (m_keep_dims)
{
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
PartialShape result_shape{PartialShape::dynamic()};
......@@ -43,20 +43,32 @@ void op::util::ArithmeticReductionKeepDims::validate_and_infer_types()
if (input_rank.is_static() && reduction_axes_constant())
{
std::vector<Dimension> dims;
for (auto axis : reduction_axes)
AxisSet reduction_axes;
auto reduction_axes_val =
as_type<op::Constant>(input_value(1).get_node())->cast_vector<int64_t>();
for (auto axis : reduction_axes_val)
{
NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank),
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
try
{
axis = normalize_axis(this, axis, input_rank);
}
catch (const ngraph_error&)
{
NODE_VALIDATION_CHECK(this,
false,
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
}
reduction_axes.insert(axis);
}
std::vector<Dimension> dims;
for (size_t i = 0; i < size_t(input_rank); i++)
{
if (reduction_axes.count(i) == 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment