Commit 779a9300 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Adjust Concat to specification (#3582)

* Changed name of concat axis

* Clang styles applied

* Revert "Clang styles applied"

This reverts commit 4c1d3f4436765ce9eafa00140365d5d3e358eab1.

* Revert "Changed name of concat axis"

This reverts commit cdfe6638777509f21b199ce78fc2fe28cec15d45.

* Introduced alias methods

* Updated changes.md
parent bccb1ec8
......@@ -68,6 +68,11 @@ arguments now take type `CoordinateDiff` instead of `Shape`. `CoordinateDiff` is
`std::vector<std::ptrdiff_t>`, which "is like `size_t` but is allowed to be negative". Callers may
need to be adapted.
## Changes to Concat op
* `get_concatenation_axis` was renamed to `get_axis`. In order to provide backward compatibility `get_concatenation_axis` is now alis of `get_axis` method
* `set_concatenation_axis` was renamed to `set_axis`. In order to provide backward compatibility `set_concatenation_axis` is now alis of `set_axis` method
## `Parameter` and `Function` no longer take a type argument.
## Changes to Tensor read and write methods
......
......@@ -24,15 +24,15 @@ using namespace ngraph;
constexpr NodeTypeInfo op::Concat::type_info;
op::Concat::Concat(const OutputVector& args, size_t concatenation_axis)
op::Concat::Concat(const OutputVector& args, size_t axis)
: Op(args)
, m_concatenation_axis(concatenation_axis)
, m_axis(axis)
{
constructor_validate_and_infer_types();
}
op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
: Concat(as_output_vector(args), concatenation_axis)
op::Concat::Concat(const NodeVector& args, size_t axis)
: Concat(as_output_vector(args), axis)
{
}
......@@ -51,9 +51,9 @@ void op::Concat::validate_and_infer_types()
if (this_input_rank.is_static())
{
NODE_VALIDATION_CHECK(this,
m_concatenation_axis < size_t(this_input_rank),
m_axis < size_t(this_input_rank),
"Concatenation axis (",
m_concatenation_axis,
m_axis,
") is out of bounds for ",
"argument ",
i,
......@@ -61,15 +61,15 @@ void op::Concat::validate_and_infer_types()
this_input_shape,
".");
concatenation_axis_output_dim += this_input_shape[m_concatenation_axis];
this_input_shape[m_concatenation_axis] = Dimension::dynamic();
concatenation_axis_output_dim += this_input_shape[m_axis];
this_input_shape[m_axis] = Dimension::dynamic();
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape),
"Argument shapes are inconsistent; they must have the same rank, and must have ",
"equal dimension everywhere except on the concatenation axis (axis ",
m_concatenation_axis,
m_axis,
").");
NODE_VALIDATION_CHECK(
......@@ -87,7 +87,7 @@ void op::Concat::validate_and_infer_types()
if (concatenated_shape.rank().is_static())
{
concatenated_shape[m_concatenation_axis] = concatenation_axis_output_dim;
concatenated_shape[m_axis] = concatenation_axis_output_dim;
}
set_output_type(0, inputs_et, concatenated_shape);
......@@ -96,7 +96,7 @@ void op::Concat::validate_and_infer_types()
shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const
{
// TODO(amprocte): Should we check the new_args count here?
return make_shared<Concat>(new_args, m_concatenation_axis);
return make_shared<Concat>(new_args, m_axis);
}
void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......@@ -115,12 +115,12 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
{
auto arg_shape = value.get_shape();
auto slice_width = arg_shape[m_concatenation_axis];
auto slice_width = arg_shape[m_axis];
size_t next_pos = pos + slice_width;
arg_delta_slice_lower[m_concatenation_axis] = pos;
arg_delta_slice_upper[m_concatenation_axis] = next_pos;
arg_delta_slice_lower[m_axis] = pos;
arg_delta_slice_upper[m_axis] = next_pos;
adjoints.add_delta(
value,
......
......@@ -36,14 +36,14 @@ namespace ngraph
/// \brief Constructs a concatenation operation.
///
/// \param args The outputs producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
Concat(const OutputVector& args, size_t concatenation_axis);
/// \param axis The axis along which to concatenate the input tensors.
Concat(const OutputVector& args, size_t axis);
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
Concat(const NodeVector& args, size_t concatenation_axis);
/// \param axis The axis along which to concatenate the input tensors.
Concat(const NodeVector& args, size_t axis);
void validate_and_infer_types() override;
......@@ -51,16 +51,15 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The concatenation axis.
size_t get_concatenation_axis() const { return m_concatenation_axis; }
void set_concatenation_axis(size_t concatenation_axis)
{
m_concatenation_axis = concatenation_axis;
}
size_t get_concatenation_axis() const { return get_axis(); }
void set_concatenation_axis(size_t concatenation_axis) { set_axis(concatenation_axis); }
/// \return The concatenation axis.
size_t get_axis() const { return m_axis; }
void set_axis(size_t axis) { m_axis = axis; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
size_t m_concatenation_axis;
size_t m_axis;
};
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment