Commit 34aae47c authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4d: Broadcast (#1783)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial shape/type validation for Broadcast; implement unit tests for same

* Remove inapplicable TODO

* Implement partial type/shape propagation for GetOutputElement

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness
parent 9aba28dc
......@@ -40,33 +40,28 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
void op::Broadcast::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
infer_shape();
for (auto axis : m_broadcast_axes)
{
return;
NODE_VALIDATION_ASSERT(this, axis < m_shape.size())
<< "Broadcast axis index (" << axis << ") exceeds specified output shape rank "
<< "(broadcast axes: " << m_broadcast_axes << ", output shape: " << m_shape << ").";
}
infer_shape();
Shape target_shape = m_shape;
Shape required_input_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
NODE_VALIDATION_ASSERT(this, *i < target_shape.size())
<< "Broadcast axis index (" << *i << ") exceeds target shape rank "
<< "(broadcast axes: " << m_broadcast_axes << ", target shape: " << target_shape
<< ").";
target_shape.erase(target_shape.begin() + *i);
required_input_shape.erase(required_input_shape.begin() + *i);
}
// TODO(amprocte): We can probably have a more helpful error message here.
// There are two things that can go wrong, which are being picked up in
// one fell swoop by this check: either the number of broadcast axes is not
// enough (arg->get_shape().size() + broadcast_axes.size() != shape.size())
// or there is a mismatch with one of the pre-broadcast axis lengths
// (i.e. target_shape.size() == arg->get_shape.size() but there is some i
// where target_shape[i] != arg->get_shape[i]).
NODE_VALIDATION_ASSERT(this, target_shape == get_input_shape(0))
<< "Broadcast argument shape, target shape, and axes are incompatible "
<< "(argument shape: " << get_input_shape(0) << ", target shape: " << m_shape
// enough, or there is a mismatch with one of the pre-broadcast axis lengths.
NODE_VALIDATION_ASSERT(this, get_input_partial_shape(0).compatible(required_input_shape))
<< "Broadcast argument shape, specified output shape, and axes are incompatible "
<< "(argument shape: " << get_input_partial_shape(0) << ", output shape: " << m_shape
<< ", broadcast axes: " << m_broadcast_axes << ").";
set_output_type(0, get_input_element_type(0), m_shape);
......
......@@ -25,12 +25,8 @@ using namespace ngraph;
#define EXPECT_HAS_SUBSTRING(haystack, needle) \
EXPECT_PRED_FORMAT2(testing::IsSubstring, needle, haystack)
//
// Tests for broadcast.
//
TEST(type_prop, broadcast_deduce)
{
// Deduce type
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
Shape bc_shape{2, 3, 4};
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
......@@ -38,6 +34,175 @@ TEST(type_prop, broadcast_deduce)
ASSERT_EQ(bc->get_shape(), bc_shape);
}
TEST(type_prop, broadcast_axes_oob)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = Shape{2, 3, 4};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1, 3});
FAIL() << "Broadcast axis out of bounds not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Broadcast axis index (3) exceeds specified output shape rank");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_shape_mismatch_wrong_rank)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = Shape{2, 3, 4, 5};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
FAIL() << "Output shape mismatch (wrong rank) not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Broadcast argument shape, specified output shape, and axes are incompatible");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_shape_mismatch_wrong_size)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bc_shape = Shape{2, 3, 5};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
FAIL() << "Output shape mismatch (wrong size) not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Broadcast argument shape, specified output shape, and axes are incompatible");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_partial_rank_dynamic_ok)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
Shape bc_shape{2, 3, 4};
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
ASSERT_EQ(bc->get_element_type(), element::f32);
ASSERT_EQ(bc->get_shape(), bc_shape);
}
TEST(type_prop, broadcast_partial_rank_dynamic_axes_oob)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto bc_shape = Shape{2, 3, 4};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1, 3});
FAIL() << "Broadcast axis out of bounds not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Broadcast axis index (3) exceeds specified output shape rank");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_partial_rank_static_dynamic_ok)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
Shape bc_shape{2, 3, 4};
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
ASSERT_EQ(bc->get_element_type(), element::f32);
ASSERT_EQ(bc->get_shape(), bc_shape);
}
TEST(type_prop, broadcast_partial_rank_static_dynamic_axes_oob)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
auto bc_shape = Shape{2, 3, 4};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1, 3});
FAIL() << "Broadcast axis out of bounds not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Broadcast axis index (3) exceeds specified output shape rank");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_partial_rank_static_dynamic_shape_mismatch_wrong_rank)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
auto bc_shape = Shape{2, 3, 4, 5};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
FAIL() << "Output shape mismatch (wrong rank) not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Broadcast argument shape, specified output shape, and axes are incompatible");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_partial_rank_static_dynamic_shape_mismatch_wrong_size)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
auto bc_shape = Shape{2, 3, 5};
try
{
auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
FAIL() << "Output shape mismatch (wrong size) not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
"Broadcast argument shape, specified output shape, and axes are incompatible");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batchnorm_rank_less_than_2)
{
auto dummy = make_shared<op::Parameter>(element::f32, Shape{1});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment