Commit 779a9300 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Adjust Concat to specification (#3582)

* Changed name of concat axis

* Clang styles applied

* Revert "Clang styles applied"

This reverts commit 4c1d3f4436765ce9eafa00140365d5d3e358eab1.

* Revert "Changed name of concat axis"

This reverts commit cdfe6638777509f21b199ce78fc2fe28cec15d45.

* Introduced alias methods

* Updated changes.md
parent bccb1ec8
...@@ -68,6 +68,11 @@ arguments now take type `CoordinateDiff` instead of `Shape`. `CoordinateDiff` is ...@@ -68,6 +68,11 @@ arguments now take type `CoordinateDiff` instead of `Shape`. `CoordinateDiff` is
`std::vector<std::ptrdiff_t>`, which "is like `size_t` but is allowed to be negative". Callers may `std::vector<std::ptrdiff_t>`, which "is like `size_t` but is allowed to be negative". Callers may
need to be adapted. need to be adapted.
## Changes to Concat op
* `get_concatenation_axis` was renamed to `get_axis`. In order to provide backward compatibility `get_concatenation_axis` is now alis of `get_axis` method
* `set_concatenation_axis` was renamed to `set_axis`. In order to provide backward compatibility `set_concatenation_axis` is now alis of `set_axis` method
## `Parameter` and `Function` no longer take a type argument. ## `Parameter` and `Function` no longer take a type argument.
## Changes to Tensor read and write methods ## Changes to Tensor read and write methods
......
...@@ -24,15 +24,15 @@ using namespace ngraph; ...@@ -24,15 +24,15 @@ using namespace ngraph;
constexpr NodeTypeInfo op::Concat::type_info; constexpr NodeTypeInfo op::Concat::type_info;
op::Concat::Concat(const OutputVector& args, size_t concatenation_axis) op::Concat::Concat(const OutputVector& args, size_t axis)
: Op(args) : Op(args)
, m_concatenation_axis(concatenation_axis) , m_axis(axis)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Concat::Concat(const NodeVector& args, size_t concatenation_axis) op::Concat::Concat(const NodeVector& args, size_t axis)
: Concat(as_output_vector(args), concatenation_axis) : Concat(as_output_vector(args), axis)
{ {
} }
...@@ -51,9 +51,9 @@ void op::Concat::validate_and_infer_types() ...@@ -51,9 +51,9 @@ void op::Concat::validate_and_infer_types()
if (this_input_rank.is_static()) if (this_input_rank.is_static())
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
m_concatenation_axis < size_t(this_input_rank), m_axis < size_t(this_input_rank),
"Concatenation axis (", "Concatenation axis (",
m_concatenation_axis, m_axis,
") is out of bounds for ", ") is out of bounds for ",
"argument ", "argument ",
i, i,
...@@ -61,15 +61,15 @@ void op::Concat::validate_and_infer_types() ...@@ -61,15 +61,15 @@ void op::Concat::validate_and_infer_types()
this_input_shape, this_input_shape,
"."); ".");
concatenation_axis_output_dim += this_input_shape[m_concatenation_axis]; concatenation_axis_output_dim += this_input_shape[m_axis];
this_input_shape[m_concatenation_axis] = Dimension::dynamic(); this_input_shape[m_axis] = Dimension::dynamic();
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape), PartialShape::merge_into(inputs_shape_scheme, this_input_shape),
"Argument shapes are inconsistent; they must have the same rank, and must have ", "Argument shapes are inconsistent; they must have the same rank, and must have ",
"equal dimension everywhere except on the concatenation axis (axis ", "equal dimension everywhere except on the concatenation axis (axis ",
m_concatenation_axis, m_axis,
")."); ").");
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
...@@ -87,7 +87,7 @@ void op::Concat::validate_and_infer_types() ...@@ -87,7 +87,7 @@ void op::Concat::validate_and_infer_types()
if (concatenated_shape.rank().is_static()) if (concatenated_shape.rank().is_static())
{ {
concatenated_shape[m_concatenation_axis] = concatenation_axis_output_dim; concatenated_shape[m_axis] = concatenation_axis_output_dim;
} }
set_output_type(0, inputs_et, concatenated_shape); set_output_type(0, inputs_et, concatenated_shape);
...@@ -96,7 +96,7 @@ void op::Concat::validate_and_infer_types() ...@@ -96,7 +96,7 @@ void op::Concat::validate_and_infer_types()
shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const
{ {
// TODO(amprocte): Should we check the new_args count here? // TODO(amprocte): Should we check the new_args count here?
return make_shared<Concat>(new_args, m_concatenation_axis); return make_shared<Concat>(new_args, m_axis);
} }
void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
...@@ -115,12 +115,12 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -115,12 +115,12 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
{ {
auto arg_shape = value.get_shape(); auto arg_shape = value.get_shape();
auto slice_width = arg_shape[m_concatenation_axis]; auto slice_width = arg_shape[m_axis];
size_t next_pos = pos + slice_width; size_t next_pos = pos + slice_width;
arg_delta_slice_lower[m_concatenation_axis] = pos; arg_delta_slice_lower[m_axis] = pos;
arg_delta_slice_upper[m_concatenation_axis] = next_pos; arg_delta_slice_upper[m_axis] = next_pos;
adjoints.add_delta( adjoints.add_delta(
value, value,
......
...@@ -36,14 +36,14 @@ namespace ngraph ...@@ -36,14 +36,14 @@ namespace ngraph
/// \brief Constructs a concatenation operation. /// \brief Constructs a concatenation operation.
/// ///
/// \param args The outputs producing the input tensors. /// \param args The outputs producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors. /// \param axis The axis along which to concatenate the input tensors.
Concat(const OutputVector& args, size_t concatenation_axis); Concat(const OutputVector& args, size_t axis);
/// \brief Constructs a concatenation operation. /// \brief Constructs a concatenation operation.
/// ///
/// \param args The nodes producing the input tensors. /// \param args The nodes producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors. /// \param axis The axis along which to concatenate the input tensors.
Concat(const NodeVector& args, size_t concatenation_axis); Concat(const NodeVector& args, size_t axis);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -51,16 +51,15 @@ namespace ngraph ...@@ -51,16 +51,15 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
/// \return The concatenation axis. /// \return The concatenation axis.
size_t get_concatenation_axis() const { return m_concatenation_axis; } size_t get_concatenation_axis() const { return get_axis(); }
void set_concatenation_axis(size_t concatenation_axis) void set_concatenation_axis(size_t concatenation_axis) { set_axis(concatenation_axis); }
{ /// \return The concatenation axis.
m_concatenation_axis = concatenation_axis; size_t get_axis() const { return m_axis; }
} void set_axis(size_t axis) { m_axis = axis; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
size_t m_concatenation_axis; size_t m_axis;
}; };
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment