Commit b079f266 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4g: Select (#1800)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial shape/type validation for Select; implement unit tests for same

* Implement partial type/shape propagation for GetOutputElement

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness
parent 34aae47c
......@@ -32,27 +32,30 @@ op::Select::Select(const shared_ptr<Node>& arg0,
: Op("Select", check_single_output_args({arg0, arg1, arg2}))
{
constructor_validate_and_infer_types();
}
auto& input_0 = get_inputs().at(0);
auto& input_1 = get_inputs().at(1);
auto& input_2 = get_inputs().at(2);
NODE_VALIDATION_ASSERT(this, input_0.get_element_type() == element::boolean)
void op::Select::validate_and_infer_types()
{
NODE_VALIDATION_ASSERT(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::boolean)
<< "Argument 0 does not have boolean element type (element type: "
<< input_0.get_element_type() << ").";
<< get_input_element_type(0) << ").";
NODE_VALIDATION_ASSERT(this,
input_0.get_shape() == input_1.get_shape() &&
input_0.get_shape() == input_2.get_shape())
<< "Arguments do not all have the same shape (arg0 shape: " << input_0.get_shape()
<< ", arg1 shape: " << input_1.get_shape() << ", arg2 shape: " << input_2.get_shape()
<< ").";
PartialShape result_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(result_shape, get_input_partial_shape(1)))
<< "Argument shapes are inconsistent.";
NODE_VALIDATION_ASSERT(this, PartialShape::merge_into(result_shape, get_input_partial_shape(2)))
<< "Argument shapes are inconsistent.";
element::Type result_et;
NODE_VALIDATION_ASSERT(this, input_1.get_element_type() == input_2.get_element_type())
<< "Arguments 1 and 2 do not have the same element type (arg1 type: "
<< input_1.get_element_type() << ", arg2 type: " << input_2.get_element_type() << ").";
NODE_VALIDATION_ASSERT(
this, element::Type::merge(result_et, get_input_element_type(1), get_input_element_type(2)))
<< "Argument 1 and 2 element types are inconsistent.";
set_output_type(0, input_1.get_element_type(), input_1.get_shape());
set_output_type(0, result_et, result_shape);
}
shared_ptr<Node> op::Select::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -53,6 +53,7 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override;
protected:
void validate_and_infer_types() override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
......
......@@ -1114,7 +1114,7 @@ TEST(type_prop, select_shape_mismatch_a)
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Arguments do not all have the same shape"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
}
catch (...)
{
......@@ -1135,7 +1135,7 @@ TEST(type_prop, select_shape_mismatch_b)
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Arguments do not all have the same shape"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
}
catch (...)
{
......@@ -1156,7 +1156,7 @@ TEST(type_prop, select_shape_mismatch_c)
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Arguments do not all have the same shape"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
}
catch (...)
{
......@@ -1200,7 +1200,160 @@ TEST(type_prop, select_elem_mismatch_bc)
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Arguments 1 and 2 do not have the same element type"));
std::string("Argument 1 and 2 element types are inconsistent"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, select_partial_all_rank_dynamic)
{
auto param0 = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, select_partial_all_rank_dynamic_arg0_et_dynamic_arg1_arg2_et_mismatch)
{
auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
try
{
auto sel = make_shared<op::Select>(param0, param1, param2);
FAIL() << "Did not detect mismatched element types for args 1 and 2 (element type-dynamic "
"arg0)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument 1 and 2 element types are inconsistent"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, select_partial_all_rank_dynamic_arg0_arg1_et_dynamic)
{
auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, select_partial_all_rank_dynamic_arg0_arg2_et_dynamic)
{
auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, select_partial_all_rank_dynamic_arg0_arg1_arg2_et_dynamic)
{
auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::dynamic);
ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, select_partial_arg0_rank_dynamic_static_arg1_arg2_rank_dynamic_ok)
{
auto param0 =
make_shared<op::Parameter>(element::boolean, PartialShape{2, Dimension::dynamic(), 3});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(
sel->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
}
TEST(type_prop, select_partial_arg1_rank_dynamic_static_arg0_arg2_rank_dynamic_ok)
{
auto param0 = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
auto param1 =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(
sel->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
}
TEST(type_prop, select_partial_arg2_rank_dynamic_static_arg0_arg1_rank_dynamic_ok)
{
auto param0 = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param2 =
make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(
sel->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
}
TEST(type_prop, select_partial_all_rank_static_dynamic_ok)
{
auto param0 = make_shared<op::Parameter>(
element::boolean, PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
auto param1 = make_shared<op::Parameter>(
element::f32, PartialShape{Dimension::dynamic(), 8, Dimension::dynamic()});
auto param2 = make_shared<op::Parameter>(
element::f32, PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3});
auto sel = make_shared<op::Select>(param0, param1, param2);
ASSERT_EQ(sel->get_output_element_type(0), element::f32);
ASSERT_TRUE(sel->get_output_partial_shape(0).is_static());
ASSERT_EQ(sel->get_output_shape(0), (Shape{2, 8, 3}));
}
TEST(type_prop, select_partial_all_rank_static_intransitive_incompatibility)
{
auto param0 = make_shared<op::Parameter>(
element::boolean, PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
auto param1 = make_shared<op::Parameter>(
element::f32, PartialShape{Dimension::dynamic(), 8, Dimension::dynamic()});
auto param2 =
make_shared<op::Parameter>(element::f32, PartialShape{3, Dimension::dynamic(), 3});
try
{
auto sel = make_shared<op::Select>(param0, param1, param2);
FAIL() << "Did not detect intransitive partial-shape incompatibility";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
}
catch (...)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment