Commit 76c73c91 authored by Ayan Moitra's avatar Ayan Moitra Committed by Scott Cyphers

General-purpose recurrent reshape elimination pass (#2665)

* [WIP] First commit

* Incremental code changes

* Incremental code changes

* Further mods

* Improve src + add more tests

* Another test added

* clang

* Added NGRAPH_DEBUG statements

* Incorporate Xiaoyu and Scott's comment

* Incorporate Adam's comments in tests

* Incorporate Adam's comments

* Add Jayaram's comments
parent f9d0bd57
...@@ -338,6 +338,8 @@ set (SRC ...@@ -338,6 +338,8 @@ set (SRC
pass/zero_dim_tensor_elimination.hpp pass/zero_dim_tensor_elimination.hpp
pass/concat_fusion.hpp pass/concat_fusion.hpp
pass/concat_fusion.cpp pass/concat_fusion.cpp
pass/pass_util.hpp
pass/pass_util.cpp
pattern/matcher.cpp pattern/matcher.cpp
pattern/matcher.hpp pattern/matcher.hpp
pattern/op/any.hpp pattern/op/any.hpp
......
...@@ -53,18 +53,8 @@ namespace ...@@ -53,18 +53,8 @@ namespace
bool check_concat_has_no_fan_out(const std::shared_ptr<Node>& op) bool check_concat_has_no_fan_out(const std::shared_ptr<Node>& op)
{ {
auto users = op->get_users(true); auto no_fan_out = ngraph::pass::get_no_fan_out_function();
std::set<std::shared_ptr<Node>> user_set(users.begin(), users.end()); return no_fan_out(op);
size_t num_unique_users = user_set.size();
if (num_unique_users == 1)
{
return true;
}
else
{
NGRAPH_DEBUG << "self_concat_fusion: " << op->get_name() << " has fan out\n";
return false;
}
} }
bool valid_self_concat(const std::shared_ptr<Node>& Op) bool valid_self_concat(const std::shared_ptr<Node>& Op)
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/pass/pass_util.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/pass_util.hpp"
using namespace std;
using namespace ngraph;
std::function<bool(std::shared_ptr<Node>)> ngraph::pass::get_no_fan_out_function()
{
auto ret_fun = [](std::shared_ptr<Node> n) {
auto users = n->get_users(true);
std::set<std::shared_ptr<Node>> user_set(users.begin(), users.end());
size_t num_unique_users = user_set.size();
if (num_unique_users == 1)
{
return true;
}
else
{
NGRAPH_DEBUG << n->get_name() << " has fan out\n";
return false;
}
};
return ret_fun;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <cstdlib> // llvm 8.1 gets confused about `malloc` otherwise
#include <functional>
#include <iostream>
#include <set>
#include <sstream>
#include <string>
#include "ngraph/node.hpp"
namespace ngraph
{
namespace pass
{
std::function<bool(std::shared_ptr<Node>)> get_no_fan_out_function();
}
}
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
...@@ -192,3 +191,105 @@ void pass::ReshapeElimination::construct_dot_transpose_pattern() ...@@ -192,3 +191,105 @@ void pass::ReshapeElimination::construct_dot_transpose_pattern()
auto m = make_shared<pattern::Matcher>(preshape, callback); auto m = make_shared<pattern::Matcher>(preshape, callback);
this->add_matcher(m); this->add_matcher(m);
} }
void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
{
Shape shape_op{3};
Shape shape_r{1, 3};
auto op = make_shared<pattern::op::Label>(element::f32, shape_op);
auto reshape = make_shared<op::Reshape>(op, AxisVector{0}, shape_r);
auto reshape_label =
make_shared<pattern::op::Label>(reshape, get_no_fan_out_function(), NodeVector{reshape});
auto callback = [op, reshape_label](pattern::RecurrentMatcher& m) {
NGRAPH_DEBUG << "In callback for construct_recurrent_reshape against node = "
<< reshape_label->get_argument(0)->get_name();
auto reshape_node_vector = m.get_bound_nodes_for_pattern(reshape_label);
// The bound node vector is in reverse order. It is convenient to have the
// bound node vector in the correct order
std::reverse(std::begin(reshape_node_vector), std::end(reshape_node_vector));
auto first_bound_reshape_op = reshape_node_vector.front();
auto driver_op = first_bound_reshape_op->get_argument(0);
auto last_bound_reshape_op = reshape_node_vector.back();
// Need to check if the user of the last bound op is a reshape since the last reshape is allowed
// to have fan-out but the matcher will discard any reshape if it has fan-out
auto user_of_last_bound_reshape_op = last_bound_reshape_op->get_users(true)[0];
if (std::dynamic_pointer_cast<op::Reshape>(user_of_last_bound_reshape_op))
{
reshape_node_vector.push_back(user_of_last_bound_reshape_op);
last_bound_reshape_op = reshape_node_vector.back();
}
// Return if the recurrent matcher matches only one reshape
if (reshape_node_vector.size() == 1)
{
return false;
}
// The complete reshape node vector may not contain contiguous reshapes that can be
// fused. Only the subset of reshapes with a reshape(any axis order) followed by reshapes
// with default axis order can be fused. Creating such subpatterns here:
std::vector<NodeVector> sub_patterns{NodeVector{first_bound_reshape_op}};
for (auto it = std::next(reshape_node_vector.begin()); it != reshape_node_vector.end();
it++)
{
auto r = std::dynamic_pointer_cast<op::Reshape>(*it);
// Check that the input to r is the last reshape stored in the
// subpattern vector
if (!r)
{
NGRAPH_DEBUG
<< "Incorrect match. Something went wrong. Non-reshape op has been matched";
return false;
}
auto default_order_r = get_default_order(r->get_input_shape(0));
if (r->get_input_order() == default_order_r)
{
sub_patterns.back().push_back(r);
}
else
{
NGRAPH_DEBUG << r->get_name() << "does not have default axis order. "
<< "It might be part of a different subpattern";
sub_patterns.push_back(NodeVector{r});
}
}
bool modify_graph = false;
// Replace the patterns
for (auto sub_pattern : sub_patterns)
{
// Do not consider subpatterns with just one reshape in them
if (sub_pattern.size() == 1)
{
continue;
}
auto first_reshape = std::dynamic_pointer_cast<op::Reshape>(sub_pattern.front());
auto input_to_first_reshape = first_reshape->get_argument(0);
auto last_reshape = std::dynamic_pointer_cast<op::Reshape>(sub_pattern.back());
auto new_input_order = first_reshape->get_input_order();
auto new_out_shape = last_reshape->get_shape();
auto new_reshape = std::make_shared<op::Reshape>(
input_to_first_reshape, new_input_order, new_out_shape);
replace_node(last_reshape, new_reshape);
modify_graph = true;
}
return modify_graph;
};
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>(
reshape_label, op, empty_correlated_matches, callback);
this->add_matcher(m);
}
...@@ -17,12 +17,14 @@ ...@@ -17,12 +17,14 @@
#pragma once #pragma once
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/pass_util.hpp"
namespace ngraph namespace ngraph
{ {
namespace pass namespace pass
{ {
class ReshapeElimination; class ReshapeElimination;
class RecurrentReshapeElimination;
} }
} }
...@@ -42,3 +44,16 @@ private: ...@@ -42,3 +44,16 @@ private:
void construct_identity_reshape_pattern(); void construct_identity_reshape_pattern();
void construct_reshapex2_pattern(); void construct_reshapex2_pattern();
}; };
class ngraph::pass::RecurrentReshapeElimination : public ngraph::pass::RecurrentGraphRewrite
{
public:
RecurrentReshapeElimination()
: RecurrentGraphRewrite()
{
construct_recurrent_reshape();
}
private:
void construct_recurrent_reshape();
};
...@@ -29,13 +29,17 @@ ...@@ -29,13 +29,17 @@
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/reshape_elimination.hpp" #include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp" #include "ngraph/pattern/op/skip.hpp"
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "util/all_close.hpp"
#include "util/matcher.hpp" #include "util/matcher.hpp"
#include "util/random.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -106,3 +110,325 @@ TEST(reshape_elimination, dot_transpose_to_dot_w_transpose_args) ...@@ -106,3 +110,325 @@ TEST(reshape_elimination, dot_transpose_to_dot_w_transpose_args)
ASSERT_EQ(gdot->get_argument(1)->get_argument(0), W); ASSERT_EQ(gdot->get_argument(1)->get_argument(0), W);
ASSERT_EQ(gdot->get_shape(), (Shape{1, 2})); ASSERT_EQ(gdot->get_shape(), (Shape{1, 2}));
} }
TEST(reshape_elimination, recurrent_reshapes)
{
Shape shape_a{2, 2, 3, 3, 2, 4};
auto generate_func = [shape_a]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r_1{3, 2, 2, 4, 6};
Shape shape_r_2{6, 8, 3, 2};
Shape shape_r_3{6, 8, 6};
Shape shape_r_4{6, 2, 2, 2, 6};
Shape shape_r_5{2, 3, 2, 2, 2, 3, 2};
Shape shape_r_6{48, 6};
auto r_1 = make_shared<op::Reshape>(A, AxisVector{2, 4, 0, 5, 3, 1}, shape_r_1);
auto r_2 = make_shared<op::Reshape>(r_1, AxisVector{0, 1, 2, 3, 4}, shape_r_2);
auto r_3 = make_shared<op::Reshape>(r_2, AxisVector{0, 1, 2, 3}, shape_r_3);
auto r_4 = make_shared<op::Reshape>(r_3, AxisVector{0, 1, 2}, shape_r_4);
auto r_5 = make_shared<op::Reshape>(r_4, AxisVector{0, 1, 2, 3, 4}, shape_r_5);
auto r_6 = make_shared<op::Reshape>(r_5, AxisVector{0, 1, 2, 3, 4, 5, 6}, shape_r_6);
auto f = make_shared<Function>(r_6, ParameterVector{A});
return f;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
pass::Manager pass_manager;
// pass_manager.register_pass<pass::VisualizeTree>("before_recurrent_reshapes.pdf");
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
// pass_manager.register_pass<pass::VisualizeTree>("after_recurrent_reshapes.pdf");
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(baseline_input_shape));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 1);
}
TEST(reshape_elimination, recurrent_reshapes_elimination)
{
Shape shape_a{2, 2, 3, 3, 2, 4};
auto generate_func = [shape_a]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r_1{3, 2, 2, 4, 6};
Shape shape_r_2{6, 8, 3, 2};
Shape shape_r_3{6, 8, 6};
Shape shape_r_4{6, 2, 2, 2, 6};
Shape shape_r_5{2, 3, 2, 2, 2, 3, 2};
Shape shape_r_6{48, 6};
Shape shape_r_7{2, 2, 3, 3, 2, 4};
auto r_1 = make_shared<op::Reshape>(A, AxisVector{0, 1, 2, 3, 4, 5}, shape_r_1);
auto r_2 = make_shared<op::Reshape>(r_1, AxisVector{0, 1, 2, 3, 4}, shape_r_2);
auto r_3 = make_shared<op::Reshape>(r_2, AxisVector{0, 1, 2, 3}, shape_r_3);
auto r_4 = make_shared<op::Reshape>(r_3, AxisVector{0, 1, 2}, shape_r_4);
auto r_5 = make_shared<op::Reshape>(r_4, AxisVector{0, 1, 2, 3, 4}, shape_r_5);
auto r_6 = make_shared<op::Reshape>(r_5, AxisVector{0, 1, 2, 3, 4, 5, 6}, shape_r_6);
auto r_7 = make_shared<op::Reshape>(r_6, AxisVector{0, 1}, shape_r_7);
auto f = make_shared<Function>(r_7, ParameterVector{A});
return f;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
pass::Manager pass_manager;
// pass_manager.register_pass<pass::VisualizeTree>("before_recurrent_reshapes_elimination.pdf");
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
// pass_manager.register_pass<pass::VisualizeTree>("after_1_recurrent_reshapes_elimination.pdf");
pass_manager.register_pass<pass::ReshapeElimination>();
// pass_manager.register_pass<pass::VisualizeTree>("after_2_recurrent_reshapes_elimination.pdf");
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(baseline_input_shape));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 0);
}
TEST(reshape_elimination, recurrent_reshapes_fan_out)
{
Shape shape_a{4, 6, 10, 2};
auto generate_func = [shape_a]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r_1{6, 4, 5, 4};
Shape shape_r_2{24, 20};
auto reshape_1 = make_shared<op::Reshape>(A, AxisVector{0, 3, 2, 1}, shape_r_1);
auto reshape_2 = make_shared<op::Reshape>(reshape_1, AxisVector{0, 1, 2, 3}, shape_r_2);
auto reshape_3 = make_shared<op::Reshape>(reshape_2, AxisVector{0, 1}, shape_a);
auto f_ = make_shared<Function>(NodeVector{reshape_2, reshape_3}, ParameterVector{A});
return f_;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
pass::Manager pass_manager;
// pass_manager.register_pass<pass::VisualizeTree>("before_recurrent_reshapes_fan_out.pdf");
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
// pass_manager.register_pass<pass::VisualizeTree>("after_recurrent_reshapes_fan_out.pdf");
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(baseline_input_shape));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 2);
}
TEST(reshape_elimination, recurrent_reshapes_fan_out_at_end)
{
Shape shape_a{12, 8, 1, 1};
auto generate_func = [shape_a]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto reshape_1 = make_shared<op::Reshape>(A, AxisVector{0, 3, 2, 1}, Shape{4, 3, 8, 1});
auto reshape_2 = make_shared<op::Reshape>(reshape_1, AxisVector{0, 1, 2, 3}, shape_a);
auto reshape_3 =
make_shared<op::Reshape>(reshape_2, AxisVector{0, 1, 2, 3}, Shape{4, 3, 8, 1});
auto abs_1 = make_shared<op::Abs>(reshape_3);
auto f_ = make_shared<Function>(NodeVector{abs_1, reshape_3}, ParameterVector{A});
return f_;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
pass::Manager pass_manager;
// pass_manager.register_pass<pass::VisualizeTree>("before_recurrent_reshapes_fan_out_at_end.pdf");
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
// pass_manager.register_pass<pass::VisualizeTree>("after_recurrent_reshapes_fan_out_at_end.pdf");
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(baseline_input_shape));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 1);
}
TEST(reshape_elimination, recurrent_reshapes_multiple_fusions)
{
Shape shape_a{2, 2, 3, 3, 2, 4};
auto generate_func = [shape_a]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r_1{3, 2, 2, 4, 6};
Shape shape_r_2{6, 8, 3, 2};
Shape shape_r_3{6, 8, 6};
Shape shape_r_4{6, 2, 2, 2, 6};
Shape shape_r_5{2, 3, 2, 2, 2, 3, 2};
Shape shape_r_6{48, 6};
auto r_1 = make_shared<op::Reshape>(A, AxisVector{2, 4, 0, 5, 3, 1}, shape_r_1);
auto r_2 = make_shared<op::Reshape>(r_1, AxisVector{0, 1, 2, 3, 4}, shape_r_2);
auto r_3 = make_shared<op::Reshape>(r_2, AxisVector{0, 1, 2, 3}, shape_r_3);
auto r_4 = make_shared<op::Reshape>(r_3, AxisVector{1, 0, 2}, shape_r_4);
auto r_5 = make_shared<op::Reshape>(r_4, AxisVector{0, 1, 2, 3, 4}, shape_r_5);
auto r_6 = make_shared<op::Reshape>(r_5, AxisVector{0, 1, 2, 3, 4, 5, 6}, shape_r_6);
auto f = make_shared<Function>(r_6, ParameterVector{A});
return f;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
pass::Manager pass_manager;
// pass_manager.register_pass<pass::VisualizeTree>(
// "before_recurrent_reshapes_multiple_fusions.pdf");
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
// pass_manager.register_pass<pass::VisualizeTree>(
// "after_recurrent_reshapes_multiple_fusions.pdf");
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(baseline_input_shape));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 2);
}
TEST(reshape_elimination, nonrecurrent_reshapes)
{
Shape shape_a{8, 6, 1, 1};
Shape shape_r{2, 24};
auto generate_func = [shape_a, shape_r]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto reshape_1 = make_shared<op::Reshape>(A, AxisVector{3, 0, 2, 1}, shape_r);
auto abs_1 = make_shared<op::Abs>(reshape_1);
auto reshape_2 = make_shared<op::Reshape>(abs_1, AxisVector{0, 1}, shape_a);
auto abs_2 = make_shared<op::Abs>(reshape_2);
auto reshape_3 = make_shared<op::Reshape>(abs_2, AxisVector{0, 1, 2, 3}, shape_a);
auto f_ = make_shared<Function>(NodeVector{reshape_3}, ParameterVector{A});
return f_;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
pass::Manager pass_manager;
// pass_manager.register_pass<pass::VisualizeTree>("before_nonrecurrent_reshapes.pdf");
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
// pass_manager.register_pass<pass::VisualizeTree>("after_nonrecurrent_reshapes.pdf");
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(baseline_input_shape));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 3);
}
TEST(reshape_elimination, recurrent_reshapes_multiple_branches)
{
Shape shape_a{2, 2, 3, 3, 2, 4};
auto generate_func = [shape_a]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r_1{3, 2, 2, 4, 6};
Shape shape_r_2{6, 8, 3, 2};
Shape shape_r_3{6, 8, 6};
Shape shape_r_4{6, 2, 2, 2, 6};
Shape shape_r_5{2, 3, 2, 2, 2, 3, 2};
Shape shape_r_6{48, 6};
auto r_1 = make_shared<op::Reshape>(A, AxisVector{2, 4, 0, 5, 3, 1}, shape_r_1);
auto r_2 = make_shared<op::Reshape>(r_1, AxisVector{0, 1, 2, 3, 4}, shape_r_2);
auto r_3 = make_shared<op::Reshape>(r_2, AxisVector{0, 1, 2, 3}, shape_r_3);
auto r_4 = make_shared<op::Reshape>(r_3, AxisVector{0, 1, 2}, shape_r_4);
auto r_5 = make_shared<op::Reshape>(r_4, AxisVector{0, 1, 2, 3, 4}, shape_r_5);
auto r_6 = make_shared<op::Reshape>(r_5, AxisVector{0, 1, 2, 3, 4, 5, 6}, shape_r_6);
auto r_7 = make_shared<op::Reshape>(A, AxisVector{2, 4, 0, 5, 3, 1}, shape_r_2);
auto r_8 = make_shared<op::Reshape>(r_7, AxisVector{0, 1, 2, 3}, shape_r_3);
auto f = make_shared<Function>(NodeVector{r_6, r_8}, ParameterVector{A});
return f;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
pass::Manager pass_manager;
// pass_manager.register_pass<pass::VisualizeTree>(
// "before_recurrent_reshapes_multiple_branches.pdf");
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
// pass_manager.register_pass<pass::VisualizeTree>(
// "after_recurrent_reshapes_multiple_branches.pdf");
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(baseline_input_shape));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 2);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment