Unverified Commit 9cdbf129 authored by Evgenya Stepyreva's avatar Evgenya Stepyreva Committed by GitHub

[ VariadicSplit ] Dynamic shape inference (#4462)

* [ VariadicSplit ] Dynamic shape inference

* Small code fixes

* Add tests for partial shape inference

* Style-apply

* Style-apply
Co-authored-by: 's avatarMichal Karzynski <michal.karzynski@intel.com>
Co-authored-by: 's avatarraramer01 <rebecca.ramer@intel.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 21645142
......@@ -59,14 +59,14 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
auto axis_input = input_value(1).get_node_shared_ptr();
auto split_lengths_input = input_value(2).get_node_shared_ptr();
auto data_shape = data.get_partial_shape();
auto data_type = data.get_element_type();
const auto& data_type = data.get_element_type();
set_output_size(num_outputs);
if (data_shape.is_static() && axis_input->is_constant() &&
if (data_shape.rank().is_static() && axis_input->is_constant() &&
split_lengths_input->is_constant())
{
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto axis_val = axis_input->cast_vector<int64_t>()[0];
const auto axis_input_constant = as_type_ptr<op::Constant>(axis_input);
auto axis_val = axis_input_constant->cast_vector<int64_t>()[0];
// Adjust split axis in case of negatives
int64_t axis = ngraph::normalize_axis(this, axis_val, data_shape.rank());
......@@ -99,25 +99,31 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
sum_of_splits += split_lengths[i];
}
}
auto data_shape_dims = vector<Dimension>{data.get_partial_shape()};
auto dimension_at_axis = data_shape_dims.at(axis);
if (negative_one > 0)
if (negative_one >= 0 && dimension_at_axis.is_static())
{
split_lengths[negative_one] = data_shape[axis].get_length() - sum_of_splits;
split_lengths[negative_one] = dimension_at_axis.get_length() - sum_of_splits;
sum_of_splits += split_lengths[negative_one];
}
if (data_shape[axis].is_static())
{
NODE_VALIDATION_CHECK(this,
sum_of_splits == data_shape[axis].get_length(),
"Total length of splits: ",
sum_of_splits,
" must match the length of the chosen axis: ",
data_shape[axis].get_length());
data_shape[axis]);
}
for (size_t output{0}; output < num_outputs; ++output)
{
auto tmp_shape = data_shape.to_shape();
tmp_shape.at(axis) = split_lengths.at(output);
set_output_type(output, data_type, tmp_shape);
auto output_split_dim = split_lengths.at(output) == -1 ? Dimension::dynamic()
: split_lengths.at(output);
auto tmp_shape = data_shape_dims;
tmp_shape.at(axis) = output_split_dim;
set_output_type(output, data_type, PartialShape{tmp_shape});
}
}
else
......
......@@ -41,6 +41,14 @@ TEST(type_prop, variadic_split)
.get_shape(),
(Shape{3, 6}));
EXPECT_EQ(make_shared<op::v1::VariadicSplit>(
make_shared<op::Parameter>(element::i32, Shape{12, 6}),
op::Constant::create<int64_t>(element::i64, Shape{}, {-2}),
op::Constant::create<int64_t>(element::i64, Shape{3}, {-1, 7, 2}))
->output(0)
.get_shape(),
(Shape{3, 6}));
EXPECT_EQ(make_shared<op::v1::VariadicSplit>(
make_shared<op::Parameter>(element::i32, Shape{12, 1, 6}),
op::Constant::create<int64_t>(element::i64, Shape{1}, {2}),
......@@ -148,3 +156,43 @@ TEST(type_prop, variadic_split_splits_multiple_negatives)
std::string("Cannot infer split with multiple -1 values at 0 and 1"));
}
}
TEST(type_prop, variadic_split_shape_partially_dynamic)
{
// Variadic split shape {12,?} into {7,?}, {3,?} and {2,?}
auto var_split1 = make_shared<op::v1::VariadicSplit>(
make_shared<op::Parameter>(element::i32, PartialShape{12, Dimension()}),
op::Constant::create<int64_t>(element::i64, Shape{}, {-2}),
op::Constant::create<int64_t>(element::i64, Shape{3}, {7, -1, 2}));
EXPECT_TRUE(
var_split1->get_output_partial_shape(0).same_scheme(PartialShape{7, Dimension::dynamic()}));
EXPECT_TRUE(
var_split1->get_output_partial_shape(1).same_scheme(PartialShape{3, Dimension::dynamic()}));
EXPECT_TRUE(
var_split1->get_output_partial_shape(2).same_scheme(PartialShape{2, Dimension::dynamic()}));
// Variadic split shape {?,?,6} into {?,?,3}, {?,?,1} and {?,?,2}
auto var_split2 = make_shared<op::v1::VariadicSplit>(
make_shared<op::Parameter>(element::i32, PartialShape{Dimension(), Dimension(), 6}),
op::Constant::create<int64_t>(element::i64, Shape{}, {2}),
op::Constant::create<int64_t>(element::i64, Shape{3}, {3, 1, 2}));
EXPECT_TRUE(var_split2->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3}));
EXPECT_TRUE(var_split2->get_output_partial_shape(1).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1}));
EXPECT_TRUE(var_split2->get_output_partial_shape(2).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), 2}));
// Variadic split shape {?,6} into {?,6}, and {?,0}
auto var_split3 = make_shared<op::v1::VariadicSplit>(
make_shared<op::Parameter>(element::i32, PartialShape{Dimension(), 6}),
op::Constant::create<int64_t>(element::i64, Shape{}, {1}),
op::Constant::create<int64_t>(element::i64, Shape{2}, {6, 0}));
EXPECT_TRUE(
var_split3->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 6}));
EXPECT_TRUE(
var_split3->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 0}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment