Commit 656dfa55 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Adam Procter

enable cse for reduction ops (#1030)

* enable cse for reduction ops

* reduction tests
parent 7d6a0d1c
......@@ -77,6 +77,17 @@ static bool cse_binarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
(a->get_argument(1) == b->get_argument(0) && a->get_argument(0) == b->get_argument(1));
}
static bool cse_reduction(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_reduction for " << a->get_name() << " and " << b->get_name();
auto ar_a = std::dynamic_pointer_cast<op::util::ArithmeticReduction>(a);
auto ar_b = std::dynamic_pointer_cast<op::util::ArithmeticReduction>(b);
return ar_a->get_argument(0) == ar_b->get_argument(0) &&
ar_a->get_reduction_axes() == ar_b->get_reduction_axes();
}
static std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>
initialize_ops_to_cse_handlers()
......@@ -110,6 +121,8 @@ static std::unordered_map<std::type_index,
{TI(op::Power), cse_binarywise},
//{TI(op::Remainder), cse_binarywise},
{TI(op::Subtract), cse_binarywise},
{TI(op::Sum), cse_reduction},
{TI(op::Product), cse_reduction},
});
}
......
......@@ -188,3 +188,35 @@ TEST(CSE, abs_add_abs_add_negative)
ASSERT_EQ(oadd4->get_argument(1), D);
ASSERT_EQ(oadd3->get_argument(0), oadd4->get_argument(0));
}
template <typename T>
static void execute_cse_reduction_test()
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, Shape{3, 5});
auto a_reduction_op = std::make_shared<T>(A, AxisSet{0, 1});
auto a_reduction_op2 = std::make_shared<T>(A, AxisSet{0, 1});
auto a_reduction_op3 = std::make_shared<T>(A, AxisSet{0});
auto sub_aa = a_reduction_op - a_reduction_op2;
auto B = std::make_shared<op::Parameter>(element::i32, Shape{3, 5});
auto b_reduction_op = std::make_shared<T>(B, AxisSet{0, 1});
auto sub_ab = a_reduction_op - b_reduction_op;
auto f = std::make_shared<Function>(NodeVector{sub_aa, sub_ab, a_reduction_op3},
op::ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(sub_aa->get_argument(0), sub_aa->get_argument(1));
ASSERT_NE(sub_ab->get_argument(0), sub_ab->get_argument(1));
ASSERT_NE(f->get_results().at(2)->get_argument(0), sub_aa->get_argument(0));
}
TEST(CSE, reduction_ops)
{
execute_cse_reduction_test<op::Sum>();
execute_cse_reduction_test<op::Product>();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment