Commit 67ff40f5 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Use node IDs for comparison when sorting in pattern match. (#3105)

* Use node IDs for comparison when sorting in pattern match.

* Make args const.
parent c2bfb94d
......@@ -298,9 +298,14 @@ namespace ngraph
if (graph_node->is_commutative())
{
std::sort(
begin(pattern_args),
end(pattern_args)); // TODO: [nikolayk] we don't really have to use lexicographically-based perms, heap's algo should be faster
// TODO: [nikolayk] we don't really have to use lexicographically-based perms, heap's algo should be faster
std::sort(begin(pattern_args),
end(pattern_args),
[](const std::shared_ptr<ngraph::Node>& n1,
const std::shared_ptr<ngraph::Node>& n2) {
return n1->get_instance_id() < n2->get_instance_id();
});
do
{
NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node "
......@@ -311,7 +316,13 @@ namespace ngraph
pattern_map.insert(begin(copy), end(copy));
return true;
}
} while (std::next_permutation(begin(pattern_args), end(pattern_args)));
} while (std::next_permutation(begin(pattern_args),
end(pattern_args),
[](const std::shared_ptr<ngraph::Node>& n1,
const std::shared_ptr<ngraph::Node>& n2) {
return n1->get_instance_id() <
n2->get_instance_id();
}));
}
else
{
......
......@@ -514,6 +514,33 @@ TEST(pattern, previous_matches)
}
}
TEST(pattern, test_sort)
{
using ngraph::pattern::Matcher;
Shape shape{};
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto abs1 = make_shared<op::Abs>(a);
auto abs2 = make_shared<op::Abs>(b);
auto add = abs1 + abs2;
auto pa = make_shared<op::Parameter>(element::i32, shape);
auto pb = make_shared<op::Parameter>(element::i32, shape);
auto pabs1 = make_shared<op::Abs>(pa);
auto pabs1_label = std::make_shared<pattern::op::Label>(pabs1);
auto pabs2 = make_shared<op::Abs>(b);
auto padd = pabs1_label + pabs2;
{
Matcher n1(padd);
ASSERT_TRUE(n1.match(add));
auto r1 = n1.get_pattern_map()[pabs1_label];
ASSERT_TRUE(n1.match(add));
ASSERT_EQ(r1, n1.get_pattern_map()[pabs1_label]);
}
}
TEST(pattern, recurrent_pattern)
{
using ngraph::pattern::RecurrentMatcher;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment