Commit ba96aa0f authored by Ayan Moitra's avatar Ayan Moitra Committed by Scott Cyphers

Concat Elimination and Self Concat Fusion pass (#2634)

* [WIP] First commit

* Remove some commented code

* Further changes

* Further changes

* Add method to remove patterns with just one concat

* Add tests

* Add more tests

* Fix fan out case

* refactor code

* refactor code

* Added NGRAPH_DEBUG statements

* Use INTERPRETER as backend instead of CPU...travis build failure

* clang

* minor edit

* add more checks in the tests

* Incorporate Bob's comment

* Removed some NGRAPH_DEBUG statements and incorporated Pruthvi's comment

* Incorporate Xiaoyu's comments

* some refactoring
parent abd1c70d
......@@ -320,6 +320,8 @@ set (SRC
pass/zero_dim_tensor_elimination.cpp
pass/zero_dim_tensor_elimination.hpp
pass/zero_dim_tensor_elimination.hpp
pass/concat_fusion.hpp
pass/concat_fusion.cpp
pattern/matcher.cpp
pattern/matcher.hpp
pattern/op/any.hpp
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
namespace ngraph
{
namespace pass
{
class ConcatElimination;
class SelfConcatFusion;
}
}
class ngraph::pass::ConcatElimination : public ngraph::pass::GraphRewrite
{
public:
ConcatElimination()
: GraphRewrite()
{
construct_concat_elimination();
}
private:
void construct_concat_elimination();
};
class ngraph::pass::SelfConcatFusion : public ngraph::pass::FunctionPass
{
public:
virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
private:
void update_concat_pattern_vectors(const std::shared_ptr<Node>&);
void remove_single_concat_op_pattern();
void construct_concat_patterns(const std::shared_ptr<pattern::Matcher>&,
const std::shared_ptr<pattern::op::Label>&,
const std::shared_ptr<Node>&);
bool replace_patterns(const NodeVector&);
std::vector<NodeVector> m_concat_pattern_vectors;
};
......@@ -33,6 +33,7 @@ set(SRC
build_graph.cpp
builder_autobroadcast.cpp
constant_folding.cpp
concat_fusion.cpp
control_dependencies.cpp
coordinate.cpp
copy.cpp
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment