Commit fb49e0c2 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4e: Dot (#1787)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial shape/type validation for Dot

* Implement unit tests for partial-shape/type propagation for Dot

* Implement partial type/shape propagation for GetOutputElement

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness

* Some light reformatting

* Review comments
parent 982889f5
...@@ -47,57 +47,90 @@ op::Dot::Dot(const shared_ptr<Node>& arg0, ...@@ -47,57 +47,90 @@ op::Dot::Dot(const shared_ptr<Node>& arg0,
void op::Dot::validate_and_infer_types() void op::Dot::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic()) element::Type result_et;
{
return; NODE_VALIDATION_ASSERT(
} this, element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)))
<< "Arguments do not have the same element type (arg0 element type: "
<< get_input_element_type(0) << ", arg1 element type: " << get_input_element_type(1)
<< ").";
const PartialShape& arg0_shape = get_input_partial_shape(0);
const PartialShape& arg1_shape = get_input_partial_shape(1);
auto& input_0 = get_inputs().at(0); // If an explicit value was not passed for reduction axis count at construction time, we have
auto& input_1 = get_inputs().at(1); // some extra work to do.
//
// - If one of the arguments is known to be scalar, the count is 0.
// - If both of the arguments are known to be nonscalar, the count is 1.
// - Otherwise, the count is unknown.
bool reduction_axes_ambiguous = !m_has_reduction_axes_count;
if (!m_has_reduction_axes_count) if (reduction_axes_ambiguous)
{ {
m_reduction_axes_count = if (arg0_shape.rank().same_scheme(0) || arg1_shape.rank().same_scheme(0))
(input_0.get_shape().size() == 0 || input_1.get_shape().size() == 0) ? 0 : 1; {
m_reduction_axes_count = 0;
reduction_axes_ambiguous = false;
}
else if (arg0_shape.rank().is_static() && arg1_shape.rank().is_static())
{
m_reduction_axes_count = 1;
reduction_axes_ambiguous = false;
}
} }
NODE_VALIDATION_ASSERT(this, input_0.get_element_type() == input_1.get_element_type()) PartialShape result_shape;
<< "Arguments do not have the same element type (arg0 element type: "
<< input_0.get_element_type() << ", arg1 element type: " << input_1.get_element_type()
<< ").";
Shape input_0_shape = input_0.get_shape(); NODE_VALIDATION_ASSERT(this,
Shape input_1_shape = input_1.get_shape(); reduction_axes_ambiguous || arg0_shape.rank().is_dynamic() ||
m_reduction_axes_count <= size_t(arg0_shape.rank()))
<< "Reduction axes count (" << m_reduction_axes_count
<< ") is too large (arg0 shape: " << arg0_shape << ", arg1 shape: " << arg1_shape << ").";
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_ASSERT(this,
m_reduction_axes_count <= input_0_shape.size() && reduction_axes_ambiguous || arg1_shape.rank().is_dynamic() ||
m_reduction_axes_count <= input_1_shape.size()) m_reduction_axes_count <= size_t(arg1_shape.rank()))
<< "Reduction axes count (" << m_reduction_axes_count << "Reduction axes count (" << m_reduction_axes_count
<< ") is too large (arg0 shape: " << input_0_shape << ", arg1 shape: " << input_1_shape << ") is too large (arg0 shape: " << arg0_shape << ", arg1 shape: " << arg1_shape << ").";
<< ").";
for (size_t i = 0; i < m_reduction_axes_count; i++) if (!reduction_axes_ambiguous && arg0_shape.rank().is_static() && arg1_shape.rank().is_static())
{ {
size_t axis_index_arg0 = input_0_shape.size() - m_reduction_axes_count + i; for (size_t i = 0; i < m_reduction_axes_count; i++)
size_t axis_index_arg1 = i; {
size_t axis_index_arg0 = size_t(arg0_shape.rank()) - m_reduction_axes_count + i;
NODE_VALIDATION_ASSERT(this, size_t axis_index_arg1 = i;
input_0_shape[axis_index_arg0] == input_1_shape[axis_index_arg1])
<< "Paired axes (axis " << axis_index_arg0 << " from arg0, axis " << axis_index_arg1 NODE_VALIDATION_ASSERT(
<< " from arg1) " this, arg0_shape[axis_index_arg0].compatible(arg1_shape[axis_index_arg1]))
<< "do not have same length (arg0 shape: " << input_0_shape << "Paired axes (axis " << axis_index_arg0 << " from arg0, axis " << axis_index_arg1
<< ", arg1 shape: " << input_1_shape << ", " << " from arg1) do not have same length (arg0 shape: " << arg0_shape
<< "reduction axes count: " << m_reduction_axes_count << ")."; << ", arg1 shape: " << arg1_shape
<< ", reduction axes count: " << m_reduction_axes_count << ").";
}
std::vector<Dimension> result_dims(size_t(arg0_shape.rank()) + size_t(arg1_shape.rank()) -
2 * m_reduction_axes_count);
size_t i = 0;
for (size_t j = 0; j < size_t(arg0_shape.rank()) - m_reduction_axes_count; j++)
{
result_dims[i++] = arg0_shape[j];
}
for (size_t j = m_reduction_axes_count; j < size_t(arg1_shape.rank()); j++)
{
result_dims[i++] = arg1_shape[j];
}
result_shape = PartialShape(result_dims);
}
else
{
result_shape = PartialShape::dynamic();
} }
Shape result_shape(input_0_shape.size() + input_1_shape.size() - 2 * m_reduction_axes_count); set_output_type(0, result_et, result_shape);
copy(input_0_shape.begin(), input_0_shape.end() - m_reduction_axes_count, result_shape.begin());
copy(input_1_shape.begin() + m_reduction_axes_count,
input_1_shape.end(),
result_shape.begin() + (input_0_shape.size() - m_reduction_axes_count));
set_output_type(0, input_0.get_element_type(), result_shape);
} }
shared_ptr<op::Reshape> make_reshape_axes_to_front(const shared_ptr<Node>& n, shared_ptr<op::Reshape> make_reshape_axes_to_front(const shared_ptr<Node>& n,
......
...@@ -853,6 +853,232 @@ TEST(type_prop, dot_deduce_reduction_axes_size_mismatch) ...@@ -853,6 +853,232 @@ TEST(type_prop, dot_deduce_reduction_axes_size_mismatch)
} }
} }
TEST(type_prop, dot_partial_both_rank_dynamic_axis_count_implicit)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto d = make_shared<op::Dot>(param0, param1);
ASSERT_TRUE(d->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, dot_partial_both_rank_dynamic_axis_count_explicit)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto d = make_shared<op::Dot>(param0, param1, /*reduction axis count=*/1234);
ASSERT_TRUE(d->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, dot_partial_left_rank_dynamic_right_rank_static_dynamic_axis_count_implicit)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3});
auto d = make_shared<op::Dot>(param0, param1);
ASSERT_TRUE(d->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, dot_partial_left_rank_dynamic_right_rank_static_dynamic_axis_count_explicit_ok)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3});
auto d = make_shared<op::Dot>(param0, param1, /* reduction axis count=*/3);
ASSERT_TRUE(d->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop,
dot_partial_left_rank_dynamic_right_rank_static_dynamic_axis_count_explicit_too_many)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto param1 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3});
try
{
auto d = make_shared<op::Dot>(param0, param1, /* reduction axis count=*/4);
FAIL()
<< "Too many reduction axes not detected (rank-dynamic/rank-static dynamic operands)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Reduction axes count (4) is too large");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dot_partial_left_rank_static_dynamic_right_rank_dynamic_axis_count_implicit)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto d = make_shared<op::Dot>(param0, param1);
ASSERT_TRUE(d->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, dot_partial_left_rank_static_dynamic_right_rank_dynamic_axis_count_explicit_ok)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto d = make_shared<op::Dot>(param0, param1, /* reduction axis count=*/3);
ASSERT_TRUE(d->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop,
dot_partial_left_rank_static_dynamic_right_rank_dynamic_axis_count_explicit_too_many)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3});
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
try
{
auto d = make_shared<op::Dot>(param0, param1, /* reduction axis count=*/4);
FAIL()
<< "Too many reduction axes not detected (rank-dynamic/rank-static dynamic operands)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Reduction axes count (4) is too large");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop,
dot_partial_left_rank_static_dynamic_right_rank_static_dynamic_axis_count_implicit_1_ok)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 2});
auto param1 = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), 4, Dimension::dynamic(), 5});
auto d = make_shared<op::Dot>(param0, param1);
ASSERT_TRUE(d->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 2, Dimension::dynamic(), 4, Dimension::dynamic(), 5}));
}
TEST(type_prop,
dot_partial_left_rank_static_dynamic_right_rank_static_dynamic_axis_count_implicit_0_ok)
{
auto param0 = make_shared<op::Parameter>(element::f32, PartialShape{});
auto param1 = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), 4, Dimension::dynamic(), 5});
auto d = make_shared<op::Dot>(param0, param1);
ASSERT_TRUE(d->get_output_partial_shape(0).same_scheme(
PartialShape{2, Dimension::dynamic(), 4, Dimension::dynamic(), 5}));
}
TEST(
type_prop,
dot_partial_left_rank_static_dynamic_right_rank_static_dynamic_axis_count_explicit_too_many_for_left)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3});
auto param1 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3, 5, 6});
try
{
auto d = make_shared<op::Dot>(param0, param1, /* reduction axis count=*/4);
FAIL() << "Too many reduction axes not detected (rank-static dynamic/rank-static dynamic "
"operands)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Reduction axes count (4) is too large");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
dot_partial_left_rank_static_dynamic_right_rank_static_dynamic_axis_count_explicit_too_many_for_right)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3, 5, 6});
auto param1 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3});
try
{
auto d = make_shared<op::Dot>(param0, param1, /* reduction axis count=*/4);
FAIL() << "Too many reduction axes not detected (rank-static dynamic/rank-static dynamic "
"operands)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Reduction axes count (4) is too large");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
dot_partial_left_rank_static_dynamic_right_rank_static_dynamic_axis_count_explicit_too_many_for_both)
{
auto param0 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3});
auto param1 =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 2, 3});
try
{
auto d = make_shared<op::Dot>(param0, param1, /* reduction axis count=*/4);
FAIL() << "Too many reduction axes not detected (rank-static dynamic/rank-static dynamic "
"operands)";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Reduction axes count (4) is too large");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dot_partial_left_et_dynamic)
{
auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto d = make_shared<op::Dot>(param0, param1, /* reduction axis count=*/3);
ASSERT_EQ(d->get_output_element_type(0), element::f32);
}
TEST(type_prop, dot_partial_right_et_dynamic)
{
auto param0 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto d = make_shared<op::Dot>(param0, param1, /* reduction axis count=*/3);
ASSERT_EQ(d->get_output_element_type(0), element::i32);
}
TEST(type_prop, dot_partial_both_et_dynamic)
{
auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto param1 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto d = make_shared<op::Dot>(param0, param1, /* reduction axis count=*/3);
ASSERT_EQ(d->get_output_element_type(0), element::dynamic);
}
// //
// Tests for binary elementwise ops. // Tests for binary elementwise ops.
// //
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment