Unverified Commit b32b5c23 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

make sure matcher respects argument order for non-commutative ops (#847)

parent 638f36ee
...@@ -26,18 +26,6 @@ namespace ngraph ...@@ -26,18 +26,6 @@ namespace ngraph
{ {
namespace pattern namespace pattern
{ {
static std::vector<std::shared_ptr<Node>> get_arguments(std::shared_ptr<Node> n)
{
std::unordered_set<std::shared_ptr<Node>> arguments;
for (const auto& input : n->get_inputs())
{
arguments.insert(input.get_output().get_node());
}
return std::vector<std::shared_ptr<Node>>(
begin(arguments), end(arguments)); //vector is needed for generating permutations
}
std::shared_ptr<Node> Matcher::match_root() { return m_match_root; } std::shared_ptr<Node> Matcher::match_root() { return m_match_root; }
bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label, bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node, const std::shared_ptr<Node>& graph_node,
...@@ -62,7 +50,7 @@ namespace ngraph ...@@ -62,7 +50,7 @@ namespace ngraph
if (is_match) //in case label was already bound this rebinds it to the same node (harmless; and the logic seems cleaner) if (is_match) //in case label was already bound this rebinds it to the same node (harmless; and the logic seems cleaner)
{ {
auto args = get_arguments(label); auto args = label->get_input_ops();
if (args.size() > 0) if (args.size() > 0)
{ {
if (args.size() != 1) if (args.size() != 1)
...@@ -95,7 +83,7 @@ namespace ngraph ...@@ -95,7 +83,7 @@ namespace ngraph
} }
else else
{ {
auto args = get_arguments(any); auto args = any->get_input_ops();
if (args.size() != 1) if (args.size() != 1)
{ {
throw ngraph_error("Any can only take one argument"); throw ngraph_error("Any can only take one argument");
...@@ -165,8 +153,8 @@ namespace ngraph ...@@ -165,8 +153,8 @@ namespace ngraph
<< "pattern = " << pattern_node->get_name() << " " << "pattern = " << pattern_node->get_name() << " "
<< "matched " << graph_node->get_name(); << "matched " << graph_node->get_name();
auto args = get_arguments(graph_node); auto args = graph_node->get_input_ops();
auto pattern_args = get_arguments(pattern_node); auto pattern_args = pattern_node->get_input_ops();
if (args.size() != pattern_args.size()) if (args.size() != pattern_args.size())
{ {
......
...@@ -458,6 +458,16 @@ TEST(pattern, matcher) ...@@ -458,6 +458,16 @@ TEST(pattern, matcher)
ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add))); ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
ASSERT_EQ(n.get_pattern_map()[label], add); ASSERT_EQ(n.get_pattern_map()[label], add);
//Correct argument order
ASSERT_FALSE(n.match(b - a, a - b));
auto aab = a * (a - b);
auto paab = pattern * (pattern - b);
ASSERT_TRUE(n.match(paab, aab));
auto aba = a * (b - a);
ASSERT_FALSE(n.match(paab, aba));
auto paba = pattern * (b - pattern);
ASSERT_FALSE(n.match(paba, aab));
//Correlations //Correlations
auto label1 = std::make_shared<pattern::op::Label>(a); auto label1 = std::make_shared<pattern::op::Label>(a);
auto tmp = label1 + b; auto tmp = label1 + b;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment