Commit a28f9a67 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Algebraic Simplifier for Sum (#907)

* simplifier for sum

* add comment, remove visualization passes
parent 7f582a99
......@@ -18,6 +18,7 @@
#include <set>
#include "algebraic_simplification.hpp"
#include "ngraph/axis_vector.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
......@@ -106,11 +107,65 @@ static bool simplify_add(std::shared_ptr<Node> n)
return false;
}
static size_t reduction_shape_size(const AxisSet& axes, const Shape& shape)
{
size_t prod = 1;
for (auto axis : axes)
{
prod *= shape.at(axis);
}
return prod;
}
//`simplify_sum` optimizes the following case:
//sum(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
//where constant2's values are equal to scalar_constant * shape_size(reduction_axes)
static bool simplify_sum(std::shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_sum for " << n->get_name();
auto sum = std::dynamic_pointer_cast<op::Sum>(n);
auto broadcast = std::dynamic_pointer_cast<op::Broadcast>(n->get_argument(0));
if (!broadcast)
{
NGRAPH_DEBUG << n->get_name() << " isn't Broadcast";
return false;
}
auto cnst = std::dynamic_pointer_cast<op::Constant>(broadcast->get_argument(0));
if (!cnst || cnst->get_shape().size() > 0 /*not a scalar*/)
{
NGRAPH_DEBUG << broadcast->get_argument(0)->get_name() << " isn't a scalar constant";
return false;
}
auto multiplier = reduction_shape_size(sum->get_reduction_axes(), broadcast->get_shape());
double sum_const_value = cnst->get_vector<double>().at(0) * multiplier;
std::shared_ptr<Node> sum_cnst =
op::Constant::create(cnst->get_element_type(), Shape{}, {sum_const_value});
auto new_node = sum_cnst;
if (sum->get_shape().size() > 0)
{
ngraph::AxisSet axes{};
for (size_t i = 0; i < sum->get_shape().size(); i++)
{
axes.insert(i);
}
new_node = std::make_shared<op::Broadcast>(sum_cnst, sum->get_shape(), axes);
}
ngraph::replace_node(n, new_node);
return true;
}
static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>
initialize_const_values_to_ops()
{
return std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>({
{TI(op::Add), simplify_add}, {TI(op::Multiply), simplify_multiply},
{TI(op::Add), simplify_add},
{TI(op::Multiply), simplify_multiply},
{TI(op::Sum), simplify_sum},
});
}
......
......@@ -228,3 +228,59 @@ TEST(algebraic_simplification, multiply_negative_tests)
ASSERT_EQ(expected.at(i), results.at(i)->get_argument(0));
}
}
TEST(algebraic_simplification, multiply_sum_scalar_one)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {1.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{0, 1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, op::ParameterVector{});
pass_manager.run_passes(f);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values = new_const->get_vector<double>();
ASSERT_EQ(values.size(), 1);
ASSERT_EQ(values.at(0), 15);
}
TEST(algebraic_simplification, multiply_sum_vector_one)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {1.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, op::ParameterVector{});
pass_manager.run_passes(f);
auto new_broadcast =
std::dynamic_pointer_cast<op::Broadcast>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_broadcast);
auto new_const = std::dynamic_pointer_cast<op::Constant>(new_broadcast->get_argument(0));
auto values = new_const->get_vector<double>();
ASSERT_EQ(values.size(), 1);
ASSERT_EQ(values.at(0), 5);
}
TEST(algebraic_simplification, multiply_sum_negative)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{2, 5}, AxisSet{1});
auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{0, 1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, op::ParameterVector{});
pass_manager.run_passes(f);
auto f_sum = f->get_results().at(0)->get_argument(0);
ASSERT_EQ(f_sum, sum_fconst1);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment