Commit 9aba28dc authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Partial Shapes and Types, Part 4b: Concat (#1778)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial shape/type validation for concat

* Fix for a corner case in concat propagation of dynamic shapes; unit tests for concat propagation of dynamic shapes

* Implement partial type/shape propagation for GetOutputElement

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness

* Remove validate-punt-if-dynamic test because it uses Concat
parent d3d27108
...@@ -32,58 +32,49 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis) ...@@ -32,58 +32,49 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void op::Concat::validate_and_infer_types() void op::Concat::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required."; NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
Shape first_input_shape = get_input_shape(0); PartialShape inputs_shape_scheme{PartialShape::dynamic()};
size_t expected_rank = first_input_shape.size(); element::Type inputs_et{element::dynamic};
element::Type expected_et = get_input_element_type(0); Dimension concatenation_axis_output_dim{0};
for (auto i = 1; i < get_inputs().size(); i++) for (auto i = 0; i < get_inputs().size(); i++)
{ {
NODE_VALIDATION_ASSERT(this, get_input_shape(i).size() == expected_rank) PartialShape this_input_shape = get_input_partial_shape(i);
<< "Not all arguments have the same rank: argument 0 has shape " << first_input_shape Dimension this_input_rank = this_input_shape.rank();
<< " of rank " << expected_rank << " but argument " << i << " has shape " if (this_input_rank.is_static())
<< get_input_shape(i) << " of rank " << get_input_shape(i).size() << ".";
NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == expected_et)
<< "Not all arguments have the same element type: argument 0 has element type "
<< expected_et << " but argument " << i << " has element type "
<< get_input_element_type(i) << ".";
}
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < expected_rank)
<< "Concatenation axis (" << m_concatenation_axis << ") is out of bounds (inputs have rank "
<< expected_rank << ").";
size_t concatenation_axis_output_length = first_input_shape.at(m_concatenation_axis);
for (auto i = 1; i < get_inputs().size(); i++)
{ {
for (auto j = 0; j < get_input_shape(i).size(); j++) NODE_VALIDATION_ASSERT(this, m_concatenation_axis < size_t(this_input_rank))
{ << "Concatenation axis (" << m_concatenation_axis << ") is out of bounds for "
if (j != m_concatenation_axis) << "argument " << i << ", which has shape " << this_input_shape << ".";
{
NODE_VALIDATION_ASSERT(this, first_input_shape[j] == get_input_shape(i)[j]) concatenation_axis_output_dim += this_input_shape[m_concatenation_axis];
<< "Dimensions of argument " << i << " do not match for axis " << j this_input_shape[m_concatenation_axis] = Dimension::dynamic();
<< " (expected " << first_input_shape[j] << ", got " << get_input_shape(i)[j]
<< ")."; NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape))
<< "Argument shapes are inconsistent; they must have the same rank, and must have "
<< "equal dimension everywhere except on the concatenation axis (axis "
<< m_concatenation_axis << ").";
NODE_VALIDATION_ASSERT(
this, element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)))
<< "Argument element types are inconsistent.";
} }
else else
{ {
concatenation_axis_output_length += get_input_shape(i)[j]; concatenation_axis_output_dim += Dimension::dynamic();
}
} }
} }
Shape concatenated_shape = first_input_shape; PartialShape concatenated_shape = inputs_shape_scheme;
concatenated_shape[m_concatenation_axis] = concatenation_axis_output_length;
if (concatenated_shape.rank().is_static())
{
concatenated_shape[m_concatenation_axis] = concatenation_axis_output_dim;
}
set_output_type(0, expected_et, concatenated_shape); set_output_type(0, inputs_et, concatenated_shape);
} }
shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -243,7 +243,10 @@ TEST(type_prop, concat_deduce_wrong_rank) ...@@ -243,7 +243,10 @@ TEST(type_prop, concat_deduce_wrong_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Not all arguments have the same rank")); EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument shapes are inconsistent; they must have the same rank, and must "
"have equal dimension everywhere except on the concatenation axis"));
} }
catch (...) catch (...)
{ {
...@@ -264,8 +267,10 @@ TEST(type_prop, concat_deduce_wrong_shape) ...@@ -264,8 +267,10 @@ TEST(type_prop, concat_deduce_wrong_shape)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Dimensions of argument 2 do not match for axis 2")); error.what(),
std::string("Argument shapes are inconsistent; they must have the same rank, and must "
"have equal dimension everywhere except on the concatenation axis"));
} }
catch (...) catch (...)
{ {
...@@ -286,9 +291,7 @@ TEST(type_prop, concat_deduce_axis_oob) ...@@ -286,9 +291,7 @@ TEST(type_prop, concat_deduce_axis_oob)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(), std::string("Concatenation axis (3) is out of bounds"));
error.what(),
std::string("Concatenation axis (3) is out of bounds (inputs have rank 3)"));
} }
catch (...) catch (...)
{ {
...@@ -320,8 +323,239 @@ TEST(type_prop, concat_deduce_elem_type_mismatch) ...@@ -320,8 +323,239 @@ TEST(type_prop, concat_deduce_elem_type_mismatch)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument element types are inconsistent"));
std::string("Not all arguments have the same element type")); }
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, concat_partial_et_consistent)
{
auto param0 = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
auto param1 = make_shared<op::Parameter>(element::dynamic, Shape{2, 7, 4});
auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 2, 4});
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, 1);
ASSERT_EQ(c->get_element_type(), element::f32);
ASSERT_EQ(c->get_shape(), (Shape{2, 12, 4}));
}
TEST(type_prop, concat_partial_et_inconsistent)
{
auto param0 = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
auto param1 = make_shared<op::Parameter>(element::dynamic, Shape{2, 7, 4});
auto param2 = make_shared<op::Parameter>(element::i32, Shape{2, 2, 4});
try
{
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, 1);
// Should have thrown, so fail if it didn't
FAIL() << "Inconsistent element types not detected (some dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument element types are inconsistent"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, concat_partial_all_rank_dynamic)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, 1);
ASSERT_TRUE(c->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, concat_partial_some_rank_dynamic_others_rank_static_dynamic_consistent)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 =
make_shared<op::Parameter>(element::f32, PartialShape{2, 3, Dimension::dynamic()});
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, 1);
ASSERT_TRUE(
c->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
}
TEST(type_prop, concat_partial_some_rank_dynamic_others_rank_static_dynamic_rank_inconsistent)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 =
make_shared<op::Parameter>(element::f32, PartialShape{2, 3, Dimension::dynamic(), 4});
try
{
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, 1);
// Should have thrown, so fail if it didn't
FAIL() << "Inconsistent ranks not detected (some args rank-dynamic, some args rank-static "
"dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument shapes are inconsistent; they must have the same rank, and must "
"have equal dimension everywhere except on the concatenation axis"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, concat_partial_some_rank_dynamic_others_rank_static_dynamic_dims_inconsistent)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 =
make_shared<op::Parameter>(element::f32, PartialShape{3, 3, Dimension::dynamic()});
try
{
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, 1);
// Should have thrown, so fail if it didn't
FAIL() << "Inconsistent dimensions not detected (some args rank-dynamic, some args "
"rank-static dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument shapes are inconsistent; they must have the same rank, and must "
"have equal dimension everywhere except on the concatenation axis"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop,
concat_partial_some_rank_dynamic_others_rank_static_dynamic_dims_intransitively_inconsistent)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(
element::f32, PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()});
auto param3 =
make_shared<op::Parameter>(element::f32, PartialShape{3, 3, Dimension::dynamic()});
try
{
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2, param3}, 1);
// Should have thrown, so fail if it didn't
FAIL() << "Inconsistent dimensions not detected (some args rank-dynamic, some args "
"rank-static dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument shapes are inconsistent; they must have the same rank, and must "
"have equal dimension everywhere except on the concatenation axis"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, concat_partial_some_rank_dynamic_others_rank_static_with_concat_axis_static)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape{2, 2, 3});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 =
make_shared<op::Parameter>(element::f32, PartialShape{2, 3, Dimension::dynamic()});
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, 1);
ASSERT_TRUE(
c->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
}
TEST(type_prop,
concat_partial_some_rank_dynamic_others_rank_static_with_concat_axis_static_dims_inconsistent)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape{2, 2, 3});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 =
make_shared<op::Parameter>(element::f32, PartialShape{3, 3, Dimension::dynamic()});
try
{
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, 1);
// Should have thrown, so fail if it didn't
FAIL() << "Inconsistent dimensions not detected (some args rank-dynamic, some args "
"rank-static dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument shapes are inconsistent; they must have the same rank, and must "
"have equal dimension everywhere except on the concatenation axis"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, concat_partial_all_static_with_concat_axis_static_compatible_result_static)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape{2, 2, 3});
auto param1 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4, 3});
auto param2 =
make_shared<op::Parameter>(element::f32, PartialShape{2, 3, Dimension::dynamic()});
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, 1);
ASSERT_EQ(c->get_shape(), (Shape{2, 9, 3}));
}
TEST(type_prop, concat_partial_all_static_with_concat_axis_static_compatible_result_dynamic)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{2, 2, Dimension::dynamic()});
auto param1 = make_shared<op::Parameter>(
element::f32, PartialShape{Dimension::dynamic(), 4, Dimension::dynamic()});
auto param2 =
make_shared<op::Parameter>(element::f32, PartialShape{2, 3, Dimension::dynamic()});
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, 1);
ASSERT_TRUE(
c->get_output_partial_shape(0).same_scheme(PartialShape{2, 9, Dimension::dynamic()}));
}
TEST(type_prop, concat_partial_all_static_with_concat_axis_static_dims_incompatible)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape{2, 2, 3});
auto param1 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4, 3});
auto param2 =
make_shared<op::Parameter>(element::f32, PartialShape{3, 3, Dimension::dynamic()});
try
{
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, 1);
// Should have thrown, so fail if it didn't
FAIL() << "Inconsistent dimensions not detected (some args rank-dynamic, some args "
"rank-static dynamic)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument shapes are inconsistent; they must have the same rank, and must "
"have equal dimension everywhere except on the concatenation axis"));
} }
catch (...) catch (...)
{ {
...@@ -8496,20 +8730,3 @@ TEST(type_prop, dequantize_offset_shape_mismatch_different_rank_fails) ...@@ -8496,20 +8730,3 @@ TEST(type_prop, dequantize_offset_shape_mismatch_different_rank_fails)
FAIL() << "Deduced type check failed for unexpected reason"; FAIL() << "Deduced type check failed for unexpected reason";
} }
} }
//
// This is testing a temporary hack for ops that do not yet support partial-shape validation.
// The graph we construct here is bogus, but because there is some partiality in the input shapes,
// it should still pass validation but set the output shape and element types to be dynamic.
//
TEST(type_prop, validate_punt_if_dynamic)
{
auto a = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto b = make_shared<op::Parameter>(element::u32, PartialShape{1, Dimension::dynamic(), 3});
auto c = make_shared<op::Parameter>(element::i32, Shape{1, 8, 3});
auto concat = make_shared<op::Concat>(NodeVector{a, b, c}, /*concatenation axis=*/1234);
ASSERT_EQ(concat->get_output_size(), 1);
ASSERT_TRUE(concat->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(concat->get_output_element_type(0).is_dynamic());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment