Commit 82c19d24 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Replace with Broadcast if exists in Algebraic Simplifier (#908)

* pick broadcast if exists

* remove logic for sum

* get at broadcast using the label-on-skip approach

* tests for broadcast fix

*  add comments
parent c9d65479
......@@ -41,11 +41,27 @@ static std::shared_ptr<pattern::Matcher>
auto bcst_pred = [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr;
};
auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
auto matcher = std::make_shared<pattern::Matcher>(std::make_shared<T>(label, bcst), nullptr);
auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto matcher =
std::make_shared<pattern::Matcher>(std::make_shared<T>(label, bcst_label), nullptr);
return matcher;
}
static std::shared_ptr<pattern::op::Label>
get_broadcast_label(std::shared_ptr<pattern::Matcher> matcher)
{
return std::dynamic_pointer_cast<pattern::op::Label>(matcher->pattern_node()->get_argument(1));
}
//`simplify_multiply` optimizes the following 4 *base* cases
//(8 cases in total including variants due to commutativity)
//
//a * 0 -> 0
//a * broadcast(0) -> broadcast(0)
//a * 1 -> a
//a * broadcast(1) -> a
static bool simplify_multiply(std::shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_multiply for " << n->get_name();
......@@ -55,14 +71,16 @@ static bool simplify_multiply(std::shared_ptr<Node> n)
std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});
auto const_label_one =
std::make_shared<pattern::op::Label>(iconst, ngraph::is_one, NodeVector{iconst});
auto matcher_const_zero = create_binary_matcher<op::Multiply>(label, const_label_zero);
auto matcher_const_one = create_binary_matcher<op::Multiply>(label, const_label_one);
if (matcher_const_zero->match(n))
{
auto cnst = matcher_const_zero->get_pattern_map()[const_label_zero];
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << cnst->get_name();
ngraph::replace_node(n, cnst);
auto bcst_label = get_broadcast_label(matcher_const_zero);
auto bcst_or_cnst = matcher_const_zero->get_pattern_map()[bcst_label];
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << bcst_or_cnst->get_name();
ngraph::replace_node(n, bcst_or_cnst);
return true;
}
......@@ -77,6 +95,11 @@ static bool simplify_multiply(std::shared_ptr<Node> n)
return false;
}
//`simplify_multiply` optimizes the following 2 *base* cases
//(4 cases in total including variants due to commutativity)
//
//a + 0 -> a
//a + broadcast(0) -> a
static bool simplify_add(std::shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_add for " << n->get_name();
......
......@@ -116,6 +116,37 @@ TEST(algebraic_simplification, add_broadcast)
}
}
TEST(algebraic_simplification, multiply_broadcast)
{
Shape shape{2, 2};
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape);
auto iconst0 = ngraph::make_zero(element::i32, Shape{});
auto const_broadcast = make_shared<op::Broadcast>(iconst0, shape, AxisSet{0, 1});
auto mul_a_0 = a * const_broadcast;
auto mul_a_0_0 = mul_a_0 * const_broadcast;
auto mul_b_0 = b * const_broadcast;
auto mul_b_0_0 = mul_b_0 * const_broadcast;
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, mul_a_0_0, c, mul_b_0_0},
op::ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Add>(f), 0);
auto expected = ngraph::NodeVector{a, b, const_broadcast, c, const_broadcast};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i), results.at(i)->get_argument(0));
}
}
TEST(algebraic_simplification, zero_plus_zero_commutativity)
{
Shape shape{};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment