Commit 566af28b authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Support negative reduction axes in reduction ops shape inference (#4081)

parent 10195034
......@@ -16,6 +16,7 @@
#include "ngraph/op/util/arithmetic_reduction.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
......@@ -64,7 +65,6 @@ void op::util::ArithmeticReduction::set_reduction_axes(const AxisSet& reduction_
void op::util::ArithmeticReduction::validate_and_infer_types()
{
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
......@@ -72,22 +72,32 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
if (input_rank.is_static() && reduction_axes_constant())
{
std::vector<Dimension> dims;
for (auto axis : reduction_axes)
AxisSet reduction_axes;
auto reduction_axes_val =
as_type<op::Constant>(input_value(1).get_node())->get_vector<int64_t>();
for (auto axis : reduction_axes_val)
{
NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank),
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
try
{
axis = normalize_axis(this, axis, size_t(input_rank));
}
catch (const ngraph_error& err)
{
NODE_VALIDATION_CHECK(this,
false,
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
}
reduction_axes.insert(axis);
}
std::vector<Dimension> dims;
for (size_t i = 0; i < size_t(input_rank); i++)
{
if (reduction_axes.count(i) == 0)
......
......@@ -16,6 +16,7 @@
#include "ngraph/op/util/logical_reduction.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
......@@ -63,7 +64,6 @@ void op::util::LogicalReduction::set_reduction_axes(const AxisSet& reduction_axe
void op::util::LogicalReduction::validate_and_infer_types()
{
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
......@@ -71,22 +71,32 @@ void op::util::LogicalReduction::validate_and_infer_types()
if (input_rank.is_static() && reduction_axes_constant())
{
std::vector<Dimension> dims;
for (auto axis : reduction_axes)
AxisSet reduction_axes;
auto reduction_axes_val =
as_type<op::Constant>(input_value(1).get_node())->get_vector<int64_t>();
for (auto axis : reduction_axes_val)
{
NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank),
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
try
{
axis = normalize_axis(this, axis, size_t(input_rank));
}
catch (const ngraph_error& err)
{
NODE_VALIDATION_CHECK(this,
false,
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
}
reduction_axes.insert(axis);
}
std::vector<Dimension> dims;
for (size_t i = 0; i < size_t(input_rank); i++)
{
if (reduction_axes.count(i) == 0)
......
......@@ -16,6 +16,7 @@
#include "ngraph/op/util/logical_reduction_keep_dims.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
......@@ -33,7 +34,6 @@ void op::util::LogicalReductionKeepDims::validate_and_infer_types()
{
if (m_keep_dims)
{
const auto reduction_axes = get_reduction_axes();
const auto input_shape = get_input_partial_shape(0);
const auto input_rank = input_shape.rank();
PartialShape result_shape{PartialShape::dynamic()};
......@@ -45,20 +45,32 @@ void op::util::LogicalReductionKeepDims::validate_and_infer_types()
if (input_rank.is_static() && reduction_axes_constant())
{
std::vector<Dimension> dims;
for (const auto axis : reduction_axes)
AxisSet reduction_axes;
auto reduction_axes_val =
as_type<op::Constant>(input_value(1).get_node())->get_vector<int64_t>();
for (auto axis : reduction_axes_val)
{
NODE_VALIDATION_CHECK(this,
axis < size_t(input_rank),
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
try
{
axis = normalize_axis(this, axis, size_t(input_rank));
}
catch (const ngraph_error& err)
{
NODE_VALIDATION_CHECK(this,
false,
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
}
reduction_axes.insert(axis);
}
std::vector<Dimension> dims;
for (size_t i = 0; i < size_t(input_rank); i++)
{
if (reduction_axes.count(i) == 0)
......
......@@ -126,3 +126,14 @@ TEST(type_prop, sum_partial_rank_static_dynamic_axes_oob)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, sum_partial_negative_axes)
{
auto param =
make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic(), 4, 5});
auto summation_axes = op::Constant::create(element::i64, Shape{2}, {-3, -2});
auto sum = make_shared<op::Sum>(param, summation_axes);
EXPECT_EQ(sum->get_output_element_type(0), element::f32);
EXPECT_EQ(sum->get_shape(), (Shape{1, 2, 5}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment