Unverified Commit b32b5c23 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

make sure matcher respects argument order for non-commutative ops (#847)

parent 638f36ee
......@@ -26,18 +26,6 @@ namespace ngraph
{
namespace pattern
{
static std::vector<std::shared_ptr<Node>> get_arguments(std::shared_ptr<Node> n)
{
std::unordered_set<std::shared_ptr<Node>> arguments;
for (const auto& input : n->get_inputs())
{
arguments.insert(input.get_output().get_node());
}
return std::vector<std::shared_ptr<Node>>(
begin(arguments), end(arguments)); //vector is needed for generating permutations
}
std::shared_ptr<Node> Matcher::match_root() { return m_match_root; }
bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node,
......@@ -62,7 +50,7 @@ namespace ngraph
if (is_match) //in case label was already bound this rebinds it to the same node (harmless; and the logic seems cleaner)
{
auto args = get_arguments(label);
auto args = label->get_input_ops();
if (args.size() > 0)
{
if (args.size() != 1)
......@@ -95,7 +83,7 @@ namespace ngraph
}
else
{
auto args = get_arguments(any);
auto args = any->get_input_ops();
if (args.size() != 1)
{
throw ngraph_error("Any can only take one argument");
......@@ -165,8 +153,8 @@ namespace ngraph
<< "pattern = " << pattern_node->get_name() << " "
<< "matched " << graph_node->get_name();
auto args = get_arguments(graph_node);
auto pattern_args = get_arguments(pattern_node);
auto args = graph_node->get_input_ops();
auto pattern_args = pattern_node->get_input_ops();
if (args.size() != pattern_args.size())
{
......
......@@ -458,6 +458,16 @@ TEST(pattern, matcher)
ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
ASSERT_EQ(n.get_pattern_map()[label], add);
//Correct argument order
ASSERT_FALSE(n.match(b - a, a - b));
auto aab = a * (a - b);
auto paab = pattern * (pattern - b);
ASSERT_TRUE(n.match(paab, aab));
auto aba = a * (b - a);
ASSERT_FALSE(n.match(paab, aba));
auto paba = pattern * (b - pattern);
ASSERT_FALSE(n.match(paba, aab));
//Correlations
auto label1 = std::make_shared<pattern::op::Label>(a);
auto tmp = label1 + b;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment