Unverified Commit b007440a authored by Evgenya Stepyreva's avatar Evgenya Stepyreva Committed by GitHub

Squeeze shape inference for dynamic input shape (#4451)

* Squeeze shape inference for dynamic input shape

* Style

* diyessi comments resolved !PR4451

* uint64_t

* [ Unsqueeze ] Dynamic shape inference

* Adjusted tests to dynamic shapes.

* Squeeze refactor.

* Unsqueeze refactor.

* Added test for axis with invaild value.

* Changed variables' names.
Co-authored-by: 's avatarEwa21 <ewa.tusien@intel.com>
parent d0df9e0c
...@@ -38,25 +38,34 @@ void op::Squeeze::pre_validate_and_infer_types() ...@@ -38,25 +38,34 @@ void op::Squeeze::pre_validate_and_infer_types()
auto data = input_value(0); auto data = input_value(0);
auto axes_node = input_value(1).get_node_shared_ptr(); auto axes_node = input_value(1).get_node_shared_ptr();
if (data.get_partial_shape().is_dynamic() || !axes_node->is_constant()) bool data_has_dynamic_rank = data.get_partial_shape().rank().is_dynamic();
bool data_has_dynamic_shape = data.get_partial_shape().is_dynamic();
auto axes_constant = as_type_ptr<op::v0::Constant>(axes_node);
bool axes_is_empty_constant =
(axes_constant) ? axes_constant->cast_vector<int64_t>().empty() : false;
if (data_has_dynamic_rank || !axes_constant ||
(data_has_dynamic_shape && axes_is_empty_constant))
{ {
set_output_type(0, get_input_element_type(0), PartialShape::dynamic()); set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
return; return;
} }
auto data_shape = data.get_shape(); auto data_partial_shape = data.get_partial_shape();
uint64_t data_rank = data_partial_shape.rank().get_length();
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node); auto axes =
auto axes = normalize_axes( normalize_axes(this->description(), axes_constant->cast_vector<int64_t>(), data_rank);
this->description(), axes_constant->cast_vector<int64_t>(), data_shape.size());
// Prepare set of unique axes marked to be removed from input data. // Prepare set of unique axes marked to be removed from input data.
std::vector<uint64_t> axes_to_squeeze(data_shape.size()); vector<uint64_t> axes_to_squeeze(data_rank);
if (axes.empty()) if (axes_is_empty_constant)
{ {
auto data_shape = data.get_shape();
// Default behaviour is to remove all single dimension axes. // Default behaviour is to remove all single dimension axes.
for (size_t idx = 0; idx < data_shape.size(); ++idx) for (uint64_t idx = 0; idx < data_rank; ++idx)
{ {
if (data_shape.at(idx) == 1) if (data_shape.at(idx) == 1)
{ {
...@@ -73,24 +82,27 @@ void op::Squeeze::pre_validate_and_infer_types() ...@@ -73,24 +82,27 @@ void op::Squeeze::pre_validate_and_infer_types()
set<size_t, greater<size_t>> unique_axes(begin(axes), end(axes)); set<size_t, greater<size_t>> unique_axes(begin(axes), end(axes));
for (uint64_t axis : unique_axes) for (uint64_t axis : unique_axes)
{ {
if (!data_has_dynamic_shape)
{
auto data_shape = data.get_shape();
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, this,
(data_shape.at(axis) == 1), (data_shape.at(axis) == 1),
"provided axis value is invalid. Only axes of size 1 may be removed."); "provided axis value is invalid. Only axes of size 1 may be removed.");
}
axes_to_squeeze.at(axis) = 1; axes_to_squeeze.at(axis) = 1;
} }
} }
Shape output_data_shape; vector<Dimension> output_data_shape;
for (size_t idx = 0; idx < data_shape.size(); ++idx) for (uint64_t idx = 0; idx < data_rank; ++idx)
{ {
if (axes_to_squeeze.at(idx) == 0) if (axes_to_squeeze.at(idx) == 0)
{ {
output_data_shape.push_back(data_shape.at(idx)); output_data_shape.push_back(data_partial_shape[idx]);
} }
} }
set_output_type(0, get_input_element_type(0), PartialShape(output_data_shape));
set_output_type(0, get_input_element_type(0), output_data_shape);
} }
bool ngraph::op::v0::Squeeze::visit_attributes(AttributeVisitor& visitor) bool ngraph::op::v0::Squeeze::visit_attributes(AttributeVisitor& visitor)
......
...@@ -36,8 +36,10 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes) ...@@ -36,8 +36,10 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
void op::Unsqueeze::pre_validate_and_infer_types() void op::Unsqueeze::pre_validate_and_infer_types()
{ {
const auto data = input_value(0); const auto data = input_value(0);
auto data_partial_shape = data.get_partial_shape();
const auto data_rank = data_partial_shape.rank();
const auto axes_node = input_value(1).get_node_shared_ptr(); const auto axes_node = input_value(1).get_node_shared_ptr();
const auto data_rank = data.get_partial_shape().rank();
if (data_rank.is_dynamic() || !axes_node->is_constant()) if (data_rank.is_dynamic() || !axes_node->is_constant())
{ {
...@@ -45,10 +47,12 @@ void op::Unsqueeze::pre_validate_and_infer_types() ...@@ -45,10 +47,12 @@ void op::Unsqueeze::pre_validate_and_infer_types()
return; return;
} }
uint64_t data_rank_value = data_partial_shape.rank().get_length();
// Get value of axes from Constant // Get value of axes from Constant
const auto axes_constant = as_type_ptr<op::Constant>(axes_node); const auto axes_constant = as_type_ptr<op::v0::Constant>(axes_node);
const auto axes_values = axes_constant->cast_vector<int64_t>(); const auto axes_values = axes_constant->cast_vector<int64_t>();
const auto expanded_rank = data_rank.get_length() + axes_values.size(); const auto expanded_rank = data_rank_value + axes_values.size();
auto axes = normalize_axes(this->description(), axes_values, expanded_rank); auto axes = normalize_axes(this->description(), axes_values, expanded_rank);
NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory."); NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory.");
...@@ -56,27 +60,17 @@ void op::Unsqueeze::pre_validate_and_infer_types() ...@@ -56,27 +60,17 @@ void op::Unsqueeze::pre_validate_and_infer_types()
axes.size() == set<int64_t>(begin(axes), end(axes)).size(), axes.size() == set<int64_t>(begin(axes), end(axes)).size(),
"'axes' input has a duplicate axis."); "'axes' input has a duplicate axis.");
if (data.get_partial_shape().is_dynamic())
{
set_output_type(0,
get_input_element_type(0),
PartialShape::dynamic(data.get_partial_shape().rank() + axes.size()));
return;
}
auto data_shape = data.get_shape();
sort(begin(axes), end(axes), less<int64_t>()); sort(begin(axes), end(axes), less<int64_t>());
AxisVector input_order{ngraph::get_default_order(data_shape.size())}; vector<Dimension> output_shape{data_partial_shape};
for (auto axis : axes) for (auto axis : axes)
{ {
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, axis <= data_shape.size(), "provided 'axes' value ", axis, " is not valid."); this, axis <= expanded_rank, "provided 'axes' value ", axis, " is not valid.");
data_shape.insert(next(begin(data_shape), axis), 1); output_shape.insert(next(begin(output_shape), axis), 1);
} }
set_output_type(0, get_input_element_type(0), data_shape); set_output_type(0, get_input_element_type(0), PartialShape{output_shape});
} }
NodeVector op::Unsqueeze::decompose_op() const NodeVector op::Unsqueeze::decompose_op() const
......
...@@ -46,7 +46,8 @@ TEST(type_prop, squeeze_dynamic) ...@@ -46,7 +46,8 @@ TEST(type_prop, squeeze_dynamic)
auto squeeze = make_shared<op::Squeeze>(param, axes_node); auto squeeze = make_shared<op::Squeeze>(param, axes_node);
ASSERT_EQ(squeeze->get_element_type(), element::f32); ASSERT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_TRUE(squeeze->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
EXPECT_TRUE(squeeze->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
axes_node = make_shared<ngraph::op::Constant>(element::u64, Shape{0}, vector<int64_t>{}); axes_node = make_shared<ngraph::op::Constant>(element::u64, Shape{0}, vector<int64_t>{});
auto squeeze_default_axes = make_shared<op::Squeeze>(param, axes_node); auto squeeze_default_axes = make_shared<op::Squeeze>(param, axes_node);
...@@ -55,3 +56,25 @@ TEST(type_prop, squeeze_dynamic) ...@@ -55,3 +56,25 @@ TEST(type_prop, squeeze_dynamic)
EXPECT_TRUE( EXPECT_TRUE(
squeeze_default_axes->get_output_partial_shape(0).same_scheme(PartialShape::dynamic())); squeeze_default_axes->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
} }
TEST(type_prop, squeeze_axes_invalid_value)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{0, 2});
try
{
auto squeeze = make_shared<op::Squeeze>(param, axes_node);
FAIL() << "Squeeze axis invalid value not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"provided axis value is invalid. Only axes of size 1 may be removed.");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
...@@ -40,5 +40,12 @@ TEST(type_prop, unsqueeze_dynamic) ...@@ -40,5 +40,12 @@ TEST(type_prop, unsqueeze_dynamic)
auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node); auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node);
ASSERT_EQ(unsqueeze->get_element_type(), element::f32); ASSERT_EQ(unsqueeze->get_element_type(), element::f32);
EXPECT_TRUE(unsqueeze->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(7))); EXPECT_TRUE(
unsqueeze->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(),
1,
1,
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()}));
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment