Commit 1234eb97 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

matcher's strict mode (#2235)

* matcher's strict mode

* fix clang warnings
parent 28622bde
...@@ -190,6 +190,23 @@ namespace ngraph ...@@ -190,6 +190,23 @@ namespace ngraph
add_node(graph_node); add_node(graph_node);
size_t watermark = m_matched_list.size() - 1; size_t watermark = m_matched_list.size() - 1;
// we can skip multi-output nodes since their shapes will be compared
// when their individual GOE are matched
// this also gives a bit more flexibility since we don't have to worry
// about *all* outputs of a pattern node but only the ones we want to match.
if (m_strict_mode && graph_node->get_outputs().size() == 1)
{
bool shape_match = pattern_node->get_output_partial_shape(0).compatible(
graph_node->get_output_partial_shape(0));
bool et_match =
pattern_node->get_element_type().compatible(graph_node->get_element_type());
if (!shape_match || !et_match)
{
return abort_match(watermark, false);
}
}
// This env var allows one to specify node name patterns to abort pattern matching // This env var allows one to specify node name patterns to abort pattern matching
// at particular nodes. The upshot is that one can quickly zero in on an offending fusion by // at particular nodes. The upshot is that one can quickly zero in on an offending fusion by
// disabling individual fusions or optimizations that use Matcher. // disabling individual fusions or optimizations that use Matcher.
......
...@@ -68,11 +68,13 @@ namespace ngraph ...@@ -68,11 +68,13 @@ namespace ngraph
/// \param callback is a callback function that will be called on a successful match /// \param callback is a callback function that will be called on a successful match
Matcher(const std::shared_ptr<Node> pattern_node = nullptr, Matcher(const std::shared_ptr<Node> pattern_node = nullptr,
graph_rewrite_callback callback = nullptr, graph_rewrite_callback callback = nullptr,
const std::string& name = "Unnamed") const std::string& name = "Unnamed",
bool strict_mode = false)
: m_pattern_node(pattern_node) : m_pattern_node(pattern_node)
, m_callback(callback) , m_callback(callback)
, m_depth(0) , m_depth(0)
, m_name(name) , m_name(name)
, m_strict_mode(strict_mode)
{ {
} }
...@@ -171,6 +173,7 @@ namespace ngraph ...@@ -171,6 +173,7 @@ namespace ngraph
graph_rewrite_callback m_callback; graph_rewrite_callback m_callback;
size_t m_depth; size_t m_depth;
std::string m_name; std::string m_name;
bool m_strict_mode;
}; };
class RecurrentMatcher class RecurrentMatcher
......
...@@ -541,6 +541,29 @@ TEST(pattern, matcher) ...@@ -541,6 +541,29 @@ TEST(pattern, matcher)
ASSERT_TRUE(n.match(add_label1, add + a)); ASSERT_TRUE(n.match(add_label1, add + a));
ASSERT_EQ(n.get_pattern_map()[label1], a); ASSERT_EQ(n.get_pattern_map()[label1], a);
ASSERT_EQ(n.get_pattern_map()[label2], add); ASSERT_EQ(n.get_pattern_map()[label2], add);
// strict mode
{
TestMatcher sm(nullptr, nullptr, "TestMatcher", true);
// exact shape and type
auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
auto label_dynamic_shape =
make_shared<pattern::op::Label>(element::i32, PartialShape::dynamic());
auto param = make_shared<op::Parameter>(element::f32, Shape{});
ASSERT_TRUE(sm.match(label_dynamic_shape, scalar_param));
// wrong type
auto scalar_param_wrong_type = make_shared<op::Parameter>(element::f32, Shape{});
ASSERT_FALSE(sm.match(label, scalar_param_wrong_type));
// dynamic dimension
auto label_dynamic_dimension =
make_shared<pattern::op::Label>(element::i32, PartialShape{Dimension::dynamic()});
auto vector_param = make_shared<op::Parameter>(element::i32, Shape{10});
ASSERT_TRUE(sm.match(label_dynamic_dimension, vector_param));
// dynamic type
auto label_dynamic_type =
make_shared<pattern::op::Label>(element::dynamic, PartialShape{Dimension::dynamic()});
ASSERT_TRUE(sm.match(label_dynamic_type, vector_param));
}
} }
TEST(pattern, sum) TEST(pattern, sum)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment