Commit e07637c0 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Collect matched nodes (#1166)

* collect matched nodes

* clear m_matched_list

* tests

* address feedback
parent 137f002b
...@@ -122,24 +122,27 @@ namespace ngraph ...@@ -122,24 +122,27 @@ namespace ngraph
throw ngraph_error("pattern_node or graph_node shouldn't be nullptrs!"); throw ngraph_error("pattern_node or graph_node shouldn't be nullptrs!");
} }
add_node(graph_node);
size_t watermark = m_matched_list.size() - 1;
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : " NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : "
<< "pattern = " << pattern_node->get_name() << " matched " << "pattern = " << pattern_node->get_name() << " matched "
<< graph_node->get_name(); << graph_node->get_name();
if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node)) if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node))
{ {
return match_pattern(label_node, graph_node, pattern_map); return abort_match(watermark, match_pattern(label_node, graph_node, pattern_map));
} }
if (auto skip_node = std::dynamic_pointer_cast<op::Skip>( if (auto skip_node = std::dynamic_pointer_cast<op::Skip>(
pattern_node)) //matches PatternSkipOp semantics pattern_node)) //matches PatternSkipOp semantics
{ {
return match_skip(skip_node, graph_node, pattern_map); return abort_match(watermark, match_skip(skip_node, graph_node, pattern_map));
} }
if (auto any_node = std::dynamic_pointer_cast<op::Any>(pattern_node)) if (auto any_node = std::dynamic_pointer_cast<op::Any>(pattern_node))
{ {
return match_any(any_node, graph_node, pattern_map); return abort_match(watermark, match_any(any_node, graph_node, pattern_map));
} }
auto p_pattern_node = pattern_node.get(); auto p_pattern_node = pattern_node.get();
...@@ -147,10 +150,11 @@ namespace ngraph ...@@ -147,10 +150,11 @@ namespace ngraph
if (std::type_index(typeid(*p_pattern_node)) == std::type_index(typeid(*p_graph_node))) if (std::type_index(typeid(*p_pattern_node)) == std::type_index(typeid(*p_graph_node)))
{ {
return match_arguments(pattern_node, graph_node, pattern_map); return abort_match(watermark,
match_arguments(pattern_node, graph_node, pattern_map));
} }
return false; return abort_match(watermark, false);
} }
bool Matcher::match_permutation(const NodeVector& pattern_args, bool Matcher::match_permutation(const NodeVector& pattern_args,
...@@ -240,6 +244,7 @@ namespace ngraph ...@@ -240,6 +244,7 @@ namespace ngraph
//clear our state //clear our state
m_match_root.reset(); m_match_root.reset();
m_pattern_map.clear(); m_pattern_map.clear();
m_matched_list.clear();
if (!m_pattern_node || !graph_node) if (!m_pattern_node || !graph_node)
{ {
......
...@@ -109,7 +109,7 @@ namespace ngraph ...@@ -109,7 +109,7 @@ namespace ngraph
} }
bool process_match(graph_rewrite_callback callback = nullptr); bool process_match(graph_rewrite_callback callback = nullptr);
NodeVector get_matched_nodes() { return m_matched_list; }
void reset() {} void reset() {}
std::string get_name() { return m_name; } std::string get_name() { return m_name; }
std::shared_ptr<Node> get_pattern() { return m_pattern_node; } std::shared_ptr<Node> get_pattern() { return m_pattern_node; }
...@@ -124,6 +124,16 @@ namespace ngraph ...@@ -124,6 +124,16 @@ namespace ngraph
friend op::Label; //TODO: refine to match_class friend op::Label; //TODO: refine to match_class
protected: protected:
void add_node(std::shared_ptr<Node> node) { m_matched_list.push_back(node); }
bool abort_match(size_t watermark, bool matched)
{
if (!matched)
{
m_matched_list.erase(m_matched_list.begin() + watermark, m_matched_list.end());
}
return matched;
}
bool virtual match_node(const std::shared_ptr<Node>& pattern_node, bool virtual match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node, const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map); PatternMap& pattern_map);
...@@ -135,6 +145,7 @@ namespace ngraph ...@@ -135,6 +145,7 @@ namespace ngraph
std::shared_ptr<Node> m_match_root; std::shared_ptr<Node> m_match_root;
std::shared_ptr<Node> m_pattern_node; std::shared_ptr<Node> m_pattern_node;
PatternMap m_pattern_map; PatternMap m_pattern_map;
NodeVector m_matched_list;
private: private:
static std::string pad(size_t num) { return std::string(num, ' '); } static std::string pad(size_t num) { return std::string(num, ' '); }
......
...@@ -391,33 +391,51 @@ TEST(pattern, graph_rewrite) ...@@ -391,33 +391,51 @@ TEST(pattern, graph_rewrite)
} }
} }
std::ostream& operator<<(std::ostream& os, const ngraph::NodeVector& nv)
{
std::vector<std::string> names;
for (auto n : nv)
{
names.push_back(n->get_name());
}
os << vector_to_string(names);
return os;
}
TEST(pattern, matcher) TEST(pattern, matcher)
{ {
Shape shape{}; Shape shape{};
auto a = make_shared<op::Parameter>(element::i32, shape); auto a = make_shared<op::Parameter>(element::i32, shape);
TestMatcher n(nullptr); TestMatcher n(nullptr);
ASSERT_TRUE(n.match(a, a)); ASSERT_TRUE(n.match(a, a));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
auto abs = make_shared<op::Abs>(a); auto abs = make_shared<op::Abs>(a);
auto any = std::make_shared<pattern::op::Skip>(a); auto any = std::make_shared<pattern::op::Skip>(a);
ASSERT_TRUE(n.match(any, abs)); ASSERT_TRUE(n.match(any, abs));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
auto false_pred = [](std::shared_ptr<Node> no) { return false; }; auto false_pred = [](std::shared_ptr<Node> no) { return false; };
auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred); auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
ASSERT_TRUE(n.match(any_false, a)); ASSERT_TRUE(n.match(any_false, a));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
auto pattern = std::make_shared<pattern::op::Label>(a); auto pattern = std::make_shared<pattern::op::Label>(a);
ASSERT_TRUE(n.match(pattern, a)); ASSERT_TRUE(n.match(pattern, a));
ASSERT_EQ(n.get_pattern_map()[pattern], a); ASSERT_EQ(n.get_pattern_map()[pattern], a);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred); auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
ASSERT_FALSE(n.match(pattern_false, a)); ASSERT_FALSE(n.match(pattern_false, a));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
auto b = make_shared<op::Parameter>(element::i32, shape); auto b = make_shared<op::Parameter>(element::i32, shape);
auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>(); auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>();
auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b}); auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
ASSERT_TRUE(n.match(bea, a + b)); auto add_ab = a + b;
ASSERT_TRUE(n.match(bea, add_ab));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
ASSERT_TRUE(n.match(bea, b + a)); ASSERT_TRUE(n.match(bea, b + a));
auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b}); auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
...@@ -432,22 +450,33 @@ TEST(pattern, matcher) ...@@ -432,22 +450,33 @@ TEST(pattern, matcher)
ASSERT_FALSE(n.match(d, b)); ASSERT_FALSE(n.match(d, b));
ASSERT_FALSE(n.match(abs + b, b + b)); ASSERT_FALSE(n.match(abs + b, b + b));
ASSERT_TRUE(n.match(any + b, abs + b)); ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
auto add_absb = abs + b;
ASSERT_TRUE(n.match(any + b, add_absb));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
ASSERT_TRUE(n.match(pattern + b, abs + b)); ASSERT_TRUE(n.match(pattern + b, add_absb));
ASSERT_EQ(n.get_pattern_map()[pattern], abs); ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
ASSERT_TRUE(n.match(b + pattern, abs + b)); ASSERT_TRUE(n.match(b + pattern, add_absb));
ASSERT_EQ(n.get_pattern_map()[pattern], abs); ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
auto c = make_shared<op::Parameter>(element::i32, shape); auto c = make_shared<op::Parameter>(element::i32, shape);
ASSERT_TRUE(n.match(c * (b + pattern), c * (abs + b))); auto mul_add_absb = c * (add_absb);
ASSERT_TRUE(n.match(c * (b + pattern), mul_add_absb));
ASSERT_EQ(n.get_pattern_map()[pattern], abs); ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
ASSERT_TRUE(n.match(c * (any + b), c * (abs + b))); //nested any ASSERT_TRUE(n.match(c * (any + b), mul_add_absb)); //nested any
ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); //permutations w/ any ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b))); //nested any ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); //permutations w/ any
ASSERT_TRUE(n.match(c * (any_false + b), (b + a) * c)); //permutations w/ any_false auto mul_c_add_ab = c * add_ab;
ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b))); //nested any
ASSERT_TRUE(n.match(c * (any_false + b), mul_c_add_ab)); //permutations w/ any_false
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
auto iconst1_0 = construct_constant_node(1); auto iconst1_0 = construct_constant_node(1);
auto iconst1_1 = construct_constant_node(1); auto iconst1_1 = construct_constant_node(1);
...@@ -462,6 +491,7 @@ TEST(pattern, matcher) ...@@ -462,6 +491,7 @@ TEST(pattern, matcher)
auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add}); auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
ASSERT_TRUE(n.match(label, add)); ASSERT_TRUE(n.match(label, add));
ASSERT_EQ(n.get_pattern_map()[label], add); ASSERT_EQ(n.get_pattern_map()[label], add);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
ASSERT_FALSE(n.match(label, a - b)); ASSERT_FALSE(n.match(label, a - b));
...@@ -483,9 +513,11 @@ TEST(pattern, matcher) ...@@ -483,9 +513,11 @@ TEST(pattern, matcher)
auto tmp = label1 + b; auto tmp = label1 + b;
auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp}); auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
auto sub_label1 = label1 - label2; auto sub_label1 = label1 - label2;
ASSERT_TRUE(n.match(sub_label1, a - add)); auto sub_add = a - add;
ASSERT_TRUE(n.match(sub_label1, sub_add));
ASSERT_EQ(n.get_pattern_map()[label1], a); ASSERT_EQ(n.get_pattern_map()[label1], a);
ASSERT_EQ(n.get_pattern_map()[label2], add); ASSERT_EQ(n.get_pattern_map()[label2], add);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
ASSERT_FALSE(n.match(sub_label1, add - a)); ASSERT_FALSE(n.match(sub_label1, add - a));
......
/******************************************************************************* /*******************************************************************************
* Copyright 2017-2018 Intel Corporation * Copyright 2017-2018 Intel Corporation
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
//this is for more nuanced testing //this is for more nuanced testing
class TestMatcher : public ngraph::pattern::Matcher class TestMatcher : public ngraph::pattern::Matcher
{ {
using ngraph::pattern::Matcher::Matcher; using ngraph::pattern::Matcher::Matcher;
bool virtual match_node(const std::shared_ptr<ngraph::Node>& pattern_node, bool virtual match_node(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node, const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map) override PatternMap& pattern_map) override
{ {
if (std::dynamic_pointer_cast<::ngraph::op::Parameter>(pattern_node)) if (std::dynamic_pointer_cast<::ngraph::op::Parameter>(pattern_node))
{ {
return pattern_node.get() == dynamic_cast<::ngraph::op::Parameter*>(graph_node.get()); bool result =
} pattern_node.get() == dynamic_cast<::ngraph::op::Parameter*>(graph_node.get());
if (result)
return this->ngraph::pattern::Matcher::match_node(pattern_node, graph_node, pattern_map); {
} m_matched_list.push_back(graph_node);
}
public: return result;
bool match(const std::shared_ptr<ngraph::Node>& pattern_node, }
const std::shared_ptr<ngraph::Node>& graph_node)
{ return this->ngraph::pattern::Matcher::match_node(pattern_node, graph_node, pattern_map);
assert( }
pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match` public:
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name() bool match(const std::shared_ptr<ngraph::Node>& pattern_node,
<< " , graph_node = " << graph_node->get_name(); const std::shared_ptr<ngraph::Node>& graph_node)
{
m_pattern_map.clear(); assert(
m_match_root.reset(); pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match`
bool is_match = match_node(pattern_node, graph_node, m_pattern_map); NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name()
if (is_match) << " , graph_node = " << graph_node->get_name();
{
m_match_root = graph_node; m_pattern_map.clear();
} m_match_root.reset();
return is_match; m_matched_list.clear();
}
}; bool is_match = match_node(pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
}
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment