Commit 659d2565 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Algebraic Simplification for Product (#949)

* product simplifier

* char -> signed char
parent 41c50b44
......@@ -171,6 +171,19 @@ static std::shared_ptr<Node>
return op::Constant::create<T>(type, Shape{}, {sum_cnst});
}
template <typename T>
static std::shared_ptr<Node>
pow_by(element::Type type, size_t multiplier, std::shared_ptr<op::Constant> cnst)
{
T prod = static_cast<T>(1);
T val = cnst->get_vector<T>().at(0);
for (size_t i = 0; i < multiplier; i++)
{
prod *= val;
}
return op::Constant::create<T>(type, Shape{}, {prod});
}
static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst, size_t multiplier)
{
if (cnst->get_element_type() == element::i32)
......@@ -179,7 +192,7 @@ static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst
}
else if (cnst->get_element_type() == element::i8)
{
return multiply_by<char>(cnst->get_element_type(), multiplier, cnst);
return multiply_by<signed char>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f32)
{
......@@ -193,13 +206,40 @@ static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst
return nullptr;
}
//`simplify_sum` optimizes the following case:
static std::shared_ptr<Node> get_prod_constant(std::shared_ptr<op::Constant> cnst,
size_t multiplier)
{
if (cnst->get_element_type() == element::i32)
{
return pow_by<int>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::i8)
{
return pow_by<signed char>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f32)
{
return pow_by<float>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f64)
{
return pow_by<double>(cnst->get_element_type(), multiplier, cnst);
}
return nullptr;
}
//`simplify_reduction` optimizes the following case:
//sum(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
//where constant2's values are equal to scalar_constant * shape_size(reduction_axes)
static bool simplify_sum(std::shared_ptr<Node> n)
//product(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
//where constant2's values are equal to scalar_constant ^ shape_size(reduction_axes)
template <typename T,
std::shared_ptr<Node> (*F)(std::shared_ptr<op::Constant> cnst, size_t multiplier)>
static bool simplify_reduction(std::shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_sum for " << n->get_name();
auto sum = std::dynamic_pointer_cast<op::Sum>(n);
NGRAPH_DEBUG << "In simplify_reduction for " << n->get_name();
auto reduction = std::dynamic_pointer_cast<T>(n);
auto broadcast = std::dynamic_pointer_cast<op::Broadcast>(n->get_argument(0));
if (!broadcast)
......@@ -215,39 +255,44 @@ static bool simplify_sum(std::shared_ptr<Node> n)
return false;
}
auto multiplier = reduction_shape_size(sum->get_reduction_axes(), broadcast->get_shape());
auto sum_cnst = get_sum_constant(cnst, multiplier);
auto multiplier = reduction_shape_size(reduction->get_reduction_axes(), broadcast->get_shape());
auto reduction_cnst = F(cnst, multiplier);
//Unsupported type
if (!sum_cnst)
if (!reduction_cnst)
{
NGRAPH_DEBUG << "unsupported type";
return false;
}
if (sum->get_shape().size() > 0)
if (reduction->get_shape().size() > 0)
{
ngraph::AxisSet axes{};
for (size_t i = 0; i < sum->get_shape().size(); i++)
for (size_t i = 0; i < reduction->get_shape().size(); i++)
{
axes.insert(i);
}
sum_cnst = std::make_shared<op::Broadcast>(sum_cnst, sum->get_shape(), axes);
reduction_cnst =
std::make_shared<op::Broadcast>(reduction_cnst, reduction->get_shape(), axes);
}
ngraph::replace_node(n, sum_cnst);
ngraph::replace_node(n, reduction_cnst);
return true;
}
static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>
initialize_ops_to_simplifiers()
{
return std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>({
{TI(op::Add), simplify_add},
{TI(op::Multiply), simplify_multiply},
{TI(op::Sum), simplify_sum},
{TI(op::Log), simplify_log},
});
return std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>(
{{TI(op::Add), simplify_add},
{TI(op::Multiply), simplify_multiply},
{TI(op::Sum),
std::function<bool(std::shared_ptr<Node>)>{
simplify_reduction<op::Sum, get_sum_constant>}},
{TI(op::Product),
std::function<bool(std::shared_ptr<Node>)>{
simplify_reduction<op::Product, get_prod_constant>}},
{TI(op::Log), simplify_log}});
}
static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>
......
......@@ -34,6 +34,7 @@
#include "ngraph/op/log.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/subtract.hpp"
......@@ -265,6 +266,62 @@ TEST(algebraic_simplification, multiply_negative_tests)
}
}
TEST(algebraic_simplification, multiply_prod_vector_one)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {2.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, op::ParameterVector{});
pass_manager.run_passes(f);
auto new_broadcast =
std::dynamic_pointer_cast<op::Broadcast>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_broadcast);
auto new_const = std::dynamic_pointer_cast<op::Constant>(new_broadcast->get_argument(0));
auto values = new_const->get_vector<double>();
ASSERT_EQ(values.size(), 1);
ASSERT_EQ(values.at(0), 32);
}
TEST(algebraic_simplification, multiply_prod_scalar_one)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {2.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{0, 1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, op::ParameterVector{});
pass_manager.run_passes(f);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values = new_const->get_vector<double>();
ASSERT_EQ(values.size(), 1);
ASSERT_EQ(values.at(0), 32768);
}
TEST(algebraic_simplification, multiply_prod_negative)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{2, 5}, AxisSet{1});
auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{0, 1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, op::ParameterVector{});
pass_manager.run_passes(f);
auto f_prod = f->get_results().at(0)->get_argument(0);
ASSERT_EQ(f_prod, prod_fconst1);
}
TEST(algebraic_simplification, multiply_sum_scalar_one)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {1.0});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment