Commit a8cd0e94 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Support for Recurring Patterns (#782)

* initial support for recurring matching

* fix a bug where patterns weren't populated w/ matched nodes; add recurrent tests

* add a missing newline

* address feedback

* fix function comment
parent fa6c2a60
...@@ -38,6 +38,57 @@ namespace ngraph ...@@ -38,6 +38,57 @@ namespace ngraph
begin(arguments), end(arguments)); //vector is needed for generating permutations begin(arguments), end(arguments)); //vector is needed for generating permutations
} }
bool Matcher::match_recurring_pattern(
std::shared_ptr<Node> graph,
std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern,
RPatternMap& patterns,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
{
bool matched = false;
Matcher m(pattern);
PatternMap previous_matches;
NGRAPH_DEBUG << "matching graph to " << graph->get_name() << std::endl;
//try to match one cell (i.e. pattern)
while (m.match(graph, previous_matches))
{
matched = true;
//move to the next cell
graph = m.m_pattern_map[rpattern];
NGRAPH_DEBUG << "setting graph to " << graph->get_name() << std::endl;
//copy bound nodes for the current pattern graph into a global matches map
for (auto cur_match : m.m_pattern_map)
{
patterns[cur_match.first].push_back(cur_match.second);
}
//pre-populate the pattern map for the next cell with the bound nodes
//from the current match. Only bound nodes whose labels are in
//correlated_patterns are pre-populated. Any other labels are
//unbounded by default
for (auto cor_pat : correlated_patterns)
{
if (m.m_pattern_map.count(cor_pat) != 0)
{
//assert that bound nodes from the previous and current matches are the same
if (previous_matches.count(cor_pat) != 0)
{
if (previous_matches[cor_pat] != m.m_pattern_map[cor_pat])
{
throw ngraph_error(
"previous matches and current matches aren't consistent!");
}
}
previous_matches[cor_pat] = m.m_pattern_map[cor_pat];
}
}
}
return matched;
}
std::shared_ptr<Node> Matcher::match_root() { return m_match_root; } std::shared_ptr<Node> Matcher::match_root() { return m_match_root; }
bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label, bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node, const std::shared_ptr<Node>& graph_node,
...@@ -230,7 +281,33 @@ namespace ngraph ...@@ -230,7 +281,33 @@ namespace ngraph
if (!m_pattern_node || !graph_node) if (!m_pattern_node || !graph_node)
{ {
throw "m_pattern_node or graph_node are not set!"; throw ngraph_error("m_pattern_node or graph_node are not set");
}
NGRAPH_DEBUG << "[MATCHER] Starting match pattern = " << m_pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
bool is_match = match_node(m_pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
}
bool Matcher::match(const std::shared_ptr<Node>& graph_node,
const PatternMap& previous_matches)
{
//clear our state
m_match_root.reset();
m_pattern_map.clear();
//insert previous matches
m_pattern_map.insert(previous_matches.cbegin(), previous_matches.cend());
if (!m_pattern_node || !graph_node)
{
throw ngraph_error("m_pattern_node or graph_node are not set");
} }
NGRAPH_DEBUG << "[MATCHER] Starting match pattern = " << m_pattern_node->get_name() NGRAPH_DEBUG << "[MATCHER] Starting match pattern = " << m_pattern_node->get_name()
......
...@@ -33,6 +33,7 @@ namespace ngraph ...@@ -33,6 +33,7 @@ namespace ngraph
namespace pattern namespace pattern
{ {
using gr_callback_fn = std::function<bool(class Matcher& m)>; using gr_callback_fn = std::function<bool(class Matcher& m)>;
using RPatternMap = std::map<std::shared_ptr<op::Label>, NodeVector>;
namespace op namespace op
{ {
...@@ -63,6 +64,12 @@ namespace ngraph ...@@ -63,6 +64,12 @@ namespace ngraph
/// \param graph_node is an input graph to be matched against /// \param graph_node is an input graph to be matched against
bool match(const std::shared_ptr<Node>& graph_node); bool match(const std::shared_ptr<Node>& graph_node);
/// \brief Matches a pattern to \p graph_node
///
/// \param graph_node is an input graph to be matched against
/// \param previous_matches contains previous mappings from labels to nodes to use
bool match(const std::shared_ptr<Node>& graph_node, const PatternMap& previous_matches);
template <typename T> template <typename T>
static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node) static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node)
{ {
...@@ -90,6 +97,20 @@ namespace ngraph ...@@ -90,6 +97,20 @@ namespace ngraph
std::shared_ptr<Node> pattern_node() { return m_pattern_node; } std::shared_ptr<Node> pattern_node() { return m_pattern_node; }
std::shared_ptr<Node> match_root(); std::shared_ptr<Node> match_root();
PatternMap get_pattern_map() { return PatternMap{m_pattern_map}; } PatternMap get_pattern_map() { return PatternMap{m_pattern_map}; }
/// \brief Low-level helper to match recurring patterns
///
/// \param graph is a graph to be matched against
/// \param pattern is a recurring pattern
/// \param rpattern specifies a node to recur from next
/// \param patterns a map from labels to matches
/// \param correlated_patterns specify labels whose bound nodes should be
/// the same across all cells
static bool match_recurring_pattern(
std::shared_ptr<Node> graph,
std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern,
RPatternMap& patterns,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
friend op::Label; //TODO: refine to match_class friend op::Label; //TODO: refine to match_class
protected: protected:
......
...@@ -530,3 +530,92 @@ TEST(pattern, variance) ...@@ -530,3 +530,92 @@ TEST(pattern, variance)
ASSERT_TRUE(n.match(var_graph, variance)); ASSERT_TRUE(n.match(var_graph, variance));
ASSERT_EQ(n.get_pattern_map()[var_graph], variance); ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
} }
TEST(pattern, previous_matches)
{
using ngraph::pattern::Matcher;
Shape shape{};
Matcher::PatternMap previous_matches;
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto pattern = std::make_shared<pattern::op::Label>(b);
auto abs = make_shared<op::Abs>(a);
auto add = abs + b;
{
Matcher n(pattern + b);
ASSERT_TRUE(n.match(add, previous_matches));
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
}
{
Matcher n(pattern + b);
previous_matches.insert(std::make_pair(pattern, a));
ASSERT_FALSE(n.match(add, previous_matches));
}
}
TEST(pattern, recurrent_pattern)
{
using ngraph::pattern::Matcher;
Shape shape{};
ngraph::pattern::Matcher::PatternMap previous_matches;
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto rpattern = std::make_shared<pattern::op::Label>(b);
auto iconst0 = construct_constant_node(0);
auto abs = make_shared<op::Abs>(a);
auto add1 = iconst0 + b;
auto add2 = iconst0 + add1;
auto add3 = iconst0 + add2;
auto padd = iconst0 + rpattern;
ngraph::pattern::RPatternMap matches;
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
Matcher::match_recurring_pattern(add3, padd, rpattern, matches, empty_correlated_matches);
ASSERT_EQ(matches.size(), 1);
ASSERT_EQ(matches.count(rpattern), 1);
auto recurrent_matches = matches[rpattern];
ASSERT_EQ(recurrent_matches.at(0), add2);
ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_EQ(recurrent_matches.at(2), b);
//Multiple labels in a reccuring pattern
auto iconst1 = construct_constant_node(1);
auto iconst_label = std::make_shared<pattern::op::Label>(iconst1, nullptr, NodeVector{iconst1});
auto add2_2 = iconst1 + add1;
auto add3_2 = iconst0 + add2_2;
auto padd2 = iconst_label + rpattern;
matches.clear();
Matcher::match_recurring_pattern(add3_2, padd2, rpattern, matches, empty_correlated_matches);
ASSERT_EQ(matches.size(), 2);
recurrent_matches = matches[rpattern];
ASSERT_EQ(recurrent_matches.at(0), add2_2);
ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_EQ(recurrent_matches.at(2), b);
auto iconst_matches = matches[iconst_label];
ASSERT_EQ(iconst_matches.at(0), iconst0);
ASSERT_EQ(iconst_matches.at(1), iconst1);
ASSERT_EQ(iconst_matches.at(2), iconst0);
//Non-matching correlated labels
matches.clear();
std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
correlated_matches.insert(iconst_label);
Matcher::match_recurring_pattern(add3_2, padd2, rpattern, matches, correlated_matches);
ASSERT_EQ(matches.size(), 2);
iconst_matches = matches[iconst_label];
ASSERT_EQ(iconst_matches.size(), 1);
ASSERT_EQ(iconst_matches.at(0), iconst0);
//Matching correlated labels
matches.clear();
Matcher::match_recurring_pattern(add3, padd2, rpattern, matches, correlated_matches);
ASSERT_EQ(matches.size(), 2);
recurrent_matches = matches[rpattern];
ASSERT_EQ(recurrent_matches.at(0), add2);
ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_EQ(recurrent_matches.at(2), b);
iconst_matches = matches[iconst_label];
ASSERT_EQ(iconst_matches.at(0), iconst0);
ASSERT_EQ(iconst_matches.at(1), iconst0);
ASSERT_EQ(iconst_matches.at(2), iconst0);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment