Commit 10ef07e6 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Repackaging match_recurring_pattern into RecurrentMatcher (#832)

* repacking recurrent matching as a standalone class

* RecurrentMatcher

* add a getter for root node

* address Scott's feedback
parent a2ab7b50
...@@ -58,7 +58,7 @@ void pass::CoreFusion::construct_relu_pattern() ...@@ -58,7 +58,7 @@ void pass::CoreFusion::construct_relu_pattern()
auto skip_broadcast = std::make_shared<pattern::op::Any>(zero, broadcast_pred); auto skip_broadcast = std::make_shared<pattern::op::Any>(zero, broadcast_pred);
auto max = make_shared<op::Maximum>(skip_broadcast, val); auto max = make_shared<op::Maximum>(skip_broadcast, val);
pattern::gr_callback_fn callback = [val, zero](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [val, zero](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_relu_pattern against " NGRAPH_DEBUG << "In a callback for construct_relu_pattern against "
<< m.match_root()->get_name(); << m.match_root()->get_name();
......
...@@ -161,7 +161,7 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern() ...@@ -161,7 +161,7 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
auto pdot = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred); auto pdot = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred);
auto preshape = std::make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2}); auto preshape = std::make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2});
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = " NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
<< m.match_root()->get_name(); << m.match_root()->get_name();
......
...@@ -38,57 +38,6 @@ namespace ngraph ...@@ -38,57 +38,6 @@ namespace ngraph
begin(arguments), end(arguments)); //vector is needed for generating permutations begin(arguments), end(arguments)); //vector is needed for generating permutations
} }
bool Matcher::match_recurring_pattern(
std::shared_ptr<Node> graph,
std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern,
RPatternMap& patterns,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
{
bool matched = false;
Matcher m(pattern);
PatternMap previous_matches;
NGRAPH_DEBUG << "matching graph to " << graph->get_name() << std::endl;
//try to match one cell (i.e. pattern)
while (m.match(graph, previous_matches))
{
matched = true;
//move to the next cell
graph = m.m_pattern_map[rpattern];
NGRAPH_DEBUG << "setting graph to " << graph->get_name() << std::endl;
//copy bound nodes for the current pattern graph into a global matches map
for (auto cur_match : m.m_pattern_map)
{
patterns[cur_match.first].push_back(cur_match.second);
}
//pre-populate the pattern map for the next cell with the bound nodes
//from the current match. Only bound nodes whose labels are in
//correlated_patterns are pre-populated. Any other labels are
//unbounded by default
for (auto cor_pat : correlated_patterns)
{
if (m.m_pattern_map.count(cor_pat) != 0)
{
//assert that bound nodes from the previous and current matches are the same
if (previous_matches.count(cor_pat) != 0)
{
if (previous_matches[cor_pat] != m.m_pattern_map[cor_pat])
{
throw ngraph_error(
"previous matches and current matches aren't consistent!");
}
}
previous_matches[cor_pat] = m.m_pattern_map[cor_pat];
}
}
}
return matched;
}
std::shared_ptr<Node> Matcher::match_root() { return m_match_root; } std::shared_ptr<Node> Matcher::match_root() { return m_match_root; }
bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label, bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node, const std::shared_ptr<Node>& graph_node,
...@@ -253,9 +202,9 @@ namespace ngraph ...@@ -253,9 +202,9 @@ namespace ngraph
return false; return false;
} }
bool Matcher::process_match(::ngraph::pattern::gr_callback_fn callback) bool Matcher::process_match(::ngraph::pattern::graph_rewrite_callback callback)
{ {
gr_callback_fn cb = m_callback; graph_rewrite_callback cb = m_callback;
if (callback) if (callback)
{ {
cb = callback; cb = callback;
...@@ -320,5 +269,55 @@ namespace ngraph ...@@ -320,5 +269,55 @@ namespace ngraph
} }
return is_match; return is_match;
} }
bool RecurrentMatcher::match(std::shared_ptr<Node> graph)
{
bool matched = false;
Matcher m(m_pattern);
Matcher::PatternMap previous_matches;
m_matches.clear();
m_match_root.reset();
NGRAPH_DEBUG << "matching graph to " << graph->get_name() << std::endl;
//try to match one cell (i.e. pattern)
while (m.match(graph, previous_matches))
{
matched = true;
//move to the next cell
graph = m.get_pattern_map()[m_recurrent_pattern];
NGRAPH_DEBUG << "setting graph to " << graph->get_name() << std::endl;
//copy bound nodes for the current pattern graph into a global matches map
for (auto cur_match : m.get_pattern_map())
{
m_matches[cur_match.first].push_back(cur_match.second);
}
//pre-populate the pattern map for the next cell with the bound nodes
//from the current match. Only bound nodes whose labels are in
//correlated_patterns are pre-populated. Any other labels are
//unbounded by default
for (auto cor_pat : m_correlated_patterns)
{
if (m.get_pattern_map().count(cor_pat) != 0)
{
//assert that bound nodes from the previous and current matches are the same
if (previous_matches.count(cor_pat) != 0)
{
if (previous_matches[cor_pat] != m.get_pattern_map()[cor_pat])
{
throw ngraph_error(
"previous matches and current matches aren't consistent!");
}
}
previous_matches[cor_pat] = m.get_pattern_map()[cor_pat];
}
}
}
return matched;
}
bool RecurrentMatcher::process_match() { return m_callback(*this); }
} }
} }
...@@ -32,7 +32,8 @@ namespace ngraph ...@@ -32,7 +32,8 @@ namespace ngraph
namespace pattern namespace pattern
{ {
using gr_callback_fn = std::function<bool(class Matcher& m)>; using graph_rewrite_callback = std::function<bool(class Matcher& m)>;
using recurrent_graph_rewrite_callback = std::function<bool(class RecurrentMatcher& m)>;
using RPatternMap = std::map<std::shared_ptr<op::Label>, NodeVector>; using RPatternMap = std::map<std::shared_ptr<op::Label>, NodeVector>;
namespace op namespace op
...@@ -52,7 +53,7 @@ namespace ngraph ...@@ -52,7 +53,7 @@ namespace ngraph
/// \param pattern_node is a pattern sub graph that will be matched against input graphs /// \param pattern_node is a pattern sub graph that will be matched against input graphs
/// \param callback is a callback function that will be called on a successful match /// \param callback is a callback function that will be called on a successful match
Matcher(const std::shared_ptr<Node> pattern_node = nullptr, Matcher(const std::shared_ptr<Node> pattern_node = nullptr,
gr_callback_fn callback = nullptr) graph_rewrite_callback callback = nullptr)
: m_pattern_node(pattern_node) : m_pattern_node(pattern_node)
, m_callback(callback) , m_callback(callback)
, m_depth(0) , m_depth(0)
...@@ -91,7 +92,7 @@ namespace ngraph ...@@ -91,7 +92,7 @@ namespace ngraph
return matched; return matched;
} }
bool process_match(gr_callback_fn callback = nullptr); bool process_match(graph_rewrite_callback callback = nullptr);
void reset() {} void reset() {}
std::shared_ptr<Node> pattern_node() { return m_pattern_node; } std::shared_ptr<Node> pattern_node() { return m_pattern_node; }
...@@ -103,14 +104,6 @@ namespace ngraph ...@@ -103,14 +104,6 @@ namespace ngraph
/// \param pattern is a recurring pattern /// \param pattern is a recurring pattern
/// \param rpattern specifies a node to recur from next /// \param rpattern specifies a node to recur from next
/// \param patterns a map from labels to matches /// \param patterns a map from labels to matches
/// \param correlated_patterns specify labels whose bound nodes should be
/// the same across all cells
static bool match_recurring_pattern(
std::shared_ptr<Node> graph,
std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern,
RPatternMap& patterns,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
friend op::Label; //TODO: refine to match_class friend op::Label; //TODO: refine to match_class
protected: protected:
...@@ -138,8 +131,68 @@ namespace ngraph ...@@ -138,8 +131,68 @@ namespace ngraph
const std::shared_ptr<Node>& graph_node, const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map); PatternMap& pattern_map);
gr_callback_fn m_callback; graph_rewrite_callback m_callback;
size_t m_depth; size_t m_depth;
}; };
class RecurrentMatcher
{
public:
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
///
/// \param pattern is a pattern sub graph describing an individual cell
/// \param rpattern is a (recurring) label to denote which node the next match should start at
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same across all cells
// \param is a callback function that will be called on a successful match
RecurrentMatcher(std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns,
recurrent_graph_rewrite_callback callback)
: m_pattern(pattern)
, m_recurrent_pattern(rpattern)
, m_correlated_patterns(correlated_patterns)
, m_callback(callback)
{
}
/// \brief Returns a vector of bound nodes for a given label (used in a pattern
/// describing an individual cell
NodeVector get_bound_nodes_for_pattern(std::shared_ptr<op::Label> pattern) const
{
if (m_matches.count(pattern) == 0)
{
throw ngraph_error("No bound nodes for a given label");
}
return NodeVector{m_matches.at(pattern)};
}
size_t get_number_of_recurrent_matches() const
{
if (m_matches.size() == 0)
{
return 0;
}
return (*m_matches.begin()).second.size();
}
size_t get_number_of_bound_labels() const { return m_matches.size(); }
/// \brief Tries to match a pattern for an individual cell to a given \p graph
bool match(std::shared_ptr<Node> graph);
/// \brief Invoked by a pass to process a successful match
bool process_match();
std::shared_ptr<Node> match_root() { return m_match_root; }
private:
std::shared_ptr<Node> m_pattern;
std::shared_ptr<op::Label> m_recurrent_pattern;
const std::set<std::shared_ptr<op::Label>> m_correlated_patterns;
RPatternMap m_matches;
recurrent_graph_rewrite_callback m_callback;
std::shared_ptr<Node> m_match_root;
};
} }
} }
...@@ -136,7 +136,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias() ...@@ -136,7 +136,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias()
auto pbroadcast = std::make_shared<op::Broadcast>(b, pmmb->get_shape(), AxisSet{0}); auto pbroadcast = std::make_shared<op::Broadcast>(b, pmmb->get_shape(), AxisSet{0});
auto padd = pmmb + pbroadcast; auto padd = pmmb + pbroadcast;
ngraph::pattern::gr_callback_fn callback = [W, x](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [W, x](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_matmulbias_pattern against node = " NGRAPH_DEBUG << "In callback for construct_matmulbias_pattern against node = "
<< m.match_root()->get_name(); << m.match_root()->get_name();
...@@ -182,7 +182,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul() ...@@ -182,7 +182,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul()
auto pdot = std::make_shared<op::Dot>(skip_w, skip_x); auto pdot = std::make_shared<op::Dot>(skip_w, skip_x);
ngraph::pattern::gr_callback_fn callback = [W, x](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [W, x](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_matmul_pattern against node = " NGRAPH_DEBUG << "In callback for construct_matmul_pattern against node = "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -285,7 +285,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn() ...@@ -285,7 +285,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
// This completes fprop bn pattern // This completes fprop bn pattern
//Define a call back that needs to called once the DFG matches the pattern //Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback = ngraph::pattern::graph_rewrite_callback callback =
[variance_label, mean_label, input, eps_label, gamma_label, beta_label]( [variance_label, mean_label, input, eps_label, gamma_label, beta_label](
pattern::Matcher& m) { pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against "
...@@ -411,7 +411,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv( ...@@ -411,7 +411,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv(
Strides{1, 1}); Strides{1, 1});
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv}); auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
ngraph::pattern::gr_callback_fn callback = ngraph::pattern::graph_rewrite_callback callback =
[pad_input, pad_value, pad_label, reshape_label, conv_filter, conv_label]( [pad_input, pad_value, pad_label, reshape_label, conv_filter, conv_label](
pattern::Matcher& m) { pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -489,7 +489,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv() ...@@ -489,7 +489,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
Strides{1, 1}); Strides{1, 1});
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv}); auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
ngraph::pattern::gr_callback_fn callback = ngraph::pattern::graph_rewrite_callback callback =
[pad_input, pad_value, pad_label, conv_filter, conv_label](pattern::Matcher& m) { [pad_input, pad_value, pad_label, conv_filter, conv_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -554,7 +554,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv_backprop_ ...@@ -554,7 +554,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv_backprop_
Strides{1, 1}); Strides{1, 1});
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv}); auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
ngraph::pattern::gr_callback_fn callback = ngraph::pattern::graph_rewrite_callback callback =
[pad_input, pad_value, pad_label, output_delta, conv_label](pattern::Matcher& m) { [pad_input, pad_value, pad_label, output_delta, conv_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -616,7 +616,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid() ...@@ -616,7 +616,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp); auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
//Define a call back that needs to called once the DFG matches the pattern //Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback = [input](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -668,7 +668,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_bprop() ...@@ -668,7 +668,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_bprop()
auto negtive_2 = std::make_shared<op::Negative>(multiply_2); auto negtive_2 = std::make_shared<op::Negative>(multiply_2);
//Define a call back that needs to called once the DFG matches the pattern //Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback = [input, delta](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [input, delta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -712,7 +712,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias() ...@@ -712,7 +712,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
Strides{1, 1}); Strides{1, 1});
auto p_conv_bias = pbroadcast + pconv1; auto p_conv_bias = pbroadcast + pconv1;
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_conv_bias against node = " NGRAPH_DEBUG << "In callback for construct_conv_bias against node = "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -767,7 +767,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu() ...@@ -767,7 +767,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
auto goe = std::make_shared<op::GetOutputElement>(bn, 0); auto goe = std::make_shared<op::GetOutputElement>(bn, 0);
auto prelu = std::make_shared<op::Relu>(goe); auto prelu = std::make_shared<op::Relu>(goe);
ngraph::pattern::gr_callback_fn callback = [input, gamma, beta](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [input, gamma, beta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_batch_norm_relu against node = " NGRAPH_DEBUG << "In callback for construct_batch_norm_relu against node = "
<< m.match_root()->get_name(); << m.match_root()->get_name();
...@@ -841,7 +841,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu() ...@@ -841,7 +841,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
auto prelu = std::make_shared<op::Relu>(pconv); auto prelu = std::make_shared<op::Relu>(pconv);
pattern::gr_callback_fn callback = [](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_conv_relu against " NGRAPH_DEBUG << "In a callback for construct_conv_relu against "
<< m.match_root()->get_name(); << m.match_root()->get_name();
......
...@@ -155,7 +155,7 @@ public: ...@@ -155,7 +155,7 @@ public:
auto iconst1 = construct_constant_node(1); auto iconst1 = construct_constant_node(1);
auto pattern = std::make_shared<pattern::op::Label>(iconst1); auto pattern = std::make_shared<pattern::op::Label>(iconst1);
ngraph::pattern::gr_callback_fn callback = [pattern](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against " NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
<< m.match_root()->get_name(); << m.match_root()->get_name();
assert(m.match_root()->get_input_ops().size() == 2); assert(m.match_root()->get_input_ops().size() == 2);
...@@ -243,7 +243,7 @@ public: ...@@ -243,7 +243,7 @@ public:
{ {
auto sum_pattern = construct_sum_pattern(); auto sum_pattern = construct_sum_pattern();
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) { ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_sum_pattern against " NGRAPH_DEBUG << "In a callback for construct_sum_pattern against "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.match_root()); auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.match_root());
...@@ -556,7 +556,7 @@ TEST(pattern, previous_matches) ...@@ -556,7 +556,7 @@ TEST(pattern, previous_matches)
TEST(pattern, recurrent_pattern) TEST(pattern, recurrent_pattern)
{ {
using ngraph::pattern::Matcher; using ngraph::pattern::RecurrentMatcher;
Shape shape{}; Shape shape{};
ngraph::pattern::Matcher::PatternMap previous_matches; ngraph::pattern::Matcher::PatternMap previous_matches;
auto a = make_shared<op::Parameter>(element::i32, shape); auto a = make_shared<op::Parameter>(element::i32, shape);
...@@ -568,12 +568,11 @@ TEST(pattern, recurrent_pattern) ...@@ -568,12 +568,11 @@ TEST(pattern, recurrent_pattern)
auto add2 = iconst0 + add1; auto add2 = iconst0 + add1;
auto add3 = iconst0 + add2; auto add3 = iconst0 + add2;
auto padd = iconst0 + rpattern; auto padd = iconst0 + rpattern;
ngraph::pattern::RPatternMap matches;
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches; std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
Matcher::match_recurring_pattern(add3, padd, rpattern, matches, empty_correlated_matches); RecurrentMatcher rm(padd, rpattern, empty_correlated_matches, nullptr);
ASSERT_EQ(matches.size(), 1); ASSERT_TRUE(rm.match(add3));
ASSERT_EQ(matches.count(rpattern), 1); ASSERT_EQ(rm.get_number_of_bound_labels(), 1);
auto recurrent_matches = matches[rpattern]; auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
ASSERT_EQ(recurrent_matches.at(0), add2); ASSERT_EQ(recurrent_matches.at(0), add2);
ASSERT_EQ(recurrent_matches.at(1), add1); ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_EQ(recurrent_matches.at(2), b); ASSERT_EQ(recurrent_matches.at(2), b);
...@@ -584,37 +583,37 @@ TEST(pattern, recurrent_pattern) ...@@ -584,37 +583,37 @@ TEST(pattern, recurrent_pattern)
auto add2_2 = iconst1 + add1; auto add2_2 = iconst1 + add1;
auto add3_2 = iconst0 + add2_2; auto add3_2 = iconst0 + add2_2;
auto padd2 = iconst_label + rpattern; auto padd2 = iconst_label + rpattern;
matches.clear(); RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches, nullptr);
Matcher::match_recurring_pattern(add3_2, padd2, rpattern, matches, empty_correlated_matches); ASSERT_TRUE(rm2.match(add3_2));
ASSERT_EQ(matches.size(), 2); ASSERT_EQ(rm2.get_number_of_bound_labels(), 2);
recurrent_matches = matches[rpattern]; recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
ASSERT_EQ(recurrent_matches.at(0), add2_2); ASSERT_EQ(recurrent_matches.at(0), add2_2);
ASSERT_EQ(recurrent_matches.at(1), add1); ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_EQ(recurrent_matches.at(2), b); ASSERT_EQ(recurrent_matches.at(2), b);
auto iconst_matches = matches[iconst_label]; auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
ASSERT_EQ(iconst_matches.at(0), iconst0); ASSERT_EQ(iconst_matches.at(0), iconst0);
ASSERT_EQ(iconst_matches.at(1), iconst1); ASSERT_EQ(iconst_matches.at(1), iconst1);
ASSERT_EQ(iconst_matches.at(2), iconst0); ASSERT_EQ(iconst_matches.at(2), iconst0);
//Non-matching correlated labels //Non-matching correlated labels
matches.clear();
std::set<std::shared_ptr<pattern::op::Label>> correlated_matches; std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
correlated_matches.insert(iconst_label); correlated_matches.insert(iconst_label);
Matcher::match_recurring_pattern(add3_2, padd2, rpattern, matches, correlated_matches); RecurrentMatcher rm3(padd2, rpattern, correlated_matches, nullptr);
ASSERT_EQ(matches.size(), 2); ASSERT_TRUE(rm3.match(add3_2));
iconst_matches = matches[iconst_label]; ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
ASSERT_EQ(iconst_matches.size(), 1); ASSERT_EQ(iconst_matches.size(), 1);
ASSERT_EQ(iconst_matches.at(0), iconst0); ASSERT_EQ(iconst_matches.at(0), iconst0);
//Matching correlated labels //Matching correlated labels and
matches.clear(); //testing if RecurrentMatcher can be reused for different nodes
Matcher::match_recurring_pattern(add3, padd2, rpattern, matches, correlated_matches); ASSERT_TRUE(rm3.match(add3));
ASSERT_EQ(matches.size(), 2); ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
recurrent_matches = matches[rpattern]; recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
ASSERT_EQ(recurrent_matches.at(0), add2); ASSERT_EQ(recurrent_matches.at(0), add2);
ASSERT_EQ(recurrent_matches.at(1), add1); ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_EQ(recurrent_matches.at(2), b); ASSERT_EQ(recurrent_matches.at(2), b);
iconst_matches = matches[iconst_label]; iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
ASSERT_EQ(iconst_matches.at(0), iconst0); ASSERT_EQ(iconst_matches.at(0), iconst0);
ASSERT_EQ(iconst_matches.at(1), iconst0); ASSERT_EQ(iconst_matches.at(1), iconst0);
ASSERT_EQ(iconst_matches.at(2), iconst0); ASSERT_EQ(iconst_matches.at(2), iconst0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment