Commit 64479eb0 authored by Gleb Kazantaev's avatar Gleb Kazantaev Committed by Scott Cyphers

Constant folding support for V1 Reshape (#3833)

* Constant folding support for V1 Reshape

* Fixed constant folding tests
parent 73fff9f4
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <numeric> #include <numeric>
#include "constant_folding.hpp" #include "constant_folding.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -26,7 +26,7 @@ using namespace ngraph; ...@@ -26,7 +26,7 @@ using namespace ngraph;
template <class T> template <class T>
shared_ptr<op::Constant> fold_constant_dyn_reshape(shared_ptr<op::Constant> constant_data, shared_ptr<op::Constant> fold_constant_dyn_reshape(shared_ptr<op::Constant> constant_data,
shared_ptr<op::DynReshape> dyn_reshape) shared_ptr<op::v1::Reshape> dyn_reshape)
{ {
auto out_shape = dyn_reshape->get_shape(); auto out_shape = dyn_reshape->get_shape();
...@@ -50,7 +50,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape() ...@@ -50,7 +50,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>()); element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto constant_shape_label = auto constant_shape_label =
make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>()); make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>());
auto dyn_reshape = make_shared<op::DynReshape>(constant_data_label, constant_shape_label); auto dyn_reshape = make_shared<op::v1::Reshape>(constant_data_label, constant_shape_label);
// Note: No need to capture or consider constant_shape_label, because // Note: No need to capture or consider constant_shape_label, because
// shape propagation will have transferred the info to dyn_reshape's // shape propagation will have transferred the info to dyn_reshape's
...@@ -63,7 +63,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape() ...@@ -63,7 +63,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
auto constant_data_match = auto constant_data_match =
static_pointer_cast<op::Constant>(pattern_map[constant_data_label]); static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto dyn_reshape_match = static_pointer_cast<op::DynReshape>(m.get_match_root()); auto dyn_reshape_match = static_pointer_cast<op::v1::Reshape>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(dyn_reshape_match)); NGRAPH_CHECK(revalidate_and_ensure_static(dyn_reshape_match));
......
...@@ -1468,14 +1468,14 @@ TEST(constant_folding, constant_dyn_reshape) ...@@ -1468,14 +1468,14 @@ TEST(constant_folding, constant_dyn_reshape)
auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in); auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape); auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
auto dyn_reshape = make_shared<op::DynReshape>(constant_in, constant_shape); auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{}); auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>(); pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f); pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0); ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1); ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0)); auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
...@@ -1492,9 +1492,9 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant) ...@@ -1492,9 +1492,9 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
Shape shape_shape{3}; Shape shape_shape{3};
// We're going to add these two together elementwise to get {2, 4, 1}. // We're going to add these two together elementwise to get {2, 4, 1}.
// This means that when ConstantFolding starts, DynReshape will not yet // This means that when ConstantFolding starts, v1::Reshape will not yet
// have static output shape. But by the time the Add op is folded, the // have static output shape. But by the time the Add op is folded, the
// DynReshape's shape should be inferrable. // v1::Reshape's shape should be inferrable.
vector<int64_t> values_shape_a{1, 3, 0}; vector<int64_t> values_shape_a{1, 3, 0};
vector<int64_t> values_shape_b{1, 1, 1}; vector<int64_t> values_shape_b{1, 1, 1};
...@@ -1502,7 +1502,7 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant) ...@@ -1502,7 +1502,7 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a); auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a);
auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b); auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b);
auto dyn_reshape = auto dyn_reshape =
make_shared<op::DynReshape>(constant_in, constant_shape_a + constant_shape_b); make_shared<op::v1::Reshape>(constant_in, constant_shape_a + constant_shape_b);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{}); auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
ASSERT_TRUE(dyn_reshape->output(0).get_partial_shape().is_dynamic()); ASSERT_TRUE(dyn_reshape->output(0).get_partial_shape().is_dynamic());
...@@ -1511,7 +1511,7 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant) ...@@ -1511,7 +1511,7 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
pass_manager.register_pass<pass::ConstantFolding>(); pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f); pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0); ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1); ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0)); auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment