Commit e07637c0 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Collect matched nodes (#1166)

* collect matched nodes

* clear m_matched_list

* tests

* address feedback
parent 137f002b
......@@ -122,24 +122,27 @@ namespace ngraph
throw ngraph_error("pattern_node or graph_node shouldn't be nullptrs!");
}
add_node(graph_node);
size_t watermark = m_matched_list.size() - 1;
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : "
<< "pattern = " << pattern_node->get_name() << " matched "
<< graph_node->get_name();
if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node))
{
return match_pattern(label_node, graph_node, pattern_map);
return abort_match(watermark, match_pattern(label_node, graph_node, pattern_map));
}
if (auto skip_node = std::dynamic_pointer_cast<op::Skip>(
pattern_node)) //matches PatternSkipOp semantics
{
return match_skip(skip_node, graph_node, pattern_map);
return abort_match(watermark, match_skip(skip_node, graph_node, pattern_map));
}
if (auto any_node = std::dynamic_pointer_cast<op::Any>(pattern_node))
{
return match_any(any_node, graph_node, pattern_map);
return abort_match(watermark, match_any(any_node, graph_node, pattern_map));
}
auto p_pattern_node = pattern_node.get();
......@@ -147,10 +150,11 @@ namespace ngraph
if (std::type_index(typeid(*p_pattern_node)) == std::type_index(typeid(*p_graph_node)))
{
return match_arguments(pattern_node, graph_node, pattern_map);
return abort_match(watermark,
match_arguments(pattern_node, graph_node, pattern_map));
}
return false;
return abort_match(watermark, false);
}
bool Matcher::match_permutation(const NodeVector& pattern_args,
......@@ -240,6 +244,7 @@ namespace ngraph
//clear our state
m_match_root.reset();
m_pattern_map.clear();
m_matched_list.clear();
if (!m_pattern_node || !graph_node)
{
......
......@@ -109,7 +109,7 @@ namespace ngraph
}
bool process_match(graph_rewrite_callback callback = nullptr);
NodeVector get_matched_nodes() { return m_matched_list; }
void reset() {}
std::string get_name() { return m_name; }
std::shared_ptr<Node> get_pattern() { return m_pattern_node; }
......@@ -124,6 +124,16 @@ namespace ngraph
friend op::Label; //TODO: refine to match_class
protected:
void add_node(std::shared_ptr<Node> node) { m_matched_list.push_back(node); }
bool abort_match(size_t watermark, bool matched)
{
if (!matched)
{
m_matched_list.erase(m_matched_list.begin() + watermark, m_matched_list.end());
}
return matched;
}
bool virtual match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
......@@ -135,6 +145,7 @@ namespace ngraph
std::shared_ptr<Node> m_match_root;
std::shared_ptr<Node> m_pattern_node;
PatternMap m_pattern_map;
NodeVector m_matched_list;
private:
static std::string pad(size_t num) { return std::string(num, ' '); }
......
......@@ -391,33 +391,51 @@ TEST(pattern, graph_rewrite)
}
}
std::ostream& operator<<(std::ostream& os, const ngraph::NodeVector& nv)
{
std::vector<std::string> names;
for (auto n : nv)
{
names.push_back(n->get_name());
}
os << vector_to_string(names);
return os;
}
TEST(pattern, matcher)
{
Shape shape{};
auto a = make_shared<op::Parameter>(element::i32, shape);
TestMatcher n(nullptr);
ASSERT_TRUE(n.match(a, a));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
auto abs = make_shared<op::Abs>(a);
auto any = std::make_shared<pattern::op::Skip>(a);
ASSERT_TRUE(n.match(any, abs));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
auto false_pred = [](std::shared_ptr<Node> no) { return false; };
auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
ASSERT_TRUE(n.match(any_false, a));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
auto pattern = std::make_shared<pattern::op::Label>(a);
ASSERT_TRUE(n.match(pattern, a));
ASSERT_EQ(n.get_pattern_map()[pattern], a);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
ASSERT_FALSE(n.match(pattern_false, a));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
auto b = make_shared<op::Parameter>(element::i32, shape);
auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>();
auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
ASSERT_TRUE(n.match(bea, a + b));
auto add_ab = a + b;
ASSERT_TRUE(n.match(bea, add_ab));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
ASSERT_TRUE(n.match(bea, b + a));
auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
......@@ -432,22 +450,33 @@ TEST(pattern, matcher)
ASSERT_FALSE(n.match(d, b));
ASSERT_FALSE(n.match(abs + b, b + b));
ASSERT_TRUE(n.match(any + b, abs + b));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
auto add_absb = abs + b;
ASSERT_TRUE(n.match(any + b, add_absb));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
ASSERT_TRUE(n.match(pattern + b, abs + b));
ASSERT_TRUE(n.match(pattern + b, add_absb));
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
ASSERT_TRUE(n.match(b + pattern, abs + b));
ASSERT_TRUE(n.match(b + pattern, add_absb));
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
auto c = make_shared<op::Parameter>(element::i32, shape);
ASSERT_TRUE(n.match(c * (b + pattern), c * (abs + b)));
auto mul_add_absb = c * (add_absb);
ASSERT_TRUE(n.match(c * (b + pattern), mul_add_absb));
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
ASSERT_TRUE(n.match(c * (any + b), c * (abs + b))); //nested any
ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); //permutations w/ any
ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b))); //nested any
ASSERT_TRUE(n.match(c * (any_false + b), (b + a) * c)); //permutations w/ any_false
ASSERT_TRUE(n.match(c * (any + b), mul_add_absb)); //nested any
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); //permutations w/ any
auto mul_c_add_ab = c * add_ab;
ASSERT_TRUE(n.match(c * (any_false + b), c * (a + b))); //nested any
ASSERT_TRUE(n.match(c * (any_false + b), mul_c_add_ab)); //permutations w/ any_false
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
auto iconst1_0 = construct_constant_node(1);
auto iconst1_1 = construct_constant_node(1);
......@@ -462,6 +491,7 @@ TEST(pattern, matcher)
auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
ASSERT_TRUE(n.match(label, add));
ASSERT_EQ(n.get_pattern_map()[label], add);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
ASSERT_FALSE(n.match(label, a - b));
......@@ -483,9 +513,11 @@ TEST(pattern, matcher)
auto tmp = label1 + b;
auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
auto sub_label1 = label1 - label2;
ASSERT_TRUE(n.match(sub_label1, a - add));
auto sub_add = a - add;
ASSERT_TRUE(n.match(sub_label1, sub_add));
ASSERT_EQ(n.get_pattern_map()[label1], a);
ASSERT_EQ(n.get_pattern_map()[label2], add);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
ASSERT_FALSE(n.match(sub_label1, add - a));
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
//this is for more nuanced testing
class TestMatcher : public ngraph::pattern::Matcher
{
using ngraph::pattern::Matcher::Matcher;
bool virtual match_node(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map) override
{
if (std::dynamic_pointer_cast<::ngraph::op::Parameter>(pattern_node))
{
return pattern_node.get() == dynamic_cast<::ngraph::op::Parameter*>(graph_node.get());
}
return this->ngraph::pattern::Matcher::match_node(pattern_node, graph_node, pattern_map);
}
public:
bool match(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node)
{
assert(
pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match`
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
m_pattern_map.clear();
m_match_root.reset();
bool is_match = match_node(pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
}
};
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
//this is for more nuanced testing
class TestMatcher : public ngraph::pattern::Matcher
{
using ngraph::pattern::Matcher::Matcher;
bool virtual match_node(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map) override
{
if (std::dynamic_pointer_cast<::ngraph::op::Parameter>(pattern_node))
{
bool result =
pattern_node.get() == dynamic_cast<::ngraph::op::Parameter*>(graph_node.get());
if (result)
{
m_matched_list.push_back(graph_node);
}
return result;
}
return this->ngraph::pattern::Matcher::match_node(pattern_node, graph_node, pattern_map);
}
public:
bool match(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node)
{
assert(
pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match`
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
m_pattern_map.clear();
m_match_root.reset();
m_matched_list.clear();
bool is_match = match_node(pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
}
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment