Commit 64479eb0 authored by Gleb Kazantaev's avatar Gleb Kazantaev Committed by Scott Cyphers

Constant folding support for V1 Reshape (#3833)

* Constant folding support for V1 Reshape

* Fixed constant folding tests
parent 73fff9f4
......@@ -17,7 +17,7 @@
#include <numeric>
#include "constant_folding.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -26,7 +26,7 @@ using namespace ngraph;
template <class T>
shared_ptr<op::Constant> fold_constant_dyn_reshape(shared_ptr<op::Constant> constant_data,
shared_ptr<op::DynReshape> dyn_reshape)
shared_ptr<op::v1::Reshape> dyn_reshape)
{
auto out_shape = dyn_reshape->get_shape();
......@@ -50,7 +50,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto constant_shape_label =
make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>());
auto dyn_reshape = make_shared<op::DynReshape>(constant_data_label, constant_shape_label);
auto dyn_reshape = make_shared<op::v1::Reshape>(constant_data_label, constant_shape_label);
// Note: No need to capture or consider constant_shape_label, because
// shape propagation will have transferred the info to dyn_reshape's
......@@ -63,7 +63,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
auto constant_data_match =
static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto dyn_reshape_match = static_pointer_cast<op::DynReshape>(m.get_match_root());
auto dyn_reshape_match = static_pointer_cast<op::v1::Reshape>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(dyn_reshape_match));
......
......@@ -1468,14 +1468,14 @@ TEST(constant_folding, constant_dyn_reshape)
auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
auto dyn_reshape = make_shared<op::DynReshape>(constant_in, constant_shape);
auto dyn_reshape = make_shared<op::v1::Reshape>(constant_in, constant_shape);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
......@@ -1492,9 +1492,9 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
Shape shape_shape{3};
// We're going to add these two together elementwise to get {2, 4, 1}.
// This means that when ConstantFolding starts, DynReshape will not yet
// This means that when ConstantFolding starts, v1::Reshape will not yet
// have static output shape. But by the time the Add op is folded, the
// DynReshape's shape should be inferrable.
// v1::Reshape's shape should be inferrable.
vector<int64_t> values_shape_a{1, 3, 0};
vector<int64_t> values_shape_b{1, 1, 1};
......@@ -1502,7 +1502,7 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a);
auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b);
auto dyn_reshape =
make_shared<op::DynReshape>(constant_in, constant_shape_a + constant_shape_b);
make_shared<op::v1::Reshape>(constant_in, constant_shape_a + constant_shape_b);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
ASSERT_TRUE(dyn_reshape->output(0).get_partial_shape().is_dynamic());
......@@ -1511,7 +1511,7 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment