Unverified Commit be111fdb authored by Evgenya Stepyreva's avatar Evgenya Stepyreva Committed by GitHub

Dynamic Squeeze/Unsqueeze type and rank propagation (#4270)

* Squeeze/Unsqueeze dynamic input type/rank infer

* Unit-tests

* style

* Removed squeeze Rank propagation

* Fixed comment

* Revert comment back

* Comment resolved

* Style fixes

* Moved unsqueeze axis check

* Style

* Discussion resolved

* Style

* Assert in decompose_op, if output shape is not static

* Style
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 1e6224f0
...@@ -32,20 +32,20 @@ op::Squeeze::Squeeze(const Output<Node>& data, const Output<Node>& axes) ...@@ -32,20 +32,20 @@ op::Squeeze::Squeeze(const Output<Node>& data, const Output<Node>& axes)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
NodeVector op::Squeeze::decompose_op() const void op::Squeeze::pre_validate_and_infer_types()
{ {
auto data = input_value(0); auto data = input_value(0);
auto axes_node = input_value(1).get_node_shared_ptr(); auto axes_node = input_value(1).get_node_shared_ptr();
// Currently only support Constant node for axes. if (data.get_partial_shape().is_dynamic() || !axes_node->is_constant())
NODE_VALIDATION_CHECK(this, {
axes_node->is_constant(), set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
"doesn't support 'axes' input of other type than a Constant."); return;
}
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node); auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->cast_vector<size_t>(); auto axes = axes_constant->cast_vector<size_t>();
auto data_shape = data.get_shape(); auto data_shape = data.get_shape();
std::vector<uint64_t> axes_to_squeeze(data_shape.size()); std::vector<uint64_t> axes_to_squeeze(data_shape.size());
...@@ -87,6 +87,18 @@ NodeVector op::Squeeze::decompose_op() const ...@@ -87,6 +87,18 @@ NodeVector op::Squeeze::decompose_op() const
} }
} }
set_output_type(0, get_input_element_type(0), output_data_shape);
}
NodeVector op::Squeeze::decompose_op() const
{
NODE_VALIDATION_CHECK(
this,
(get_output_partial_shape(0).is_static()),
"output shape was not calculated during pre_validate_and_infer_types. Can not decompose.");
auto data = input_value(0);
auto data_shape = data.get_shape();
auto output_data_shape = get_output_shape(0);
AxisVector input_order{get_default_order(data_shape.size())}; AxisVector input_order{get_default_order(data_shape.size())};
return {make_shared<op::Reshape>(data, input_order, output_data_shape)}; return {make_shared<op::Reshape>(data, input_order, output_data_shape)};
} }
......
...@@ -38,6 +38,7 @@ namespace ngraph ...@@ -38,6 +38,7 @@ namespace ngraph
Squeeze(const Output<Node>& data, const Output<Node>& axes); Squeeze(const Output<Node>& data, const Output<Node>& axes);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -33,31 +33,35 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes) ...@@ -33,31 +33,35 @@ op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
} }
void op::Unsqueeze::pre_validate_and_infer_types() void op::Unsqueeze::pre_validate_and_infer_types()
{
auto axes_node = input_value(1).get_node_shared_ptr();
// Currently only support Constant node for axes.
NODE_VALIDATION_CHECK(this,
axes_node->is_constant(),
"doesn't support 'axes' input of other type than a Constant.");
}
NodeVector op::Unsqueeze::decompose_op() const
{ {
auto data = input_value(0); auto data = input_value(0);
auto axes_node = input_value(1).get_node_shared_ptr(); auto axes_node = input_value(1).get_node_shared_ptr();
if (data.get_partial_shape().rank().is_dynamic() || !axes_node->is_constant())
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
return;
}
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node); auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->cast_vector<size_t>(); auto axes = axes_constant->cast_vector<size_t>();
auto data_shape = data.get_shape();
NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory."); NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory.");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
axes.size() == set<int64_t>(begin(axes), end(axes)).size(), axes.size() == set<int64_t>(begin(axes), end(axes)).size(),
"'axes' input has a duplicate axis."); "'axes' input has a duplicate axis.");
if (data.get_partial_shape().is_dynamic())
{
set_output_type(0,
get_input_element_type(0),
PartialShape::dynamic(data.get_partial_shape().rank() + axes.size()));
return;
}
auto data_shape = data.get_shape();
sort(begin(axes), end(axes), less<int64_t>()); sort(begin(axes), end(axes), less<int64_t>());
AxisVector input_order{ngraph::get_default_order(data_shape.size())}; AxisVector input_order{ngraph::get_default_order(data_shape.size())};
...@@ -69,8 +73,20 @@ NodeVector op::Unsqueeze::decompose_op() const ...@@ -69,8 +73,20 @@ NodeVector op::Unsqueeze::decompose_op() const
data_shape.insert(next(begin(data_shape), axis), 1); data_shape.insert(next(begin(data_shape), axis), 1);
} }
set_output_type(0, get_input_element_type(0), data_shape);
}
return {make_shared<ngraph::op::Reshape>(data, input_order, data_shape)}; NodeVector op::Unsqueeze::decompose_op() const
{
NODE_VALIDATION_CHECK(
this,
(get_output_partial_shape(0).is_static()),
"output shape was not calculated during pre_validate_and_infer_types. Can not decompose.");
auto data = input_value(0);
auto data_shape = data.get_shape();
auto output_shape = get_output_shape(0);
AxisVector input_order{ngraph::get_default_order(data_shape.size())};
return {make_shared<ngraph::op::Reshape>(data, input_order, output_shape)};
} }
shared_ptr<Node> op::Unsqueeze::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Unsqueeze::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -37,3 +37,21 @@ TEST(type_prop, squeeze) ...@@ -37,3 +37,21 @@ TEST(type_prop, squeeze)
ASSERT_EQ(squeeze_default_axes->get_element_type(), element::f32); ASSERT_EQ(squeeze_default_axes->get_element_type(), element::f32);
ASSERT_EQ(squeeze_default_axes->get_shape(), (Shape{4, 4, 8})); ASSERT_EQ(squeeze_default_axes->get_shape(), (Shape{4, 4, 8}));
} }
TEST(type_prop, squeeze_dynamic)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(6));
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{0, 2});
auto squeeze = make_shared<op::Squeeze>(param, axes_node);
ASSERT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_TRUE(squeeze->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
axes_node = make_shared<ngraph::op::Constant>(element::u64, Shape{0}, vector<int64_t>{});
auto squeeze_default_axes = make_shared<op::Squeeze>(param, axes_node);
ASSERT_EQ(squeeze_default_axes->get_element_type(), element::f32);
EXPECT_TRUE(
squeeze_default_axes->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
...@@ -26,8 +26,19 @@ TEST(type_prop, unsqueeze) ...@@ -26,8 +26,19 @@ TEST(type_prop, unsqueeze)
auto param = make_shared<op::Parameter>(element::f32, Shape{4, 1, 4, 1, 8}); auto param = make_shared<op::Parameter>(element::f32, Shape{4, 1, 4, 1, 8});
auto axes_node = auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2}); make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2});
auto squeeze = make_shared<op::Unsqueeze>(param, axes_node); auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node);
ASSERT_EQ(squeeze->get_element_type(), element::f32); ASSERT_EQ(unsqueeze->get_element_type(), element::f32);
ASSERT_EQ(squeeze->get_shape(), (Shape{4, 1, 1, 1, 4, 1, 8})); ASSERT_EQ(unsqueeze->get_shape(), (Shape{4, 1, 1, 1, 4, 1, 8}));
}
TEST(type_prop, unsqueeze_dynamic)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(5));
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2});
auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node);
ASSERT_EQ(unsqueeze->get_element_type(), element::f32);
EXPECT_TRUE(unsqueeze->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(7)));
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment