Commit ba96aa0f authored by Ayan Moitra's avatar Ayan Moitra Committed by Scott Cyphers

Concat Elimination and Self Concat Fusion pass (#2634)

* [WIP] First commit

* Remove some commented code

* Further changes

* Further changes

* Add method to remove patterns with just one concat

* Add tests

* Add more tests

* Fix fan out case

* refactor code

* refactor code

* Added NGRAPH_DEBUG statements

* Use INTERPRETER as backend instead of CPU...travis build failure

* clang

* minor edit

* add more checks in the tests

* Incorporate Bob's comment

* Removed some NGRAPH_DEBUG statements and incorporated Pruthvi's comment

* Incorporate Xiaoyu's comments

* some refactoring
parent abd1c70d
...@@ -320,6 +320,8 @@ set (SRC ...@@ -320,6 +320,8 @@ set (SRC
pass/zero_dim_tensor_elimination.cpp pass/zero_dim_tensor_elimination.cpp
pass/zero_dim_tensor_elimination.hpp pass/zero_dim_tensor_elimination.hpp
pass/zero_dim_tensor_elimination.hpp pass/zero_dim_tensor_elimination.hpp
pass/concat_fusion.hpp
pass/concat_fusion.cpp
pattern/matcher.cpp pattern/matcher.cpp
pattern/matcher.hpp pattern/matcher.hpp
pattern/op/any.hpp pattern/op/any.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "concat_fusion.hpp"
#include <algorithm>
#include <iostream>
#include <numeric>
#include <unordered_set>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
namespace
{
bool check_self_concat_op(const std::shared_ptr<Node>& op)
{
auto input_args = op->get_arguments();
std::set<std::shared_ptr<Node>> input_args_set(input_args.begin(), input_args.end());
return (input_args_set.size() == 1);
}
bool check_concat_axis_dim_value(const std::shared_ptr<Node>& concat_op)
{
auto input_shape = concat_op->get_input_shape(0);
size_t concat_axis =
std::static_pointer_cast<op::Concat>(concat_op)->get_concatenation_axis();
return (input_shape[concat_axis] == 1);
}
bool check_concat_has_no_fan_out(const std::shared_ptr<Node>& op)
{
auto users = op->get_users(true);
std::set<std::shared_ptr<Node>> user_set(users.begin(), users.end());
size_t num_unique_users = user_set.size();
if (num_unique_users == 1)
{
return true;
}
else
{
NGRAPH_DEBUG << "self_concat_fusion: " << op->get_name() << " has fan out\n";
return false;
}
}
bool valid_self_concat(const std::shared_ptr<Node>& Op)
{
if (!check_self_concat_op(Op))
{
NGRAPH_DEBUG << "self_concat_fusion: Matcher matched " << Op->get_name()
<< " but it is not a self concat\n";
return false;
}
if (!check_concat_axis_dim_value(Op))
{
NGRAPH_DEBUG << "self_concat_fusion: Input shape value along concat axis of "
<< Op->get_name() << " is not equal to 1\n";
return false;
}
return true;
}
std::vector<size_t> get_concatenation_axis_vector(const NodeVector& bounded_concat_ops)
{
std::vector<size_t> concat_axis_vec;
for (auto iter : bounded_concat_ops)
{
auto concat_op = std::static_pointer_cast<op::Concat>(iter);
concat_axis_vec.push_back(concat_op->get_concatenation_axis());
}
return concat_axis_vec;
}
}
void pass::ConcatElimination::construct_concat_elimination()
{
auto op_label = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3});
auto concat = std::make_shared<op::Concat>(NodeVector{op_label}, 0);
auto concat_label = std::make_shared<pattern::op::Label>(concat, nullptr, NodeVector{concat});
auto callback = [op_label](pattern::Matcher& m) {
NGRAPH_DEBUG
<< "concat_elimination: In callback for construct_concat_elimination against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto op = pattern_map[op_label];
auto root = std::dynamic_pointer_cast<op::Concat>(m.get_match_root());
if (root && (root->get_input_shape(0) == root->get_output_shape(0)))
{
NGRAPH_DEBUG << " eliminated " << m.get_match_root() << "\n";
replace_node(m.get_match_root(), op);
return true;
}
NGRAPH_DEBUG << " Incorrect match in callback\n";
return false;
};
auto m = std::make_shared<pattern::Matcher>(concat_label, callback);
this->add_matcher(m);
}
bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> function)
{
bool modify_graph = false;
auto has_multiple_inputs = [](std::shared_ptr<Node> n) {
auto input_size = n->get_input_size();
auto root = std::dynamic_pointer_cast<op::Concat>(n);
return (root && input_size > 1);
};
auto print_state_of_bounded_vectors = [this]() -> std::string {
std::stringstream ss;
ss << "-----------------------------------------------------------" << std::endl;
ss << "State of bounded pattern node vectors: " << std::endl;
ss << "-----------------------------------------------------------" << std::endl;
ss << "Number of pattern node vectors: " << this->m_concat_pattern_vectors.size()
<< std::endl;
size_t c = 0;
for (auto iter : this->m_concat_pattern_vectors)
{
ss << "For vector " << c << std::endl;
auto iter_node_vec = iter;
ss << "concat_op_vector: ";
for (auto it : iter_node_vec)
{
ss << it->get_name() << " ";
}
ss << std::endl;
c++;
}
ss << "-----------------------------" << std::endl;
return ss.str();
};
auto concat_op_label =
std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3}, has_multiple_inputs);
auto matcher = std::make_shared<pattern::Matcher>(concat_op_label);
for (auto n : function->get_ordered_ops())
{
construct_concat_patterns(matcher, concat_op_label, n);
}
NGRAPH_DEBUG << print_state_of_bounded_vectors();
remove_single_concat_op_pattern();
for (auto concat_op_pattern_node_vector : this->m_concat_pattern_vectors)
{
modify_graph = replace_patterns(concat_op_pattern_node_vector);
}
return modify_graph;
}
void ngraph::pass::SelfConcatFusion::construct_concat_patterns(
const std::shared_ptr<pattern::Matcher>& matcher,
const std::shared_ptr<pattern::op::Label>& concat_op_label,
const std::shared_ptr<Node>& n)
{
if (matcher->match(n))
{
auto concat_op = matcher->get_pattern_map()[concat_op_label];
if (!std::dynamic_pointer_cast<op::Concat>(concat_op))
{
NGRAPH_DEBUG << "self_concat_fusion: Pattern matcher matched incorrect op. Matched "
<< concat_op->get_name() << " instead of a self concat";
return;
}
if (!valid_self_concat(concat_op))
{
NGRAPH_DEBUG << "self_concat_fusion: " << concat_op->get_name()
<< " is not a valid self concat\n";
return;
}
else
{
NGRAPH_DEBUG << "self_concat_fusion: " << concat_op->get_name()
<< " is a VALID self concat\n";
}
auto& concat_vectors = this->m_concat_pattern_vectors;
if (concat_vectors.empty())
{
concat_vectors.push_back(NodeVector{concat_op});
}
else
{
update_concat_pattern_vectors(concat_op);
}
}
}
void ngraph::pass::SelfConcatFusion::update_concat_pattern_vectors(
const std::shared_ptr<Node>& concat_op)
{
bool concat_source_found = false;
for (auto& concat_pattern_vec : this->m_concat_pattern_vectors)
{
auto last_op_in_pattern_vec = concat_pattern_vec.back();
if ((concat_op->get_argument(0) == last_op_in_pattern_vec) &&
(check_concat_has_no_fan_out(last_op_in_pattern_vec)))
{
concat_pattern_vec.push_back(concat_op);
concat_source_found = true;
break;
}
}
if (!concat_source_found)
{
this->m_concat_pattern_vectors.push_back(NodeVector{concat_op});
}
}
void ngraph::pass::SelfConcatFusion::remove_single_concat_op_pattern()
{
auto iter = m_concat_pattern_vectors.begin();
while (iter != m_concat_pattern_vectors.end())
{
if (iter->size() == 1)
{
iter = m_concat_pattern_vectors.erase(iter);
}
else
{
iter++;
}
}
}
bool ngraph::pass::SelfConcatFusion::replace_patterns(const NodeVector& bounded_concat_ops)
{
auto scalarize_dim = [](std::vector<size_t> concat_axis_vector,
const Shape& input_shape) -> Shape {
Shape scalarized_shape;
for (size_t i = 0; i < input_shape.size(); i++)
{
auto it = std::find(concat_axis_vector.begin(), concat_axis_vector.end(), i);
if (it == concat_axis_vector.end())
{
scalarized_shape.push_back(input_shape[i]);
}
}
return scalarized_shape;
};
auto concat_axis_vector = get_concatenation_axis_vector(bounded_concat_ops);
auto& first_bounded_concat = (*bounded_concat_ops.begin());
auto driver_op = first_bounded_concat->get_argument(0);
const Shape& input_shape = first_bounded_concat->get_input_shape(0);
auto scalarized_shape = scalarize_dim(concat_axis_vector, input_shape);
AxisVector axis_order = get_default_order(input_shape);
auto reshape = std::make_shared<op::Reshape>(driver_op, axis_order, scalarized_shape);
auto last_bounded_concat_op = bounded_concat_ops.back();
auto broadcast_out_shape = last_bounded_concat_op->get_shape();
auto broadcast =
std::make_shared<op::Broadcast>(reshape, broadcast_out_shape, concat_axis_vector);
replace_node(last_bounded_concat_op, broadcast);
return true;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
namespace ngraph
{
namespace pass
{
class ConcatElimination;
class SelfConcatFusion;
}
}
class ngraph::pass::ConcatElimination : public ngraph::pass::GraphRewrite
{
public:
ConcatElimination()
: GraphRewrite()
{
construct_concat_elimination();
}
private:
void construct_concat_elimination();
};
class ngraph::pass::SelfConcatFusion : public ngraph::pass::FunctionPass
{
public:
virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
private:
void update_concat_pattern_vectors(const std::shared_ptr<Node>&);
void remove_single_concat_op_pattern();
void construct_concat_patterns(const std::shared_ptr<pattern::Matcher>&,
const std::shared_ptr<pattern::op::Label>&,
const std::shared_ptr<Node>&);
bool replace_patterns(const NodeVector&);
std::vector<NodeVector> m_concat_pattern_vectors;
};
...@@ -33,6 +33,7 @@ set(SRC ...@@ -33,6 +33,7 @@ set(SRC
build_graph.cpp build_graph.cpp
builder_autobroadcast.cpp builder_autobroadcast.cpp
constant_folding.cpp constant_folding.cpp
concat_fusion.cpp
control_dependencies.cpp control_dependencies.cpp
coordinate.cpp coordinate.cpp
copy.cpp copy.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/pass/concat_fusion.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "util/all_close.hpp"
#include "util/autodiff/backprop_function.hpp"
#include "util/matcher.hpp"
#include "util/random.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(concat_fusion, single_branch)
{
Shape shape_a{128, 2048, 1, 1};
auto generate_func = [shape_a]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto concat_1 = make_shared<op::Concat>(NodeVector{A}, 2);
auto concat_2 = make_shared<op::Concat>(NodeVector{concat_1}, 2);
auto concat_3 = make_shared<op::Concat>(
NodeVector{concat_2, concat_2, concat_2, concat_2, concat_2, concat_2, concat_2}, 2);
auto concat_4 = make_shared<op::Concat>(
NodeVector{concat_3, concat_3, concat_3, concat_3, concat_3, concat_3, concat_3}, 3);
auto f_concat_1 = make_shared<Function>(NodeVector{concat_4}, ParameterVector{A});
return f_concat_1;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before_single_branch.pdf");
pass_manager.register_pass<pass::ConcatElimination>();
pass_manager.register_pass<pass::SelfConcatFusion>();
pass_manager.register_pass<pass::VisualizeTree>("after_single_branch.pdf");
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(baseline_input_shape));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 1);
ASSERT_EQ(num_broadcast_optimzed, 1);
}
TEST(concat_fusion, multiple_branches_1)
{
Shape shape_a{128, 2048, 1, 1};
auto generate_func = [shape_a]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto concat_1 = make_shared<op::Concat>(NodeVector{A}, 2);
auto concat_2 = make_shared<op::Concat>(NodeVector{concat_1}, 2);
auto concat_3 = make_shared<op::Concat>(
NodeVector{concat_2, concat_2, concat_2, concat_2, concat_2, concat_2, concat_2}, 2);
auto concat_4 = make_shared<op::Concat>(
NodeVector{concat_3, concat_3, concat_3, concat_3, concat_3, concat_3, concat_3}, 3);
auto concat_5 = make_shared<op::Concat>(NodeVector{A, A}, 2);
auto concat_6 = make_shared<op::Concat>(NodeVector{concat_5, concat_5, concat_5}, 3);
auto f_concat_1 = make_shared<Function>(NodeVector{concat_4, concat_6}, ParameterVector{A});
return f_concat_1;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before_multiple_branches_1.pdf");
pass_manager.register_pass<pass::ConcatElimination>();
pass_manager.register_pass<pass::SelfConcatFusion>();
pass_manager.register_pass<pass::VisualizeTree>("after_multiple_branches_1.pdf");
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(baseline_input_shape));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 2);
ASSERT_EQ(num_broadcast_optimzed, 2);
}
TEST(concat_fusion, multiple_branches_2)
{
Shape shape_a{128, 2048, 1, 1};
auto generate_func = [shape_a]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto concat_3 = make_shared<op::Concat>(NodeVector{A, A, A, A, A, A, A}, 2);
auto concat_4 = make_shared<op::Concat>(
NodeVector{concat_3, concat_3, concat_3, concat_3, concat_3, concat_3, concat_3}, 3);
auto concat_6 = make_shared<op::Concat>(NodeVector{A, A, A}, 3);
auto f_concat_1 = make_shared<Function>(NodeVector{concat_4, concat_6}, ParameterVector{A});
return f_concat_1;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before_multiple_branches_2.pdf");
pass_manager.register_pass<pass::ConcatElimination>();
pass_manager.register_pass<pass::SelfConcatFusion>();
pass_manager.register_pass<pass::VisualizeTree>("after_multiple_branches_2.pdf");
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(baseline_input_shape));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 1);
ASSERT_EQ(num_broadcast_optimzed, 1);
}
TEST(concat_fusion, non_fusable_self_concat)
{
Shape shape_a{128, 1, 1, 1};
Shape shape_b{128, 1, 1};
auto generate_func = [shape_a, shape_b]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto B = make_shared<op::Parameter>(element::f32, shape_b);
auto concat_1 = make_shared<op::Concat>(NodeVector{A, A, A, A}, 1);
auto concat_2 = make_shared<op::Concat>(
NodeVector{concat_1, concat_1, concat_1, concat_1, concat_1, concat_1, concat_1}, 2);
auto concat_3 = make_shared<op::Concat>(NodeVector{concat_2, concat_2}, 1);
auto concat_4 = make_shared<op::Concat>(NodeVector{concat_3, concat_3, concat_3}, 3);
auto concat_5 = make_shared<op::Concat>(NodeVector{B, B, B, B, B, B, B}, 1);
auto concat_6 = make_shared<op::Concat>(NodeVector{concat_5, concat_5, concat_5}, 2);
auto broadcast = make_shared<op::Broadcast>(concat_6, Shape{128, 8, 7, 3}, AxisSet{1});
auto add = make_shared<op::Add>(concat_4, broadcast);
auto f_concat_1 = make_shared<Function>(NodeVector{add}, ParameterVector{A, B});
return f_concat_1;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape_1 = baseline_f->get_parameters().at(0)->get_shape();
auto baseline_input_shape_2 = baseline_f->get_parameters().at(1)->get_shape();
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before_non_fusable_self_concat.pdf");
pass_manager.register_pass<pass::ConcatElimination>();
pass_manager.register_pass<pass::SelfConcatFusion>();
pass_manager.register_pass<pass::VisualizeTree>("after_non_fusable_self_concat.pdf");
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val_1(shape_size(baseline_input_shape_1));
vector<float> tensor_val_2(shape_size(baseline_input_shape_2));
rng.initialize(tensor_val_1);
rng.initialize(tensor_val_2);
args.push_back(tensor_val_1);
args.push_back(tensor_val_2);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 2);
ASSERT_EQ(num_broadcast_optimzed, 3);
}
TEST(concat_fusion, self_concat_with_fan_out)
{
Shape shape_a{8, 1, 1, 1};
Shape shape_b{8, 4, 1, 1};
auto generate_func = [shape_a, shape_b]() {
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto B = make_shared<op::Parameter>(element::f32, shape_b);
auto concat_1 = make_shared<op::Concat>(NodeVector{A, A, A, A, A, A, A}, 2);
auto concat_2 =
make_shared<op::Concat>(NodeVector{concat_1, concat_1, concat_1, concat_1}, 1);
auto concat_3 =
make_shared<op::Concat>(NodeVector{concat_2, concat_2, concat_2, concat_2}, 3);
auto concat_4 = make_shared<op::Concat>(NodeVector{B, B, B, B, B, B, B}, 2);
auto concat_5 = make_shared<op::Concat>(NodeVector{concat_4, concat_4, concat_4}, 3);
auto concat_6 = make_shared<op::Concat>(NodeVector{concat_2, concat_4}, 3);
auto f_concat_1 =
make_shared<Function>(NodeVector{concat_3, concat_6}, ParameterVector{A, B});
return f_concat_1;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape_1 = baseline_f->get_parameters().at(0)->get_shape();
auto baseline_input_shape_2 = baseline_f->get_parameters().at(1)->get_shape();
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before_self_concat_with_fan_out.pdf");
pass_manager.register_pass<pass::ConcatElimination>();
pass_manager.register_pass<pass::SelfConcatFusion>();
pass_manager.register_pass<pass::VisualizeTree>("after_self_concat_with_fan_out.pdf");
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val_1(shape_size(baseline_input_shape_1));
vector<float> tensor_val_2(shape_size(baseline_input_shape_2));
rng.initialize(tensor_val_1);
rng.initialize(tensor_val_2);
args.push_back(tensor_val_1);
args.push_back(tensor_val_2);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 1);
ASSERT_EQ(num_broadcast_optimzed, 1);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment