Commit 9aba28dc authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Partial Shapes and Types, Part 4b: Concat (#1778)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Implement partial shape/type propagation for AllReduce

* Implement partial shape/type propagation for Reshape

* Remove unneeded validate_punt from Result

* Implement partial shape/type propagation for Reverse

* Implement partial shape/type validation for ReverseSequence

* Implement partial shape/type validation for ArithmeticReduction

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Fix comment typo

* Rewrite operator+ and operator* for Dimension as members, not friends

* Formatting tweak

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Fix error message

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item

* Add unit tests for partial shape/type propagation with ReverseSequence

* Add unit tests for partial-shape/type propagation for ArithmeticReduction (via Sum)

* Implement partial shape/type validation for concat

* Fix for a corner case in concat propagation of dynamic shapes; unit tests for concat propagation of dynamic shapes

* Implement partial type/shape propagation for GetOutputElement

* Function signatures

* Add implementations, unit tests for relaxes/refines functions

* Generalize project/reduce/inject functions to cover PartialShape, move to shape_util.[ch]pp

* Deal with std::find_if #include issues

* Fix more include madness

* Remove validate-punt-if-dynamic test because it uses Concat
parent d3d27108
......@@ -32,58 +32,49 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void op::Concat::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
Shape first_input_shape = get_input_shape(0);
size_t expected_rank = first_input_shape.size();
element::Type expected_et = get_input_element_type(0);
PartialShape inputs_shape_scheme{PartialShape::dynamic()};
element::Type inputs_et{element::dynamic};
Dimension concatenation_axis_output_dim{0};
for (auto i = 1; i < get_inputs().size(); i++)
for (auto i = 0; i < get_inputs().size(); i++)
{
NODE_VALIDATION_ASSERT(this, get_input_shape(i).size() == expected_rank)
<< "Not all arguments have the same rank: argument 0 has shape " << first_input_shape
<< " of rank " << expected_rank << " but argument " << i << " has shape "
<< get_input_shape(i) << " of rank " << get_input_shape(i).size() << ".";
NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == expected_et)
<< "Not all arguments have the same element type: argument 0 has element type "
<< expected_et << " but argument " << i << " has element type "
<< get_input_element_type(i) << ".";
PartialShape this_input_shape = get_input_partial_shape(i);
Dimension this_input_rank = this_input_shape.rank();
if (this_input_rank.is_static())
{
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < size_t(this_input_rank))
<< "Concatenation axis (" << m_concatenation_axis << ") is out of bounds for "
<< "argument " << i << ", which has shape " << this_input_shape << ".";
concatenation_axis_output_dim += this_input_shape[m_concatenation_axis];
this_input_shape[m_concatenation_axis] = Dimension::dynamic();
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape))
<< "Argument shapes are inconsistent; they must have the same rank, and must have "
<< "equal dimension everywhere except on the concatenation axis (axis "
<< m_concatenation_axis << ").";
NODE_VALIDATION_ASSERT(
this, element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)))
<< "Argument element types are inconsistent.";
}
else
{
concatenation_axis_output_dim += Dimension::dynamic();
}
}
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < expected_rank)
<< "Concatenation axis (" << m_concatenation_axis << ") is out of bounds (inputs have rank "
<< expected_rank << ").";
size_t concatenation_axis_output_length = first_input_shape.at(m_concatenation_axis);
PartialShape concatenated_shape = inputs_shape_scheme;
for (auto i = 1; i < get_inputs().size(); i++)
if (concatenated_shape.rank().is_static())
{
for (auto j = 0; j < get_input_shape(i).size(); j++)
{
if (j != m_concatenation_axis)
{
NODE_VALIDATION_ASSERT(this, first_input_shape[j] == get_input_shape(i)[j])
<< "Dimensions of argument " << i << " do not match for axis " << j
<< " (expected " << first_input_shape[j] << ", got " << get_input_shape(i)[j]
<< ").";
}
else
{
concatenation_axis_output_length += get_input_shape(i)[j];
}
}
concatenated_shape[m_concatenation_axis] = concatenation_axis_output_dim;
}
Shape concatenated_shape = first_input_shape;
concatenated_shape[m_concatenation_axis] = concatenation_axis_output_length;
set_output_type(0, expected_et, concatenated_shape);
set_output_type(0, inputs_et, concatenated_shape);
}
shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment