Unverified Commit 3feb4264 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

Merge branch 'master' into ayzhuang/in-place-concat

parents 46f4acae 9aba28dc
......@@ -32,58 +32,49 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void op::Concat::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
Shape first_input_shape = get_input_shape(0);
size_t expected_rank = first_input_shape.size();
element::Type expected_et = get_input_element_type(0);
PartialShape inputs_shape_scheme{PartialShape::dynamic()};
element::Type inputs_et{element::dynamic};
Dimension concatenation_axis_output_dim{0};
for (auto i = 1; i < get_inputs().size(); i++)
for (auto i = 0; i < get_inputs().size(); i++)
{
NODE_VALIDATION_ASSERT(this, get_input_shape(i).size() == expected_rank)
<< "Not all arguments have the same rank: argument 0 has shape " << first_input_shape
<< " of rank " << expected_rank << " but argument " << i << " has shape "
<< get_input_shape(i) << " of rank " << get_input_shape(i).size() << ".";
NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == expected_et)
<< "Not all arguments have the same element type: argument 0 has element type "
<< expected_et << " but argument " << i << " has element type "
<< get_input_element_type(i) << ".";
PartialShape this_input_shape = get_input_partial_shape(i);
Dimension this_input_rank = this_input_shape.rank();
if (this_input_rank.is_static())
{
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < size_t(this_input_rank))
<< "Concatenation axis (" << m_concatenation_axis << ") is out of bounds for "
<< "argument " << i << ", which has shape " << this_input_shape << ".";
concatenation_axis_output_dim += this_input_shape[m_concatenation_axis];
this_input_shape[m_concatenation_axis] = Dimension::dynamic();
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape))
<< "Argument shapes are inconsistent; they must have the same rank, and must have "
<< "equal dimension everywhere except on the concatenation axis (axis "
<< m_concatenation_axis << ").";
NODE_VALIDATION_ASSERT(
this, element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)))
<< "Argument element types are inconsistent.";
}
else
{
concatenation_axis_output_dim += Dimension::dynamic();
}
}
NODE_VALIDATION_ASSERT(this, m_concatenation_axis < expected_rank)
<< "Concatenation axis (" << m_concatenation_axis << ") is out of bounds (inputs have rank "
<< expected_rank << ").";
size_t concatenation_axis_output_length = first_input_shape.at(m_concatenation_axis);
PartialShape concatenated_shape = inputs_shape_scheme;
for (auto i = 1; i < get_inputs().size(); i++)
if (concatenated_shape.rank().is_static())
{
for (auto j = 0; j < get_input_shape(i).size(); j++)
{
if (j != m_concatenation_axis)
{
NODE_VALIDATION_ASSERT(this, first_input_shape[j] == get_input_shape(i)[j])
<< "Dimensions of argument " << i << " do not match for axis " << j
<< " (expected " << first_input_shape[j] << ", got " << get_input_shape(i)[j]
<< ").";
}
else
{
concatenation_axis_output_length += get_input_shape(i)[j];
}
}
concatenated_shape[m_concatenation_axis] = concatenation_axis_output_dim;
}
Shape concatenated_shape = first_input_shape;
concatenated_shape[m_concatenation_axis] = concatenation_axis_output_length;
set_output_type(0, expected_et, concatenated_shape);
set_output_type(0, inputs_et, concatenated_shape);
}
shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment