Unverified Commit 3feb4264 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

Merge branch 'master' into ayzhuang/in-place-concat

parents 46f4acae 9aba28dc
...@@ -32,58 +32,49 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis) ...@@ -32,58 +32,49 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void op::Concat::validate_and_infer_types() void op::Concat::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required."; NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
Shape first_input_shape = get_input_shape(0); PartialShape inputs_shape_scheme{PartialShape::dynamic()};
size_t expected_rank = first_input_shape.size(); element::Type inputs_et{element::dynamic};
element::Type expected_et = get_input_element_type(0); Dimension concatenation_axis_output_dim{0};
for (auto i = 1; i < get_inputs().size(); i++) for (auto i = 0; i < get_inputs().size(); i++)
{ {
NODE_VALIDATION_ASSERT(this, get_input_shape(i).size() == expected_rank) PartialShape this_input_shape = get_input_partial_shape(i);
<< "Not all arguments have the same rank: argument 0 has shape " << first_input_shape Dimension this_input_rank = this_input_shape.rank();
<< " of rank " << expected_rank << " but argument " << i << " has shape " if (this_input_rank.is_static())
<< get_input_shape(i) << " of rank " << get_input_shape(i).size() << "."; {
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < size_t(this_input_rank))
NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == expected_et) << "Concatenation axis (" << m_concatenation_axis << ") is out of bounds for "
<< "Not all arguments have the same element type: argument 0 has element type " << "argument " << i << ", which has shape " << this_input_shape << ".";
<< expected_et << " but argument " << i << " has element type "
<< get_input_element_type(i) << "."; concatenation_axis_output_dim += this_input_shape[m_concatenation_axis];
this_input_shape[m_concatenation_axis] = Dimension::dynamic();
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape))
<< "Argument shapes are inconsistent; they must have the same rank, and must have "
<< "equal dimension everywhere except on the concatenation axis (axis "
<< m_concatenation_axis << ").";
NODE_VALIDATION_ASSERT(
this, element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)))
<< "Argument element types are inconsistent.";
}
else
{
concatenation_axis_output_dim += Dimension::dynamic();
}
} }
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < expected_rank) PartialShape concatenated_shape = inputs_shape_scheme;
<< "Concatenation axis (" << m_concatenation_axis << ") is out of bounds (inputs have rank "
<< expected_rank << ").";
size_t concatenation_axis_output_length = first_input_shape.at(m_concatenation_axis);
for (auto i = 1; i < get_inputs().size(); i++) if (concatenated_shape.rank().is_static())
{ {
for (auto j = 0; j < get_input_shape(i).size(); j++) concatenated_shape[m_concatenation_axis] = concatenation_axis_output_dim;
{
if (j != m_concatenation_axis)
{
NODE_VALIDATION_ASSERT(this, first_input_shape[j] == get_input_shape(i)[j])
<< "Dimensions of argument " << i << " do not match for axis " << j
<< " (expected " << first_input_shape[j] << ", got " << get_input_shape(i)[j]
<< ").";
}
else
{
concatenation_axis_output_length += get_input_shape(i)[j];
}
}
} }
Shape concatenated_shape = first_input_shape; set_output_type(0, inputs_et, concatenated_shape);
concatenated_shape[m_concatenation_axis] = concatenation_axis_output_length;
set_output_type(0, expected_et, concatenated_shape);
} }
shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment