Commit e92ee04c authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Partial Shapes and Types, Part 4h: IndexReduction (#1829)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial type/shape propagation for GetOutputElement

* Implement partial type/shape validation for IndexReduction, and unit tests

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness

* Review comments
parent 30df706f
...@@ -30,19 +30,41 @@ op::util::IndexReduction::IndexReduction(const std::string& node_type, ...@@ -30,19 +30,41 @@ op::util::IndexReduction::IndexReduction(const std::string& node_type,
, m_index_element_type(index_element_type) , m_index_element_type(index_element_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
auto rank = arg->get_shape().size(); }
void op::util::IndexReduction::validate_and_infer_types()
{
const PartialShape& arg_shape = get_input_partial_shape(0);
Rank rank = arg_shape.rank();
NODE_VALIDATION_ASSERT(this, rank.is_dynamic() || size_t(rank) >= 1)
<< "Argument rank is zero.";
NODE_VALIDATION_ASSERT(this, rank.is_dynamic() || m_axis < size_t(rank))
<< "Reduction axis (" << m_axis << ") is not less than argument rank (" << rank << ").";
NODE_VALIDATION_ASSERT(
this, m_index_element_type == element::i32 || m_index_element_type == element::i64)
<< "Index element is neither i64 or i32.";
PartialShape output_shape{PartialShape::dynamic()};
if (!rank.is_dynamic())
{
std::vector<Dimension> output_dims(size_t(rank) - 1);
size_t j = 0;
NODE_VALIDATION_ASSERT(this, rank >= 1) << "Argument rank must be at least 1"; for (size_t i = 0; i < size_t(rank) - 1; i++)
NODE_VALIDATION_ASSERT(this, axis < rank) << "Axis " << axis << " is greater than rank of " {
<< rank; if (j == m_axis)
NODE_VALIDATION_ASSERT(this, {
index_element_type == element::i32 || index_element_type == element::i64) j++;
<< "Index element type must be i64 or i32"; }
output_dims[i] = arg_shape[j++];
}
Shape output_shape = arg->get_shape(); output_shape = PartialShape(output_dims);
output_shape.erase(output_shape.begin() + axis); }
set_output_type(0, index_element_type, output_shape); set_output_type(0, m_index_element_type, output_shape);
} }
void op::util::IndexReduction::generate_adjoints(autodiff::Adjoints& adjoints, void op::util::IndexReduction::generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -37,6 +37,8 @@ namespace ngraph ...@@ -37,6 +37,8 @@ namespace ngraph
protected: protected:
size_t m_axis; size_t m_axis;
element::Type m_index_element_type; element::Type m_index_element_type;
void validate_and_infer_types() override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
......
...@@ -8953,7 +8953,7 @@ TEST(type_prop, index_reduction_scalar) ...@@ -8953,7 +8953,7 @@ TEST(type_prop, index_reduction_scalar)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Argument rank must be at least 1"); EXPECT_HAS_SUBSTRING(error.what(), "Argument rank is zero");
} }
catch (...) catch (...)
{ {
...@@ -8972,7 +8972,7 @@ TEST(type_prop, index_reduction_invalid_rank) ...@@ -8972,7 +8972,7 @@ TEST(type_prop, index_reduction_invalid_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "is greater than rank of"); EXPECT_HAS_SUBSTRING(error.what(), "Reduction axis (2) is not less than argument rank (2)");
} }
catch (...) catch (...)
{ {
...@@ -8991,7 +8991,7 @@ TEST(type_prop, index_reduction_invalid_index_type) ...@@ -8991,7 +8991,7 @@ TEST(type_prop, index_reduction_invalid_index_type)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Index element type must be"); EXPECT_HAS_SUBSTRING(error.what(), "Index element is neither i64 or i32");
} }
catch (...) catch (...)
{ {
...@@ -8999,6 +8999,108 @@ TEST(type_prop, index_reduction_invalid_index_type) ...@@ -8999,6 +8999,108 @@ TEST(type_prop, index_reduction_invalid_index_type)
} }
} }
TEST(type_prop, index_reduction_partial_rank_dynamic_output_et_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
size_t axis = 228;
auto output_et = element::dynamic;
try
{
auto argmax = make_shared<op::ArgMax>(a, axis, output_et);
FAIL() << "Invalid output type of element::dynamic not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Index element is neither i64 or i32");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, index_reduction_partial_rank_dynamic_output_et_invalid)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
size_t axis = 228;
auto output_et = element::dynamic;
try
{
auto argmax = make_shared<op::ArgMax>(a, axis, output_et);
FAIL() << "Invalid output type of element::f32 not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Index element is neither i64 or i32");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, index_reduction_partial_rank_dynamic_ok)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
size_t axis = 228;
auto output_et = element::i32;
auto argmax = make_shared<op::ArgMax>(a, axis, output_et);
ASSERT_EQ(argmax->get_output_element_type(0), element::i32);
ASSERT_TRUE(argmax->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, index_reduction_partial_rank_static_dynamic_axis_oob)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3, 4});
size_t axis = 4;
auto output_et = element::i32;
try
{
auto argmax = make_shared<op::ArgMax>(a, axis, output_et);
FAIL() << "Out-of-bounds reduction axis not detected (rank-static dynamic argument)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Reduction axis (4) is not less than argument rank (4)");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, index_reduction_partial_rank_static_dynamic_ok)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3, 4});
size_t axis = 2;
auto output_et = element::i32;
auto argmax = make_shared<op::ArgMax>(a, axis, output_et);
ASSERT_EQ(argmax->get_output_element_type(0), element::i32);
ASSERT_TRUE(
argmax->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 2, 4}));
}
TEST(type_prop, index_reduction_partial_et_dynamic_rank_static_dynamic_ok)
{
auto a =
make_shared<op::Parameter>(element::dynamic, PartialShape{Dimension::dynamic(), 2, 3, 4});
size_t axis = 2;
auto output_et = element::i32;
auto argmax = make_shared<op::ArgMax>(a, axis, output_et);
ASSERT_EQ(argmax->get_output_element_type(0), element::i32);
ASSERT_TRUE(
argmax->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 2, 4}));
}
TEST(type_prop, topk_invalid_rank) TEST(type_prop, topk_invalid_rank)
{ {
auto a = make_shared<op::Parameter>(element::f32, Shape{}); auto a = make_shared<op::Parameter>(element::f32, Shape{});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment