Commit 566af28b authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Support negative reduction axes in reduction ops shape inference (#4081)

parent 10195034
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/op/util/arithmetic_reduction.hpp" #include "ngraph/op/util/arithmetic_reduction.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -64,7 +65,6 @@ void op::util::ArithmeticReduction::set_reduction_axes(const AxisSet& reduction_ ...@@ -64,7 +65,6 @@ void op::util::ArithmeticReduction::set_reduction_axes(const AxisSet& reduction_
void op::util::ArithmeticReduction::validate_and_infer_types() void op::util::ArithmeticReduction::validate_and_infer_types()
{ {
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0); auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank(); auto input_rank = input_shape.rank();
...@@ -72,22 +72,32 @@ void op::util::ArithmeticReduction::validate_and_infer_types() ...@@ -72,22 +72,32 @@ void op::util::ArithmeticReduction::validate_and_infer_types()
if (input_rank.is_static() && reduction_axes_constant()) if (input_rank.is_static() && reduction_axes_constant())
{ {
std::vector<Dimension> dims; AxisSet reduction_axes;
auto reduction_axes_val =
for (auto axis : reduction_axes) as_type<op::Constant>(input_value(1).get_node())->get_vector<int64_t>();
for (auto axis : reduction_axes_val)
{ {
NODE_VALIDATION_CHECK(this, try
axis < size_t(input_rank), {
"Reduction axis (", axis = normalize_axis(this, axis, size_t(input_rank));
axis, }
") is out of bounds ", catch (const ngraph_error& err)
"(argument shape: ", {
input_shape, NODE_VALIDATION_CHECK(this,
", reduction axes: ", false,
reduction_axes, "Reduction axis (",
")"); axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
}
reduction_axes.insert(axis);
} }
std::vector<Dimension> dims;
for (size_t i = 0; i < size_t(input_rank); i++) for (size_t i = 0; i < size_t(input_rank); i++)
{ {
if (reduction_axes.count(i) == 0) if (reduction_axes.count(i) == 0)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/op/util/logical_reduction.hpp" #include "ngraph/op/util/logical_reduction.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -63,7 +64,6 @@ void op::util::LogicalReduction::set_reduction_axes(const AxisSet& reduction_axe ...@@ -63,7 +64,6 @@ void op::util::LogicalReduction::set_reduction_axes(const AxisSet& reduction_axe
void op::util::LogicalReduction::validate_and_infer_types() void op::util::LogicalReduction::validate_and_infer_types()
{ {
auto reduction_axes = get_reduction_axes();
auto input_shape = get_input_partial_shape(0); auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank(); auto input_rank = input_shape.rank();
...@@ -71,22 +71,32 @@ void op::util::LogicalReduction::validate_and_infer_types() ...@@ -71,22 +71,32 @@ void op::util::LogicalReduction::validate_and_infer_types()
if (input_rank.is_static() && reduction_axes_constant()) if (input_rank.is_static() && reduction_axes_constant())
{ {
std::vector<Dimension> dims; AxisSet reduction_axes;
auto reduction_axes_val =
for (auto axis : reduction_axes) as_type<op::Constant>(input_value(1).get_node())->get_vector<int64_t>();
for (auto axis : reduction_axes_val)
{ {
NODE_VALIDATION_CHECK(this, try
axis < size_t(input_rank), {
"Reduction axis (", axis = normalize_axis(this, axis, size_t(input_rank));
axis, }
") is out of bounds ", catch (const ngraph_error& err)
"(argument shape: ", {
input_shape, NODE_VALIDATION_CHECK(this,
", reduction axes: ", false,
reduction_axes, "Reduction axis (",
")"); axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
}
reduction_axes.insert(axis);
} }
std::vector<Dimension> dims;
for (size_t i = 0; i < size_t(input_rank); i++) for (size_t i = 0; i < size_t(input_rank); i++)
{ {
if (reduction_axes.count(i) == 0) if (reduction_axes.count(i) == 0)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/op/util/logical_reduction_keep_dims.hpp" #include "ngraph/op/util/logical_reduction_keep_dims.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -33,7 +34,6 @@ void op::util::LogicalReductionKeepDims::validate_and_infer_types() ...@@ -33,7 +34,6 @@ void op::util::LogicalReductionKeepDims::validate_and_infer_types()
{ {
if (m_keep_dims) if (m_keep_dims)
{ {
const auto reduction_axes = get_reduction_axes();
const auto input_shape = get_input_partial_shape(0); const auto input_shape = get_input_partial_shape(0);
const auto input_rank = input_shape.rank(); const auto input_rank = input_shape.rank();
PartialShape result_shape{PartialShape::dynamic()}; PartialShape result_shape{PartialShape::dynamic()};
...@@ -45,20 +45,32 @@ void op::util::LogicalReductionKeepDims::validate_and_infer_types() ...@@ -45,20 +45,32 @@ void op::util::LogicalReductionKeepDims::validate_and_infer_types()
if (input_rank.is_static() && reduction_axes_constant()) if (input_rank.is_static() && reduction_axes_constant())
{ {
std::vector<Dimension> dims; AxisSet reduction_axes;
for (const auto axis : reduction_axes) auto reduction_axes_val =
as_type<op::Constant>(input_value(1).get_node())->get_vector<int64_t>();
for (auto axis : reduction_axes_val)
{ {
NODE_VALIDATION_CHECK(this, try
axis < size_t(input_rank), {
"Reduction axis (", axis = normalize_axis(this, axis, size_t(input_rank));
axis, }
") is out of bounds ", catch (const ngraph_error& err)
"(argument shape: ", {
input_shape, NODE_VALIDATION_CHECK(this,
", reduction axes: ", false,
reduction_axes, "Reduction axis (",
")"); axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
}
reduction_axes.insert(axis);
} }
std::vector<Dimension> dims;
for (size_t i = 0; i < size_t(input_rank); i++) for (size_t i = 0; i < size_t(input_rank); i++)
{ {
if (reduction_axes.count(i) == 0) if (reduction_axes.count(i) == 0)
......
...@@ -126,3 +126,14 @@ TEST(type_prop, sum_partial_rank_static_dynamic_axes_oob) ...@@ -126,3 +126,14 @@ TEST(type_prop, sum_partial_rank_static_dynamic_axes_oob)
FAIL() << "Deduced type check failed for unexpected reason"; FAIL() << "Deduced type check failed for unexpected reason";
} }
} }
TEST(type_prop, sum_partial_negative_axes)
{
auto param =
make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic(), 4, 5});
auto summation_axes = op::Constant::create(element::i64, Shape{2}, {-3, -2});
auto sum = make_shared<op::Sum>(param, summation_axes);
EXPECT_EQ(sum->get_output_element_type(0), element::f32);
EXPECT_EQ(sum->get_shape(), (Shape{1, 2, 5}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment