Commit de462cba authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4f: Pad (#1799)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial shape/type validation for Pad; add unit tests for same

* Implement partial type/shape propagation for GetOutputElement

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness

* Formatting
parent 6e7d9e3b
...@@ -32,40 +32,56 @@ op::Pad::Pad(const shared_ptr<Node>& arg, ...@@ -32,40 +32,56 @@ op::Pad::Pad(const shared_ptr<Node>& arg,
, m_padding_interior(padding_interior) , m_padding_interior(padding_interior)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
}
void op::Pad::validate_and_infer_types()
{
element::Type result_et;
NODE_VALIDATION_ASSERT(this, get_input_element_type(0) == get_input_element_type(1)) NODE_VALIDATION_ASSERT(
this, element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)))
<< "Argument element types do not match (arg0 element type: " << get_input_element_type(0) << "Argument element types do not match (arg0 element type: " << get_input_element_type(0)
<< ", arg1 element type: " << get_input_element_type(1) << ")."; << ", arg1 element type: " << get_input_element_type(1) << ").";
NODE_VALIDATION_ASSERT(this, get_input_shape(1) == Shape{}) NODE_VALIDATION_ASSERT(this, get_input_partial_shape(1).compatible(PartialShape{}))
<< "Argument for padding value is not a scalar (shape: " << get_input_shape(1) << ")."; << "Argument for padding value is not a scalar (shape: " << get_input_partial_shape(1)
<< ").";
auto arg_shape = get_input_shape(0); auto arg_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() == padding_below.size()) NODE_VALIDATION_ASSERT(this,
<< "Rank for padding below does not match the rank of the data argument (padding below: " m_padding_below.size() == m_padding_above.size() &&
<< padding_below << ", data argument shape: " << arg_shape << ")."; m_padding_below.size() == m_padding_interior.size())
<< "Ranks for padding below (" << m_padding_below << "), padding above (" << m_padding_above
<< ") and interior padding (" << m_padding_interior << ") "
<< "do not match.";
NODE_VALIDATION_ASSERT(this, arg_shape.size() == padding_above.size()) size_t implied_rank = m_padding_below.size();
<< "Rank for padding above does not match the rank of the data argument (padding above: "
<< padding_above << ", data argument shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, arg_shape.size() == padding_interior.size()) NODE_VALIDATION_ASSERT(this, arg_shape.rank().compatible(implied_rank))
<< "Rank for interior padding does not match the rank of the data argument (interior " << "Rank for padding below/padding above/interior padding does not match the rank of the "
"padding: " << "data argument (padding below: " << m_padding_below << ", "
<< padding_interior << ", data argument shape: " << arg_shape << ")."; << ", padding above: " << m_padding_above << ", interior padding: " << m_padding_interior
<< ").";
Shape result_shape; std::vector<Dimension> result_dims(implied_rank, Dimension::dynamic());
for (size_t i = 0; i < arg_shape.size(); i++) if (arg_shape.rank().is_static())
{ {
result_shape.push_back( for (size_t i = 0; i < implied_rank; i++)
padding_below[i] + {
subtract_or_zero(arg_shape[i] * (padding_interior[i] + 1), padding_interior[i]) + if (arg_shape[i].is_static())
padding_above[i]); {
result_dims[i] =
m_padding_below[i] +
subtract_or_zero(size_t(arg_shape[i]) * (m_padding_interior[i] + 1),
m_padding_interior[i]) +
m_padding_above[i];
}
}
} }
set_output_type(0, get_input_element_type(0), result_shape); set_output_type(0, result_et, PartialShape(result_dims));
} }
shared_ptr<Node> op::Pad::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Pad::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -51,6 +51,7 @@ namespace ngraph ...@@ -51,6 +51,7 @@ namespace ngraph
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
protected: protected:
void validate_and_infer_types() override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
Shape m_padding_below; Shape m_padding_below;
......
...@@ -6691,7 +6691,8 @@ TEST(type_prop, pad_deduce_below_padding_wrong_rank) ...@@ -6691,7 +6691,8 @@ TEST(type_prop, pad_deduce_below_padding_wrong_rank)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank for padding below does not match the rank of the data argument")); std::string("Ranks for padding below (Shape{5, 3, 0, 6}), padding above (Shape{6, 9, "
"4}) and interior padding (Shape{2, 3, 0}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -6717,9 +6718,10 @@ TEST(type_prop, pad_deduce_above_padding_wrong_rank) ...@@ -6717,9 +6718,10 @@ TEST(type_prop, pad_deduce_above_padding_wrong_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Ranks for padding below (Shape{5, 3, 0}), "
std::string("Rank for padding above does not match the rank of the data argument")); "padding above (Shape{6, 9}) and interior "
"padding (Shape{2, 3, 0}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -6747,7 +6749,158 @@ TEST(type_prop, pad_deduce_interior_padding_wrong_rank) ...@@ -6747,7 +6749,158 @@ TEST(type_prop, pad_deduce_interior_padding_wrong_rank)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Rank for interior padding does not match the rank of the data argument")); std::string("Ranks for padding below (Shape{5, 3, 0}), padding above (Shape{6, 9, 4}) "
"and interior padding (Shape{2, 3, 0, 9, 3}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, pad_partial_data_rank_dynamic_padding_rank_dynamic_ok)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3};
Shape padding_interior{1, 0, 1};
auto pad = make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
ASSERT_EQ(pad->get_output_element_type(0), element::f32);
ASSERT_TRUE(pad->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, pad_partial_data_rank_dynamic_padding_rank_dynamic_attribs_rank_inconsistent)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3, 0};
Shape padding_interior{1, 0, 1};
try
{
auto pad =
make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
FAIL() << "Inconsistent attribute ranks not detected (rank-dynamic/rank-dynamic arguments)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks for padding below (Shape{2, 4, 6}), padding above (Shape{8, 2, 3, "
"0}) and interior padding (Shape{1, 0, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, pad_partial_data_rank_static_dynamic_padding_rank_dynamic_ok)
{
auto param0 = make_shared<op::Parameter>(
element::f32,
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3};
Shape padding_interior{1, 0, 1};
auto pad = make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
ASSERT_EQ(pad->get_output_element_type(0), element::f32);
ASSERT_TRUE(pad->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, pad_partial_data_rank_static_dynamic_some_dims_known_padding_rank_dynamic_ok)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{3, 5, Dimension::dynamic()});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3};
Shape padding_interior{1, 0, 1};
auto pad = make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
ASSERT_EQ(pad->get_output_element_type(0), element::f32);
ASSERT_TRUE(
pad->get_output_partial_shape(0).same_scheme(PartialShape{15, 11, Dimension::dynamic()}));
}
TEST(type_prop, pad_partial_data_rank_dynamic_padding_static_ok)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, Shape{});
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3};
Shape padding_interior{1, 0, 1};
auto pad = make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
ASSERT_EQ(pad->get_output_element_type(0), element::f32);
ASSERT_TRUE(pad->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, pad_partial_data_rank_dynamic_padding_static_wrong_padding_rank)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 3, 8});
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3};
Shape padding_interior{1, 0, 1};
try
{
auto pad =
make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
FAIL() << "Wrong padding rank not detected (rank-dynamic/static arguments)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument for padding value is not a scalar (shape: {2,3,8})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, pad_partial_data_rank_dynamic_padding_static_attribs_rank_inconsistent)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, Shape{});
Shape padding_below{2, 4, 6};
Shape padding_above{8, 2, 3, 4};
Shape padding_interior{1, 0, 1};
try
{
auto pad =
make_shared<op::Pad>(param0, param1, padding_below, padding_above, padding_interior);
FAIL() << "Wrong padding rank not detected (rank-dynamic/static arguments)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks for padding below (Shape{2, 4, 6}), padding above (Shape{8, 2, 3, "
"4}) and interior padding (Shape{1, 0, 1}) do not match"));
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment