Commit 33a6a7d0 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Michał Karzyński

[SPEC] Support negative axis and negative split specification in v1::VariadicSplit (#3975)

parent 0efac225
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/variadic_split.hpp" #include "ngraph/op/variadic_split.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -38,11 +39,17 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types() ...@@ -38,11 +39,17 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
set_input_is_relevant_to_value(1); set_input_is_relevant_to_value(1);
set_input_is_relevant_to_value(2); set_input_is_relevant_to_value(2);
auto split_lengths_pshape_rank = get_input_partial_shape(2).rank(); auto split_lengths_pshape = get_input_partial_shape(2);
if (split_lengths_pshape_rank.is_static()) if (split_lengths_pshape.is_static())
{ {
auto num_outputs = static_cast<size_t>(split_lengths_pshape_rank); NODE_VALIDATION_CHECK(this,
static_cast<size_t>(split_lengths_pshape.rank()) == 1,
"Split lengths should be a 1-D tensor. Got ",
split_lengths_pshape.rank(),
" instead.");
auto num_outputs = static_cast<size_t>(split_lengths_pshape[0]);
auto data = input_value(0); auto data = input_value(0);
auto axis_input = input_value(1).get_node_shared_ptr(); auto axis_input = input_value(1).get_node_shared_ptr();
auto split_lengths_input = input_value(2).get_node_shared_ptr(); auto split_lengths_input = input_value(2).get_node_shared_ptr();
...@@ -53,31 +60,58 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types() ...@@ -53,31 +60,58 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
if (data_shape.is_static() && axis_input->is_constant() && if (data_shape.is_static() && axis_input->is_constant() &&
split_lengths_input->is_constant()) split_lengths_input->is_constant())
{ {
auto axis = as_type_ptr<op::Constant>(axis_input)->get_vector<size_t>()[0]; auto data_rank = static_cast<size_t>(data_shape.rank());
auto split_lengths = as_type_ptr<op::Constant>(axis_input)->get_vector<size_t>(); auto axis_val = as_type_ptr<op::Constant>(axis_input)->get_vector<int64_t>()[0];
auto splits_length = std::accumulate(split_lengths.begin(), split_lengths.end(), 0UL); // Adjust split axis in case of negatives
int64_t axis = ngraph::normalize_axis(this, axis_val, data_rank);
NODE_VALIDATION_CHECK(this, axis > 0, "Provided axis:", axis, " can not be negative"); auto split_lengths =
auto data_rank = static_cast<size_t>(data_shape.rank()); as_type_ptr<op::Constant>(split_lengths_input)->get_vector<int64_t>();
// Adjust split lengths in case of negatives
size_t sum_of_splits = 0;
int64_t negative_one = -1;
for (size_t i = 0; i < split_lengths.size(); i++)
{
NODE_VALIDATION_CHECK(this,
split_lengths[i] >= -1,
"Invalid value ",
split_lengths[i],
" in split lengths input. Should be >= -1.");
if (split_lengths[i] == -1)
{
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
axis < data_rank, negative_one == -1,
"Provided axis:", "Cannot infer split with multiple -1 values at ",
axis, negative_one,
" can not be higher than input data rank: ", " and ",
data_rank); i);
negative_one = i;
}
else
{
sum_of_splits += split_lengths[i];
}
}
if (negative_one > 0)
{
split_lengths[negative_one] = static_cast<size_t>(data_shape[axis]) - sum_of_splits;
sum_of_splits += split_lengths[negative_one];
}
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
splits_length == static_cast<size_t>(data_shape[axis]), sum_of_splits == static_cast<size_t>(data_shape[axis]),
"Total length of splits:", "Total length of splits: ",
splits_length, sum_of_splits,
" does not sum to length of the choosen axis: ", " must match the length of the chosen axis: ",
static_cast<size_t>(data_shape[axis])); static_cast<size_t>(data_shape[axis]));
for (size_t output{0}; output < num_outputs; ++output) for (size_t output{0}; output < num_outputs; ++output)
{ {
auto tmp_shape = data_shape.to_shape(); auto tmp_shape = data_shape.to_shape();
tmp_shape.at(axis) = split_lengths.at(axis); tmp_shape.at(axis) = split_lengths.at(output);
set_output_type(output, data_type, tmp_shape); set_output_type(output, data_type, tmp_shape);
} }
} }
......
...@@ -819,9 +819,9 @@ int64_t ngraph::normalize_axis(const std::string& node_description, ...@@ -819,9 +819,9 @@ int64_t ngraph::normalize_axis(const std::string& node_description,
// Accepted range of value for axis is [axis_range_min, axis_range_max]. // Accepted range of value for axis is [axis_range_min, axis_range_max].
NGRAPH_CHECK(((axis >= axis_range_min) && (axis <= axis_range_max)), NGRAPH_CHECK(((axis >= axis_range_min) && (axis <= axis_range_max)),
node_description, node_description,
"Parameter axis ", " Parameter axis ",
axis, axis,
" out of the tensor rank [-", " out of the tensor rank range [",
axis_range_min, axis_range_min,
", ", ", ",
axis_range_max, axis_range_max,
......
...@@ -187,6 +187,7 @@ set(SRC ...@@ -187,6 +187,7 @@ set(SRC
type_prop/transpose.cpp type_prop/transpose.cpp
type_prop/unary_elementwise.cpp type_prop/unary_elementwise.cpp
type_prop/unsqueeze.cpp type_prop/unsqueeze.cpp
type_prop/variadic_split.cpp
type_prop_benchmark.cpp type_prop_benchmark.cpp
type_prop_layers.cpp type_prop_layers.cpp
util.cpp util.cpp
......
...@@ -139,6 +139,6 @@ TEST(opset, check_opset1) ...@@ -139,6 +139,6 @@ TEST(opset, check_opset1)
CHECK_OPSET(op::v1::TopK, opset1::TopK) CHECK_OPSET(op::v1::TopK, opset1::TopK)
CHECK_OPSET(op::v0::Transpose, opset1::Transpose) CHECK_OPSET(op::v0::Transpose, opset1::Transpose)
CHECK_OPSET(op::v0::Unsqueeze, opset1::Unsqueeze) CHECK_OPSET(op::v0::Unsqueeze, opset1::Unsqueeze)
// TODO using op::v0::VariadicSplit CHECK_OPSET(op::v1::VariadicSplit, opset1::VariadicSplit)
CHECK_OPSET(op::v0::Xor, opset1::Xor) CHECK_OPSET(op::v0::Xor, opset1::Xor)
} }
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, variadic_split)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
const auto axis = op::Constant::create<int64_t>(element::i64, Shape{}, {1});
const auto splits = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 4});
const auto split = make_shared<op::v1::VariadicSplit>(data, axis, splits);
EXPECT_EQ(split->outputs().size(), 2);
EXPECT_EQ(split->output(0).get_shape(), (Shape{2, 2}));
EXPECT_EQ(split->output(1).get_shape(), (Shape{2, 4}));
EXPECT_EQ(split->output(0).get_element_type(), element::i32);
EXPECT_EQ(split->output(1).get_element_type(), element::i32);
EXPECT_EQ(make_shared<op::v1::VariadicSplit>(
make_shared<op::Parameter>(element::i32, Shape{12, 6}),
op::Constant::create<int64_t>(element::i64, Shape{}, {-2}),
op::Constant::create<int64_t>(element::i64, Shape{3}, {7, -1, 2}))
->output(1)
.get_shape(),
(Shape{3, 6}));
EXPECT_EQ(make_shared<op::v1::VariadicSplit>(
make_shared<op::Parameter>(element::i32, Shape{12, 1, 6}),
op::Constant::create<int64_t>(element::i64, Shape{1}, {2}),
op::Constant::create<int64_t>(element::i64, Shape{3}, {3, 1, 2}))
->output(2)
.get_shape(),
(Shape{12, 1, 2}));
EXPECT_EQ(make_shared<op::v1::VariadicSplit>(
make_shared<op::Parameter>(element::i32, Shape{12, 6}),
op::Constant::create<int64_t>(element::i64, Shape{1}, {1}),
op::Constant::create<int64_t>(element::i64, Shape{2}, {6, 0}))
->output(1)
.get_shape(),
(Shape{12, 0}));
}
TEST(type_prop, variadic_split_splits_rank)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
try
{
const auto axis = op::Constant::create<int64_t>(element::i64, Shape{}, {1});
const auto splits = op::Constant::create<int64_t>(element::i64, Shape{1, 2}, {2, 4});
const auto split = make_shared<op::v1::VariadicSplit>(data, axis, splits);
FAIL() << "Split node was created with incorrect data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Split lengths should be a 1-D tensor. Got 2 instead."));
}
}
TEST(type_prop, variadic_split_incorrect_sum)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
try
{
const auto axis = op::Constant::create<int64_t>(element::i64, Shape{}, {1});
const auto splits = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 6});
const auto split = make_shared<op::v1::VariadicSplit>(data, axis, splits);
FAIL() << "Split node was created with incorrect data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Total length of splits: 7 must match the length of the chosen axis: 6"));
}
}
TEST(type_prop, variadic_split_incorrect_axis)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
try
{
const auto axis = op::Constant::create<int64_t>(element::i64, Shape{}, {-5});
const auto splits = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 4});
const auto split = make_shared<op::v1::VariadicSplit>(data, axis, splits);
FAIL() << "Split node was created with incorrect data.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("Parameter axis -5 out of the tensor rank range [-2, 1]."));
}
}
TEST(type_prop, variadic_split_splits_invalid_negative)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
try
{
const auto axis = op::Constant::create<int64_t>(element::i64, Shape{}, {1});
const auto splits = op::Constant::create<int64_t>(element::i64, Shape{2}, {-2, 4});
const auto split = make_shared<op::v1::VariadicSplit>(data, axis, splits);
FAIL() << "Split node was created with incorrect data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("Invalid value -2 in split lengths input. Should be >= -1."));
}
}
TEST(type_prop, variadic_split_splits_multiple_negatives)
{
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
try
{
const auto axis = op::Constant::create<int64_t>(element::i64, Shape{}, {1});
const auto splits = op::Constant::create<int64_t>(element::i64, Shape{3}, {-1, -1, 3});
const auto split = make_shared<op::v1::VariadicSplit>(data, axis, splits);
FAIL() << "Split node was created with incorrect data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Cannot infer split with multiple -1 values at 0 and 1"));
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment