Commit 04985b27 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

label on skip (#909)

parent a28f9a67
......@@ -715,3 +715,35 @@ TEST(pattern, recurrent_graph_rewrite)
ASSERT_EQ(add_b, b);
}
}
TEST(pattern, label_on_skip)
{
Shape shape{2, 2};
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, Shape{});
auto iconst = ngraph::make_zero(element::i32, Shape{});
auto label = std::make_shared<pattern::op::Label>(iconst);
auto const_label =
std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});
auto bcst_pred = [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr;
};
auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto matcher = std::make_shared<pattern::Matcher>(
std::make_shared<op::Multiply>(label, bcst_label), nullptr);
auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
auto mul = a * const_broadcast;
auto mul_scalar = b * iconst;
ASSERT_TRUE(matcher->match(mul));
ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast);
ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
ASSERT_EQ(matcher->get_pattern_map()[label], a);
ASSERT_TRUE(matcher->match(mul_scalar));
ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst);
ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
ASSERT_EQ(matcher->get_pattern_map()[label], b);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment