Unverified Commit b007440a authored by Evgenya Stepyreva's avatar Evgenya Stepyreva Committed by GitHub

Squeeze shape inference for dynamic input shape (#4451)

* Squeeze shape inference for dynamic input shape

* Style

* diyessi comments resolved !PR4451

* uint64_t

* [ Unsqueeze ] Dynamic shape inference

* Adjusted tests to dynamic shapes.

* Squeeze refactor.

* Unsqueeze refactor.

* Added test for axis with invaild value.

* Changed variables' names.
Co-authored-by: 's avatarEwa21 <ewa.tusien@intel.com>
parent d0df9e0c
......@@ -38,25 +38,34 @@ void op::Squeeze::pre_validate_and_infer_types()
auto data = input_value(0);
auto axes_node = input_value(1).get_node_shared_ptr();
if (data.get_partial_shape().is_dynamic() || !axes_node->is_constant())
bool data_has_dynamic_rank = data.get_partial_shape().rank().is_dynamic();
bool data_has_dynamic_shape = data.get_partial_shape().is_dynamic();
auto axes_constant = as_type_ptr<op::v0::Constant>(axes_node);
bool axes_is_empty_constant =
(axes_constant) ? axes_constant->cast_vector<int64_t>().empty() : false;
if (data_has_dynamic_rank || !axes_constant ||
(data_has_dynamic_shape && axes_is_empty_constant))
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
return;
}
auto data_shape = data.get_shape();
auto data_partial_shape = data.get_partial_shape();
uint64_t data_rank = data_partial_shape.rank().get_length();
// Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = normalize_axes(
this->description(), axes_constant->cast_vector<int64_t>(), data_shape.size());
auto axes =
normalize_axes(this->description(), axes_constant->cast_vector<int64_t>(), data_rank);
// Prepare set of unique axes marked to be removed from input data.
std::vector<uint64_t> axes_to_squeeze(data_shape.size());
if (axes.empty())
vector<uint64_t> axes_to_squeeze(data_rank);
if (axes_is_empty_constant)
{
auto data_shape = data.get_shape();
// Default behaviour is to remove all single dimension axes.
for (size_t idx = 0; idx < data_shape.size(); ++idx)
for (uint64_t idx = 0; idx < data_rank; ++idx)
{
if (data_shape.at(idx) == 1)
{
......@@ -73,24 +82,27 @@ void op::Squeeze::pre_validate_and_infer_types()
set<size_t, greater<size_t>> unique_axes(begin(axes), end(axes));
for (uint64_t axis : unique_axes)
{
if (!data_has_dynamic_shape)
{
auto data_shape = data.get_shape();
NODE_VALIDATION_CHECK(
this,
(data_shape.at(axis) == 1),
"provided axis value is invalid. Only axes of size 1 may be removed.");
}
axes_to_squeeze.at(axis) = 1;
}
}
Shape output_data_shape;
for (size_t idx = 0; idx < data_shape.size(); ++idx)
vector<Dimension> output_data_shape;
for (uint64_t idx = 0; idx < data_rank; ++idx)
{
if (axes_to_squeeze.at(idx) == 0)
{
output_data_shape.push_back(data_shape.at(idx));
output_data_shape.push_back(data_partial_shape[idx]);
}
}
set_output_type(0, get_input_element_type(0), output_data_shape);
set_output_type(0, get_input_element_type(0), PartialShape(output_data_shape));
}
bool ngraph::op::v0::Squeeze::visit_attributes(AttributeVisitor& visitor)
......
......@@ -36,8 +36,10 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
void op::Unsqueeze::pre_validate_and_infer_types()
{
const auto data = input_value(0);
auto data_partial_shape = data.get_partial_shape();
const auto data_rank = data_partial_shape.rank();
const auto axes_node = input_value(1).get_node_shared_ptr();
const auto data_rank = data.get_partial_shape().rank();
if (data_rank.is_dynamic() || !axes_node->is_constant())
{
......@@ -45,10 +47,12 @@ void op::Unsqueeze::pre_validate_and_infer_types()
return;
}
uint64_t data_rank_value = data_partial_shape.rank().get_length();
// Get value of axes from Constant
const auto axes_constant = as_type_ptr<op::Constant>(axes_node);
const auto axes_constant = as_type_ptr<op::v0::Constant>(axes_node);
const auto axes_values = axes_constant->cast_vector<int64_t>();
const auto expanded_rank = data_rank.get_length() + axes_values.size();
const auto expanded_rank = data_rank_value + axes_values.size();
auto axes = normalize_axes(this->description(), axes_values, expanded_rank);
NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory.");
......@@ -56,27 +60,17 @@ void op::Unsqueeze::pre_validate_and_infer_types()
axes.size() == set<int64_t>(begin(axes), end(axes)).size(),
"'axes' input has a duplicate axis.");
if (data.get_partial_shape().is_dynamic())
{
set_output_type(0,
get_input_element_type(0),
PartialShape::dynamic(data.get_partial_shape().rank() + axes.size()));
return;
}
auto data_shape = data.get_shape();
sort(begin(axes), end(axes), less<int64_t>());
AxisVector input_order{ngraph::get_default_order(data_shape.size())};
vector<Dimension> output_shape{data_partial_shape};
for (auto axis : axes)
{
NODE_VALIDATION_CHECK(
this, axis <= data_shape.size(), "provided 'axes' value ", axis, " is not valid.");
this, axis <= expanded_rank, "provided 'axes' value ", axis, " is not valid.");
data_shape.insert(next(begin(data_shape), axis), 1);
output_shape.insert(next(begin(output_shape), axis), 1);
}
set_output_type(0, get_input_element_type(0), data_shape);
set_output_type(0, get_input_element_type(0), PartialShape{output_shape});
}
NodeVector op::Unsqueeze::decompose_op() const
......
......@@ -46,7 +46,8 @@ TEST(type_prop, squeeze_dynamic)
auto squeeze = make_shared<op::Squeeze>(param, axes_node);
ASSERT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_TRUE(squeeze->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
EXPECT_TRUE(squeeze->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
axes_node = make_shared<ngraph::op::Constant>(element::u64, Shape{0}, vector<int64_t>{});
auto squeeze_default_axes = make_shared<op::Squeeze>(param, axes_node);
......@@ -55,3 +56,25 @@ TEST(type_prop, squeeze_dynamic)
EXPECT_TRUE(
squeeze_default_axes->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, squeeze_axes_invalid_value)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{0, 2});
try
{
auto squeeze = make_shared<op::Squeeze>(param, axes_node);
FAIL() << "Squeeze axis invalid value not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"provided axis value is invalid. Only axes of size 1 may be removed.");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
......@@ -40,5 +40,12 @@ TEST(type_prop, unsqueeze_dynamic)
auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node);
ASSERT_EQ(unsqueeze->get_element_type(), element::f32);
EXPECT_TRUE(unsqueeze->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(7)));
EXPECT_TRUE(
unsqueeze->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(),
1,
1,
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment