Commit 81c48453 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

is_contained (#1257)

parent 6457ed2e
......@@ -19,6 +19,7 @@
#include <typeindex>
#include <typeinfo>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/parameter.hpp"
......@@ -71,6 +72,27 @@ namespace ngraph
return is_match;
}
bool Matcher::is_contained_match(const NodeVector& exclusions, bool ignore_unused)
{
if (exclusions.empty())
{
NodeVector label_exclusions;
for (auto entry : m_pattern_map)
{
//leaf label
if (entry.first->get_inputs().empty())
{
label_exclusions.push_back(entry.second);
}
}
return ngraph::get_subgraph_outputs(
get_matched_nodes(), label_exclusions, ignore_unused)
.size() < 2;
}
return ngraph::get_subgraph_outputs(get_matched_nodes(), exclusions).size() < 2;
}
bool Matcher::match_skip(const std::shared_ptr<op::Skip>& skip,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
......
......@@ -108,6 +108,8 @@ namespace ngraph
return matched;
}
bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
bool process_match(graph_rewrite_callback callback = nullptr);
NodeVector get_matched_nodes() { return m_matched_list; }
void reset() {}
......
......@@ -794,3 +794,23 @@ TEST(pattern, label_on_skip)
ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
ASSERT_EQ(matcher->get_pattern_map()[label], b);
}
TEST(pattern, is_contained_match)
{
Shape shape{};
auto a = make_shared<op::Parameter>(element::i32, shape);
auto absn = make_shared<op::Abs>(a);
TestMatcher n(nullptr);
auto label_a = std::make_shared<pattern::op::Label>(a);
auto label_abs = make_shared<op::Abs>(a);
ASSERT_TRUE(n.match(label_abs, absn));
auto result_absn = make_shared<op::Result>(absn);
ASSERT_TRUE(n.is_contained_match());
auto absn2 = make_shared<op::Abs>(absn);
auto result_absn2 = make_shared<op::Result>(absn2);
auto label_abs2 = make_shared<op::Abs>(label_abs);
ASSERT_TRUE(n.match(label_abs2, absn2));
ASSERT_FALSE(n.is_contained_match());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment