Commit 10ef07e6 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Repackaging match_recurring_pattern into RecurrentMatcher (#832)

* repacking recurrent matching as a standalone class

* RecurrentMatcher

* add a getter for root node

* address Scott's feedback
parent a2ab7b50
......@@ -58,7 +58,7 @@ void pass::CoreFusion::construct_relu_pattern()
auto skip_broadcast = std::make_shared<pattern::op::Any>(zero, broadcast_pred);
auto max = make_shared<op::Maximum>(skip_broadcast, val);
pattern::gr_callback_fn callback = [val, zero](pattern::Matcher& m) {
pattern::graph_rewrite_callback callback = [val, zero](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_relu_pattern against "
<< m.match_root()->get_name();
......
......@@ -161,7 +161,7 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
auto pdot = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred);
auto preshape = std::make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2});
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
<< m.match_root()->get_name();
......
......@@ -38,57 +38,6 @@ namespace ngraph
begin(arguments), end(arguments)); //vector is needed for generating permutations
}
bool Matcher::match_recurring_pattern(
std::shared_ptr<Node> graph,
std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern,
RPatternMap& patterns,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
{
bool matched = false;
Matcher m(pattern);
PatternMap previous_matches;
NGRAPH_DEBUG << "matching graph to " << graph->get_name() << std::endl;
//try to match one cell (i.e. pattern)
while (m.match(graph, previous_matches))
{
matched = true;
//move to the next cell
graph = m.m_pattern_map[rpattern];
NGRAPH_DEBUG << "setting graph to " << graph->get_name() << std::endl;
//copy bound nodes for the current pattern graph into a global matches map
for (auto cur_match : m.m_pattern_map)
{
patterns[cur_match.first].push_back(cur_match.second);
}
//pre-populate the pattern map for the next cell with the bound nodes
//from the current match. Only bound nodes whose labels are in
//correlated_patterns are pre-populated. Any other labels are
//unbounded by default
for (auto cor_pat : correlated_patterns)
{
if (m.m_pattern_map.count(cor_pat) != 0)
{
//assert that bound nodes from the previous and current matches are the same
if (previous_matches.count(cor_pat) != 0)
{
if (previous_matches[cor_pat] != m.m_pattern_map[cor_pat])
{
throw ngraph_error(
"previous matches and current matches aren't consistent!");
}
}
previous_matches[cor_pat] = m.m_pattern_map[cor_pat];
}
}
}
return matched;
}
std::shared_ptr<Node> Matcher::match_root() { return m_match_root; }
bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node,
......@@ -253,9 +202,9 @@ namespace ngraph
return false;
}
bool Matcher::process_match(::ngraph::pattern::gr_callback_fn callback)
bool Matcher::process_match(::ngraph::pattern::graph_rewrite_callback callback)
{
gr_callback_fn cb = m_callback;
graph_rewrite_callback cb = m_callback;
if (callback)
{
cb = callback;
......@@ -320,5 +269,55 @@ namespace ngraph
}
return is_match;
}
bool RecurrentMatcher::match(std::shared_ptr<Node> graph)
{
bool matched = false;
Matcher m(m_pattern);
Matcher::PatternMap previous_matches;
m_matches.clear();
m_match_root.reset();
NGRAPH_DEBUG << "matching graph to " << graph->get_name() << std::endl;
//try to match one cell (i.e. pattern)
while (m.match(graph, previous_matches))
{
matched = true;
//move to the next cell
graph = m.get_pattern_map()[m_recurrent_pattern];
NGRAPH_DEBUG << "setting graph to " << graph->get_name() << std::endl;
//copy bound nodes for the current pattern graph into a global matches map
for (auto cur_match : m.get_pattern_map())
{
m_matches[cur_match.first].push_back(cur_match.second);
}
//pre-populate the pattern map for the next cell with the bound nodes
//from the current match. Only bound nodes whose labels are in
//correlated_patterns are pre-populated. Any other labels are
//unbounded by default
for (auto cor_pat : m_correlated_patterns)
{
if (m.get_pattern_map().count(cor_pat) != 0)
{
//assert that bound nodes from the previous and current matches are the same
if (previous_matches.count(cor_pat) != 0)
{
if (previous_matches[cor_pat] != m.get_pattern_map()[cor_pat])
{
throw ngraph_error(
"previous matches and current matches aren't consistent!");
}
}
previous_matches[cor_pat] = m.get_pattern_map()[cor_pat];
}
}
}
return matched;
}
bool RecurrentMatcher::process_match() { return m_callback(*this); }
}
}
......@@ -32,7 +32,8 @@ namespace ngraph
namespace pattern
{
using gr_callback_fn = std::function<bool(class Matcher& m)>;
using graph_rewrite_callback = std::function<bool(class Matcher& m)>;
using recurrent_graph_rewrite_callback = std::function<bool(class RecurrentMatcher& m)>;
using RPatternMap = std::map<std::shared_ptr<op::Label>, NodeVector>;
namespace op
......@@ -52,7 +53,7 @@ namespace ngraph
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
/// \param callback is a callback function that will be called on a successful match
Matcher(const std::shared_ptr<Node> pattern_node = nullptr,
gr_callback_fn callback = nullptr)
graph_rewrite_callback callback = nullptr)
: m_pattern_node(pattern_node)
, m_callback(callback)
, m_depth(0)
......@@ -91,7 +92,7 @@ namespace ngraph
return matched;
}
bool process_match(gr_callback_fn callback = nullptr);
bool process_match(graph_rewrite_callback callback = nullptr);
void reset() {}
std::shared_ptr<Node> pattern_node() { return m_pattern_node; }
......@@ -103,14 +104,6 @@ namespace ngraph
/// \param pattern is a recurring pattern
/// \param rpattern specifies a node to recur from next
/// \param patterns a map from labels to matches
/// \param correlated_patterns specify labels whose bound nodes should be
/// the same across all cells
static bool match_recurring_pattern(
std::shared_ptr<Node> graph,
std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern,
RPatternMap& patterns,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
friend op::Label; //TODO: refine to match_class
protected:
......@@ -138,8 +131,68 @@ namespace ngraph
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
gr_callback_fn m_callback;
graph_rewrite_callback m_callback;
size_t m_depth;
};
class RecurrentMatcher
{
public:
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
///
/// \param pattern is a pattern sub graph describing an individual cell
/// \param rpattern is a (recurring) label to denote which node the next match should start at
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same across all cells
// \param is a callback function that will be called on a successful match
RecurrentMatcher(std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns,
recurrent_graph_rewrite_callback callback)
: m_pattern(pattern)
, m_recurrent_pattern(rpattern)
, m_correlated_patterns(correlated_patterns)
, m_callback(callback)
{
}
/// \brief Returns a vector of bound nodes for a given label (used in a pattern
/// describing an individual cell
NodeVector get_bound_nodes_for_pattern(std::shared_ptr<op::Label> pattern) const
{
if (m_matches.count(pattern) == 0)
{
throw ngraph_error("No bound nodes for a given label");
}
return NodeVector{m_matches.at(pattern)};
}
size_t get_number_of_recurrent_matches() const
{
if (m_matches.size() == 0)
{
return 0;
}
return (*m_matches.begin()).second.size();
}
size_t get_number_of_bound_labels() const { return m_matches.size(); }
/// \brief Tries to match a pattern for an individual cell to a given \p graph
bool match(std::shared_ptr<Node> graph);
/// \brief Invoked by a pass to process a successful match
bool process_match();
std::shared_ptr<Node> match_root() { return m_match_root; }
private:
std::shared_ptr<Node> m_pattern;
std::shared_ptr<op::Label> m_recurrent_pattern;
const std::set<std::shared_ptr<op::Label>> m_correlated_patterns;
RPatternMap m_matches;
recurrent_graph_rewrite_callback m_callback;
std::shared_ptr<Node> m_match_root;
};
}
}
......@@ -136,7 +136,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias()
auto pbroadcast = std::make_shared<op::Broadcast>(b, pmmb->get_shape(), AxisSet{0});
auto padd = pmmb + pbroadcast;
ngraph::pattern::gr_callback_fn callback = [W, x](pattern::Matcher& m) {
ngraph::pattern::graph_rewrite_callback callback = [W, x](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_matmulbias_pattern against node = "
<< m.match_root()->get_name();
......@@ -182,7 +182,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul()
auto pdot = std::make_shared<op::Dot>(skip_w, skip_x);
ngraph::pattern::gr_callback_fn callback = [W, x](pattern::Matcher& m) {
ngraph::pattern::graph_rewrite_callback callback = [W, x](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_matmul_pattern against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -285,7 +285,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
// This completes fprop bn pattern
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback =
ngraph::pattern::graph_rewrite_callback callback =
[variance_label, mean_label, input, eps_label, gamma_label, beta_label](
pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against "
......@@ -411,7 +411,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
Strides{1, 1});
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
ngraph::pattern::gr_callback_fn callback =
ngraph::pattern::graph_rewrite_callback callback =
[pad_input, pad_value, pad_label, reshape_label, conv_filter, conv_label](
pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
......@@ -489,7 +489,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
Strides{1, 1});
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
ngraph::pattern::gr_callback_fn callback =
ngraph::pattern::graph_rewrite_callback callback =
[pad_input, pad_value, pad_label, conv_filter, conv_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
......@@ -554,7 +554,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv_backprop_
Strides{1, 1});
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
ngraph::pattern::gr_callback_fn callback =
ngraph::pattern::graph_rewrite_callback callback =
[pad_input, pad_value, pad_label, output_delta, conv_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
......@@ -616,7 +616,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback = [input](pattern::Matcher& m) {
ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -668,7 +668,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_bprop()
auto negtive_2 = std::make_shared<op::Negative>(multiply_2);
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback = [input, delta](pattern::Matcher& m) {
ngraph::pattern::graph_rewrite_callback callback = [input, delta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -712,7 +712,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
Strides{1, 1});
auto p_conv_bias = pbroadcast + pconv1;
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_conv_bias against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -767,7 +767,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
auto goe = std::make_shared<op::GetOutputElement>(bn, 0);
auto prelu = std::make_shared<op::Relu>(goe);
ngraph::pattern::gr_callback_fn callback = [input, gamma, beta](pattern::Matcher& m) {
ngraph::pattern::graph_rewrite_callback callback = [input, gamma, beta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_batch_norm_relu against node = "
<< m.match_root()->get_name();
......@@ -841,7 +841,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
auto prelu = std::make_shared<op::Relu>(pconv);
pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_conv_relu against "
<< m.match_root()->get_name();
......
......@@ -155,7 +155,7 @@ public:
auto iconst1 = construct_constant_node(1);
auto pattern = std::make_shared<pattern::op::Label>(iconst1);
ngraph::pattern::gr_callback_fn callback = [pattern](pattern::Matcher& m) {
ngraph::pattern::graph_rewrite_callback callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
<< m.match_root()->get_name();
assert(m.match_root()->get_input_ops().size() == 2);
......@@ -243,7 +243,7 @@ public:
{
auto sum_pattern = construct_sum_pattern();
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_sum_pattern against "
<< m.match_root()->get_name();
auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.match_root());
......@@ -556,7 +556,7 @@ TEST(pattern, previous_matches)
TEST(pattern, recurrent_pattern)
{
using ngraph::pattern::Matcher;
using ngraph::pattern::RecurrentMatcher;
Shape shape{};
ngraph::pattern::Matcher::PatternMap previous_matches;
auto a = make_shared<op::Parameter>(element::i32, shape);
......@@ -568,12 +568,11 @@ TEST(pattern, recurrent_pattern)
auto add2 = iconst0 + add1;
auto add3 = iconst0 + add2;
auto padd = iconst0 + rpattern;
ngraph::pattern::RPatternMap matches;
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
Matcher::match_recurring_pattern(add3, padd, rpattern, matches, empty_correlated_matches);
ASSERT_EQ(matches.size(), 1);
ASSERT_EQ(matches.count(rpattern), 1);
auto recurrent_matches = matches[rpattern];
RecurrentMatcher rm(padd, rpattern, empty_correlated_matches, nullptr);
ASSERT_TRUE(rm.match(add3));
ASSERT_EQ(rm.get_number_of_bound_labels(), 1);
auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
ASSERT_EQ(recurrent_matches.at(0), add2);
ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_EQ(recurrent_matches.at(2), b);
......@@ -584,37 +583,37 @@ TEST(pattern, recurrent_pattern)
auto add2_2 = iconst1 + add1;
auto add3_2 = iconst0 + add2_2;
auto padd2 = iconst_label + rpattern;
matches.clear();
Matcher::match_recurring_pattern(add3_2, padd2, rpattern, matches, empty_correlated_matches);
ASSERT_EQ(matches.size(), 2);
recurrent_matches = matches[rpattern];
RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches, nullptr);
ASSERT_TRUE(rm2.match(add3_2));
ASSERT_EQ(rm2.get_number_of_bound_labels(), 2);
recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
ASSERT_EQ(recurrent_matches.at(0), add2_2);
ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_EQ(recurrent_matches.at(2), b);
auto iconst_matches = matches[iconst_label];
auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
ASSERT_EQ(iconst_matches.at(0), iconst0);
ASSERT_EQ(iconst_matches.at(1), iconst1);
ASSERT_EQ(iconst_matches.at(2), iconst0);
//Non-matching correlated labels
matches.clear();
std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
correlated_matches.insert(iconst_label);
Matcher::match_recurring_pattern(add3_2, padd2, rpattern, matches, correlated_matches);
ASSERT_EQ(matches.size(), 2);
iconst_matches = matches[iconst_label];
RecurrentMatcher rm3(padd2, rpattern, correlated_matches, nullptr);
ASSERT_TRUE(rm3.match(add3_2));
ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
ASSERT_EQ(iconst_matches.size(), 1);
ASSERT_EQ(iconst_matches.at(0), iconst0);
//Matching correlated labels
matches.clear();
Matcher::match_recurring_pattern(add3, padd2, rpattern, matches, correlated_matches);
ASSERT_EQ(matches.size(), 2);
recurrent_matches = matches[rpattern];
//Matching correlated labels and
//testing if RecurrentMatcher can be reused for different nodes
ASSERT_TRUE(rm3.match(add3));
ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
ASSERT_EQ(recurrent_matches.at(0), add2);
ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_EQ(recurrent_matches.at(2), b);
iconst_matches = matches[iconst_label];
iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
ASSERT_EQ(iconst_matches.at(0), iconst0);
ASSERT_EQ(iconst_matches.at(1), iconst0);
ASSERT_EQ(iconst_matches.at(2), iconst0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment