Commit 76c73c91 authored by Ayan Moitra's avatar Ayan Moitra Committed by Scott Cyphers

General-purpose recurrent reshape elimination pass (#2665)

* [WIP] First commit

* Incremental code changes

* Incremental code changes

* Further mods

* Improve src + add more tests

* Another test added

* clang

* Added NGRAPH_DEBUG statements

* Incorporate Xiaoyu and Scott's comment

* Incorporate Adam's comments in tests

* Incorporate Adam's comments

* Add Jayaram's comments
parent f9d0bd57
......@@ -338,6 +338,8 @@ set (SRC
pass/zero_dim_tensor_elimination.hpp
pass/concat_fusion.hpp
pass/concat_fusion.cpp
pass/pass_util.hpp
pass/pass_util.cpp
pattern/matcher.cpp
pattern/matcher.hpp
pattern/op/any.hpp
......
......@@ -53,18 +53,8 @@ namespace
bool check_concat_has_no_fan_out(const std::shared_ptr<Node>& op)
{
auto users = op->get_users(true);
std::set<std::shared_ptr<Node>> user_set(users.begin(), users.end());
size_t num_unique_users = user_set.size();
if (num_unique_users == 1)
{
return true;
}
else
{
NGRAPH_DEBUG << "self_concat_fusion: " << op->get_name() << " has fan out\n";
return false;
}
auto no_fan_out = ngraph::pass::get_no_fan_out_function();
return no_fan_out(op);
}
bool valid_self_concat(const std::shared_ptr<Node>& Op)
......
......@@ -18,6 +18,7 @@
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/pass_util.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/pass_util.hpp"
using namespace std;
using namespace ngraph;
std::function<bool(std::shared_ptr<Node>)> ngraph::pass::get_no_fan_out_function()
{
auto ret_fun = [](std::shared_ptr<Node> n) {
auto users = n->get_users(true);
std::set<std::shared_ptr<Node>> user_set(users.begin(), users.end());
size_t num_unique_users = user_set.size();
if (num_unique_users == 1)
{
return true;
}
else
{
NGRAPH_DEBUG << n->get_name() << " has fan out\n";
return false;
}
};
return ret_fun;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <cstdlib> // llvm 8.1 gets confused about `malloc` otherwise
#include <functional>
#include <iostream>
#include <set>
#include <sstream>
#include <string>
#include "ngraph/node.hpp"
namespace ngraph
{
namespace pass
{
std::function<bool(std::shared_ptr<Node>)> get_no_fan_out_function();
}
}
......@@ -24,7 +24,6 @@
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/reshape.hpp"
......@@ -192,3 +191,105 @@ void pass::ReshapeElimination::construct_dot_transpose_pattern()
auto m = make_shared<pattern::Matcher>(preshape, callback);
this->add_matcher(m);
}
void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
{
Shape shape_op{3};
Shape shape_r{1, 3};
auto op = make_shared<pattern::op::Label>(element::f32, shape_op);
auto reshape = make_shared<op::Reshape>(op, AxisVector{0}, shape_r);
auto reshape_label =
make_shared<pattern::op::Label>(reshape, get_no_fan_out_function(), NodeVector{reshape});
auto callback = [op, reshape_label](pattern::RecurrentMatcher& m) {
NGRAPH_DEBUG << "In callback for construct_recurrent_reshape against node = "
<< reshape_label->get_argument(0)->get_name();
auto reshape_node_vector = m.get_bound_nodes_for_pattern(reshape_label);
// The bound node vector is in reverse order. It is convenient to have the
// bound node vector in the correct order
std::reverse(std::begin(reshape_node_vector), std::end(reshape_node_vector));
auto first_bound_reshape_op = reshape_node_vector.front();
auto driver_op = first_bound_reshape_op->get_argument(0);
auto last_bound_reshape_op = reshape_node_vector.back();
// Need to check if the user of the last bound op is a reshape since the last reshape is allowed
// to have fan-out but the matcher will discard any reshape if it has fan-out
auto user_of_last_bound_reshape_op = last_bound_reshape_op->get_users(true)[0];
if (std::dynamic_pointer_cast<op::Reshape>(user_of_last_bound_reshape_op))
{
reshape_node_vector.push_back(user_of_last_bound_reshape_op);
last_bound_reshape_op = reshape_node_vector.back();
}
// Return if the recurrent matcher matches only one reshape
if (reshape_node_vector.size() == 1)
{
return false;
}
// The complete reshape node vector may not contain contiguous reshapes that can be
// fused. Only the subset of reshapes with a reshape(any axis order) followed by reshapes
// with default axis order can be fused. Creating such subpatterns here:
std::vector<NodeVector> sub_patterns{NodeVector{first_bound_reshape_op}};
for (auto it = std::next(reshape_node_vector.begin()); it != reshape_node_vector.end();
it++)
{
auto r = std::dynamic_pointer_cast<op::Reshape>(*it);
// Check that the input to r is the last reshape stored in the
// subpattern vector
if (!r)
{
NGRAPH_DEBUG
<< "Incorrect match. Something went wrong. Non-reshape op has been matched";
return false;
}
auto default_order_r = get_default_order(r->get_input_shape(0));
if (r->get_input_order() == default_order_r)
{
sub_patterns.back().push_back(r);
}
else
{
NGRAPH_DEBUG << r->get_name() << "does not have default axis order. "
<< "It might be part of a different subpattern";
sub_patterns.push_back(NodeVector{r});
}
}
bool modify_graph = false;
// Replace the patterns
for (auto sub_pattern : sub_patterns)
{
// Do not consider subpatterns with just one reshape in them
if (sub_pattern.size() == 1)
{
continue;
}
auto first_reshape = std::dynamic_pointer_cast<op::Reshape>(sub_pattern.front());
auto input_to_first_reshape = first_reshape->get_argument(0);
auto last_reshape = std::dynamic_pointer_cast<op::Reshape>(sub_pattern.back());
auto new_input_order = first_reshape->get_input_order();
auto new_out_shape = last_reshape->get_shape();
auto new_reshape = std::make_shared<op::Reshape>(
input_to_first_reshape, new_input_order, new_out_shape);
replace_node(last_reshape, new_reshape);
modify_graph = true;
}
return modify_graph;
};
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>(
reshape_label, op, empty_correlated_matches, callback);
this->add_matcher(m);
}
......@@ -17,12 +17,14 @@
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/pass_util.hpp"
namespace ngraph
{
namespace pass
{
class ReshapeElimination;
class RecurrentReshapeElimination;
}
}
......@@ -42,3 +44,16 @@ private:
void construct_identity_reshape_pattern();
void construct_reshapex2_pattern();
};
class ngraph::pass::RecurrentReshapeElimination : public ngraph::pass::RecurrentGraphRewrite
{
public:
RecurrentReshapeElimination()
: RecurrentGraphRewrite()
{
construct_recurrent_reshape();
}
private:
void construct_recurrent_reshape();
};
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment