Commit b14d5665 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

RecurrentGraphRewrite + tests (#833)

* add a getter for root node

* recurrent graph rewrite

* fix perms, rename match_root -> get_match_root

* fix comp errors

* make match_root return the topmost match; fix tests
parent b9b7845c
......@@ -53,3 +53,33 @@ bool ngraph::pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Functio
{
return run_matchers_on_nodes_list(f->get_ordered_ops(), m_matchers, f);
}
bool ngraph::pass::RecurrentGraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f)
{
bool changed = false;
size_t i = 0;
do
{
for (auto node : f->get_ops())
{
for (auto matcher : m_matchers)
{
NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node << " , "
<< node->get_name() << " , is_output = " << node->is_output();
if (matcher->match(node))
{
NGRAPH_DEBUG << "Matcher " << matcher << " matched " << node << " , "
<< node->get_name();
if (matcher->process_match())
{
changed = true;
goto next_fusion;
}
}
}
}
next_fusion:
i++;
} while (changed && i < m_num_iters);
return changed;
}
......@@ -25,10 +25,12 @@ namespace ngraph
namespace pass
{
class GraphRewrite;
class RecurrentGraphRewrite;
}
namespace pattern
{
class Matcher;
class RecurrentMatcher;
}
}
......@@ -62,3 +64,20 @@ private:
//enable cascading rewrites
std::vector<std::shared_ptr<pattern::Matcher>> m_matchers;
};
class ngraph::pass::RecurrentGraphRewrite : public FunctionPass
{
public:
RecurrentGraphRewrite(size_t num_iters = 10)
: FunctionPass()
, m_num_iters(num_iters)
{
}
void add_matcher(std::shared_ptr<pattern::RecurrentMatcher> m) { m_matchers.push_back(m); }
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
private:
size_t m_num_iters;
std::vector<std::shared_ptr<pattern::RecurrentMatcher>> m_matchers;
};
......@@ -276,7 +276,7 @@ namespace ngraph
Matcher m(m_pattern);
Matcher::PatternMap previous_matches;
m_matches.clear();
m_match_root.reset();
m_match_root = graph;
NGRAPH_DEBUG << "matching graph to " << graph->get_name() << std::endl;
//try to match one cell (i.e. pattern)
......@@ -315,6 +315,12 @@ namespace ngraph
}
}
}
if (!matched)
{
m_match_root.reset();
}
return matched;
}
......
......@@ -185,7 +185,7 @@ namespace ngraph
/// \brief Invoked by a pass to process a successful match
bool process_match();
std::shared_ptr<Node> match_root() { return m_match_root; }
std::shared_ptr<Node> get_match_root() { return m_match_root; }
private:
std::shared_ptr<Node> m_pattern;
std::shared_ptr<op::Label> m_recurrent_pattern;
......
......@@ -618,3 +618,95 @@ TEST(pattern, recurrent_pattern)
ASSERT_EQ(iconst_matches.at(1), iconst0);
ASSERT_EQ(iconst_matches.at(2), iconst0);
}
class TestRecurrentGraphRewrite : public ngraph::pass::RecurrentGraphRewrite
{
public:
void construct_recurrent_add()
{
Shape shape{};
auto iconst0 = construct_constant_node(0);
auto iconst_label =
std::make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
auto padd = iconst_label + rpattern;
auto sum_pattern = construct_sum_pattern();
ngraph::pattern::recurrent_graph_rewrite_callback callback = [iconst_label, rpattern](
pattern::RecurrentMatcher& rm) {
NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
<< rm.get_match_root()->get_name();
auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);
auto is_iconst_zero = [](std::shared_ptr<Node> n) {
bool result = is_zero(n);
NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
return is_zero(n);
};
bool are_all_iconst_zeros =
std::all_of(iconst_matches.begin(), iconst_matches.end(), is_iconst_zero);
if (!are_all_iconst_zeros)
{
return false;
}
auto number_of_adds = rm.get_number_of_recurrent_matches();
//replace the topmost add with the seed (i.e. the first parameter to add)
//matches are added in reverse order (i.e. the first match is the topmost node)
auto arg = rm.get_bound_nodes_for_pattern(rpattern).at(number_of_adds - 1);
NGRAPH_DEBUG << "Replacing " << rm.get_match_root()->get_name() << " with "
<< arg->get_name();
ngraph::replace_node(rm.get_match_root(), arg);
return true;
};
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto rm = make_shared<pattern::RecurrentMatcher>(
padd, rpattern, empty_correlated_matches, callback);
this->add_matcher(rm);
}
TestRecurrentGraphRewrite()
: RecurrentGraphRewrite()
{
construct_recurrent_add();
}
};
TEST(pattern, recurrent_graph_rewrite)
{
Shape shape{};
pass::Manager pass_manager;
pass_manager.register_pass<TestRecurrentGraphRewrite>();
{
auto a = make_shared<op::Parameter>(element::i32, shape);
auto iconst0 = construct_constant_node(0);
auto add_a1 = a + iconst0;
auto add_a2 = add_a1 + iconst0;
auto add_a3 = add_a2 + iconst0;
auto abs_add_a3 = std::make_shared<op::Abs>(add_a3);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto add_b1 = b + iconst0;
auto add_b2 = add_b1 + iconst0;
auto abs_add_b2 = std::make_shared<op::Abs>(add_b2);
auto graph = abs_add_a3 * abs_add_b2;
auto f = std::make_shared<Function>(ngraph::NodeVector{graph}, op::ParameterVector{a, b});
pass_manager.run_passes(f);
auto left_abs = graph->get_input_op(0);
auto add_a = left_abs->get_input_op(0);
ASSERT_EQ(add_a, a);
auto right_abs = graph->get_input_op(1);
auto add_b = right_abs->get_input_op(0);
ASSERT_EQ(add_b, b);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment