Commit 9aba28dc authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Partial Shapes and Types, Part 4b: Concat (#1778)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial shape/type validation for concat

* Fix for a corner case in concat propagation of dynamic shapes; unit tests for concat propagation of dynamic shapes

* Implement partial type/shape propagation for GetOutputElement

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness

* Remove validate-punt-if-dynamic test because it uses Concat
parent d3d27108
...@@ -32,58 +32,49 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis) ...@@ -32,58 +32,49 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void op::Concat::validate_and_infer_types() void op::Concat::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required."; NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
Shape first_input_shape = get_input_shape(0); PartialShape inputs_shape_scheme{PartialShape::dynamic()};
size_t expected_rank = first_input_shape.size(); element::Type inputs_et{element::dynamic};
element::Type expected_et = get_input_element_type(0); Dimension concatenation_axis_output_dim{0};
for (auto i = 1; i < get_inputs().size(); i++) for (auto i = 0; i < get_inputs().size(); i++)
{ {
NODE_VALIDATION_ASSERT(this, get_input_shape(i).size() == expected_rank) PartialShape this_input_shape = get_input_partial_shape(i);
<< "Not all arguments have the same rank: argument 0 has shape " << first_input_shape Dimension this_input_rank = this_input_shape.rank();
<< " of rank " << expected_rank << " but argument " << i << " has shape " if (this_input_rank.is_static())
<< get_input_shape(i) << " of rank " << get_input_shape(i).size() << "."; {
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < size_t(this_input_rank))
NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == expected_et) << "Concatenation axis (" << m_concatenation_axis << ") is out of bounds for "
<< "Not all arguments have the same element type: argument 0 has element type " << "argument " << i << ", which has shape " << this_input_shape << ".";
<< expected_et << " but argument " << i << " has element type "
<< get_input_element_type(i) << "."; concatenation_axis_output_dim += this_input_shape[m_concatenation_axis];
this_input_shape[m_concatenation_axis] = Dimension::dynamic();
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape))
<< "Argument shapes are inconsistent; they must have the same rank, and must have "
<< "equal dimension everywhere except on the concatenation axis (axis "
<< m_concatenation_axis << ").";
NODE_VALIDATION_ASSERT(
this, element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)))
<< "Argument element types are inconsistent.";
}
else
{
concatenation_axis_output_dim += Dimension::dynamic();
}
} }
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < expected_rank) PartialShape concatenated_shape = inputs_shape_scheme;
<< "Concatenation axis (" << m_concatenation_axis << ") is out of bounds (inputs have rank "
<< expected_rank << ").";
size_t concatenation_axis_output_length = first_input_shape.at(m_concatenation_axis);
for (auto i = 1; i < get_inputs().size(); i++) if (concatenated_shape.rank().is_static())
{ {
for (auto j = 0; j < get_input_shape(i).size(); j++) concatenated_shape[m_concatenation_axis] = concatenation_axis_output_dim;
{
if (j != m_concatenation_axis)
{
NODE_VALIDATION_ASSERT(this, first_input_shape[j] == get_input_shape(i)[j])
<< "Dimensions of argument " << i << " do not match for axis " << j
<< " (expected " << first_input_shape[j] << ", got " << get_input_shape(i)[j]
<< ").";
}
else
{
concatenation_axis_output_length += get_input_shape(i)[j];
}
}
} }
Shape concatenated_shape = first_input_shape; set_output_type(0, inputs_et, concatenated_shape);
concatenated_shape[m_concatenation_axis] = concatenation_axis_output_length;
set_output_type(0, expected_et, concatenated_shape);
} }
shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment