Commit f349593d authored by mozga-intel's avatar mozga-intel Committed by Scott Cyphers

Concat operator, negative indexing support (#3708)

* Concat operator is updated to support dynamic_shape
1) Added new concat_negative_indexing test
2) Replaced size_t -> int64_t
3) Support for a negative indexing, calculate axis = axis + int64(input_rank)
   if (axis < 0) { axis = axis + int64_t(this_input_rank); }

* Remove unwanted #include "ngraph/op/constant.hpp" header

* Refactoring:
1) The name of variable is replaced: m_concatenation_axis -> m_axis

* Concat negative indexing test is adjusted to support dynamic_shape tensor
auto pshape_a = PartialShape::dynamic(); for each tensor {a,b,c} result tensor has dynamic_shape

* The backend supports dynamic shapes
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);

* Other shape is supported by concat:
set_output_type(0, inputs_et, PartialShape::dynamic(concatenation_axis_output_dim));

* The NODE_VALIDATION_CHECK was moved up to be for a dynamic_shape

* [Test] The shape of output tensor was changed
[Concat CPU] Added support for a negative indexing on a cpu

* Review changes:
1) Added axis re-calculate for a reference version of concat
2) axis is not replied

* Review changes: support for a negative axis

* Comment about variable is added to concat.hpp file
Removed unused variable
parent 782afe31
......@@ -24,14 +24,14 @@ using namespace ngraph;
constexpr NodeTypeInfo op::Concat::type_info;
op::Concat::Concat(const OutputVector& args, size_t axis)
op::Concat::Concat(const OutputVector& args, int64_t axis)
: Op(args)
, m_axis(axis)
{
constructor_validate_and_infer_types();
}
op::Concat::Concat(const NodeVector& args, size_t axis)
op::Concat::Concat(const NodeVector& args, int64_t axis)
: Concat(as_output_vector(args), axis)
{
}
......@@ -46,14 +46,24 @@ void op::Concat::validate_and_infer_types()
for (uint64_t i = 0; i < get_input_size(); i++)
{
NODE_VALIDATION_CHECK(this,
element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)),
"Argument element types are inconsistent.");
PartialShape this_input_shape = get_input_partial_shape(i);
Dimension this_input_rank = this_input_shape.rank();
if (this_input_rank.is_static())
{
if (get_concatenation_axis() < 0)
{
set_concatenation_axis(get_axis() < 0
? get_axis() + static_cast<int64_t>(this_input_rank)
: get_axis());
}
auto concat_axis = get_concatenation_axis();
NODE_VALIDATION_CHECK(this,
m_axis < size_t(this_input_rank),
concat_axis < static_cast<int64_t>(this_input_rank),
"Concatenation axis (",
m_axis,
concat_axis,
") is out of bounds for ",
"argument ",
i,
......@@ -61,36 +71,33 @@ void op::Concat::validate_and_infer_types()
this_input_shape,
".");
concatenation_axis_output_dim += this_input_shape[m_axis];
this_input_shape[m_axis] = Dimension::dynamic();
concatenation_axis_output_dim += this_input_shape[concat_axis];
this_input_shape[concat_axis] = Dimension::dynamic();
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(inputs_shape_scheme, this_input_shape),
"Argument shapes are inconsistent; they must have the same rank, and must have ",
"equal dimension everywhere except on the concatenation axis (axis ",
m_axis,
concat_axis,
").");
NODE_VALIDATION_CHECK(
this,
element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)),
"Argument element types are inconsistent.");
}
else
{
concatenation_axis_output_dim += Dimension::dynamic();
}
}
PartialShape concatenated_shape = inputs_shape_scheme;
if (concatenated_shape.rank().is_static())
{
concatenated_shape[m_axis] = concatenation_axis_output_dim;
concatenated_shape[get_concatenation_axis()] = concatenation_axis_output_dim;
set_output_type(0, inputs_et, concatenated_shape);
}
else
{
set_output_type(0, inputs_et, PartialShape::dynamic(concatenation_axis_output_dim));
}
set_output_type(0, inputs_et, concatenated_shape);
}
shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) const
......@@ -118,7 +125,6 @@ void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
auto slice_width = arg_shape[m_axis];
size_t next_pos = pos + slice_width;
arg_delta_slice_lower[m_axis] = pos;
arg_delta_slice_upper[m_axis] = next_pos;
......
......@@ -37,13 +37,13 @@ namespace ngraph
///
/// \param args The outputs producing the input tensors.
/// \param axis The axis along which to concatenate the input tensors.
Concat(const OutputVector& args, size_t axis);
Concat(const OutputVector& args, int64_t axis);
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
/// \param axis The axis along which to concatenate the input tensors.
Concat(const NodeVector& args, size_t axis);
Concat(const NodeVector& args, int64_t axis);
void validate_and_infer_types() override;
......@@ -51,15 +51,21 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The concatenation axis.
size_t get_concatenation_axis() const { return get_axis(); }
void set_concatenation_axis(size_t concatenation_axis) { set_axis(concatenation_axis); }
int64_t get_concatenation_axis() const { return m_concat_axis; }
void set_concatenation_axis(int64_t concatenation_axis)
{
m_concat_axis = concatenation_axis;
}
/// \return The concatenation axis.
size_t get_axis() const { return m_axis; }
void set_axis(size_t axis) { m_axis = axis; }
int64_t get_axis() const { return m_axis; }
void set_axis(int64_t axis) { m_axis = axis; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
size_t m_axis;
/// \ brief m_axis stores default value for all iterations
int64_t m_axis;
/// \brief m_concat_axis stores m_axis plus the number of rank for each iteration
int64_t m_concat_axis = -1;
};
}
}
......@@ -165,7 +165,7 @@ static bool simplify_concat(shared_ptr<Node> n)
}
auto concat = static_pointer_cast<op::Concat>(n);
size_t concat_axis = concat->get_concatenation_axis();
auto concat_axis = concat->get_concatenation_axis();
auto slice_shape = branch_tip->get_users(true).at(0)->get_shape();
size_t slice_axis = numeric_limits<size_t>::max();
......
......@@ -38,7 +38,7 @@ namespace ngraph
std::vector<Shape> input_shapes,
void* output,
Shape output_shape,
size_t axis)
int64_t axis)
{
Eigen::array<Eigen::Index, Rank> out_dims;
for (int i = 0; i < Rank; i++)
......@@ -50,7 +50,6 @@ namespace ngraph
Eigen::array<Eigen::Index, Rank> in_dims, concat_pos;
concat_pos.fill(static_cast<Eigen::Index>(0));
for (size_t i = 0; i < input_shapes.size(); i++)
{
for (int j = 0; j < Rank; j++)
......
......@@ -1201,7 +1201,7 @@ namespace ngraph
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t concat_dim = concat->get_concatenation_axis();
auto concat_dim = concat->get_concatenation_axis();
mkldnn::primitive_attr attr;
attr.set_scratchpad_mode(mkldnn::scratchpad_mode::user);
......@@ -1472,7 +1472,7 @@ namespace ngraph
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t concat_dim = concat->get_concatenation_axis();
auto concat_dim = concat->get_concatenation_axis();
// concat primitive descriptor
return mkldnn::concat::primitive_desc(
......
......@@ -2452,7 +2452,7 @@ namespace ngraph
vector<memory::desc>& o_mds)
{
auto concat = static_cast<const T*>(node.get());
size_t concat_dim = concat->get_concatenation_axis();
auto concat_dim = concat->get_concatenation_axis();
auto result_desc = mkldnn_utils::create_default_mkldnn_md(
node.get(), 0, true, memory::FORMAT::any);
#if MKLDNN_VERSION_MAJOR < 1
......
......@@ -32,12 +32,11 @@ namespace ngraph
T* out,
const std::vector<Shape>& in_shapes,
const Shape& out_shape,
size_t concatenation_axis)
int64_t concatenation_axis)
{
// We will copy the inputs to the output one at a time. As we go, we will move out
// along the concatenation axis, starting at 0.
size_t concatenation_pos = 0;
for (size_t i = 0; i < args.size(); i++)
{
// CoordinateTransform gets confused when the last input has a zero-size dim, so
......
......@@ -27,6 +27,35 @@ using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, concat_negative_axis)
{
auto pshape_a = PartialShape::dynamic();
auto A = make_shared<op::Parameter>(element::f32, pshape_a);
auto pshape_b = PartialShape::dynamic();
auto B = make_shared<op::Parameter>(element::f32, pshape_b);
auto pshape_c = PartialShape::dynamic();
auto C = make_shared<op::Parameter>(element::f32, pshape_c);
auto pshape_r = PartialShape::dynamic();
auto f = make_shared<Function>(make_shared<op::Concat>(NodeVector{A, B, C}, -1),
ParameterVector{A, B, C});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, Shape{2, 2});
copy_data(a, vector<float>{2, 4, 8, 16});
auto b = backend->create_tensor(element::f32, Shape{2, 3});
copy_data(b, vector<float>{1, 2, 4, 8, 16, 32});
auto c = backend->create_tensor(element::f32, Shape{2, 3});
copy_data(c, vector<float>{2, 3, 5, 7, 11, 13});
auto result = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b, c});
ASSERT_EQ(result->get_shape(), (Shape{2, 8}));
EXPECT_TRUE(
test::all_close_f((vector<float>{2, 4, 1, 2, 4, 2, 3, 5, 8, 16, 8, 16, 32, 7, 11, 13}),
read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, concat_matrix_colwise)
{
Shape shape_a{2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment