Commit 75379523 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

CF updates: Reshape, DynReshape, Transpose (#3338)

* Update Reshape CF to support all ETs

* Add CF for DynReshape

* Add CF for Transpose

* Add #include <numeric>, for std::iota

* style, oops
parent 30c7028f
This diff is collapsed.
......@@ -44,7 +44,9 @@ public:
REVERSE,
PRODUCT,
SUM,
CONCAT
CONCAT,
DYN_RESHAPE,
TRANSPOSE
};
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
......@@ -64,6 +66,8 @@ public:
construct_constant_product();
construct_constant_sum();
construct_constant_concat();
construct_constant_dyn_reshape();
construct_constant_transpose();
}
//this allows to specify the order in which matchers will be run
......@@ -90,6 +94,8 @@ public:
case CFTransformations::PRODUCT: construct_constant_product(); break;
case CFTransformations::SUM: construct_constant_sum(); break;
case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break;
case CFTransformations::TRANSPOSE: construct_constant_transpose(); break;
}
}
}
......@@ -108,6 +114,8 @@ private:
void construct_constant_product();
void construct_constant_sum();
void construct_constant_concat();
void construct_constant_dyn_reshape();
void construct_constant_transpose();
ngraph::BuildNodeExecutorMap m_cfmap;
};
......@@ -739,6 +739,63 @@ TEST(constant_folding, const_floor)
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, constant_dyn_reshape)
{
Shape shape_in{2, 4};
vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
Shape shape_shape{3};
vector<int64_t> values_shape{2, 4, 1};
auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
auto dyn_reshape = make_shared<op::DynReshape>(constant_in, constant_shape);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<float>();
ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, constant_transpose)
{
Shape shape_in{2, 4};
vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
Shape shape_perm{2};
vector<int64_t> values_perm{1, 0};
auto constant_in = make_shared<op::Constant>(element::f64, shape_in, values_in);
auto constant_perm = make_shared<op::Constant>(element::i64, shape_perm, values_perm);
auto transpose = make_shared<op::Transpose>(constant_in, constant_perm);
auto f = make_shared<Function>(transpose, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<double>();
vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment