Commit 659d2565 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Algebraic Simplification for Product (#949)

* product simplifier

* char -> signed char
parent 41c50b44
...@@ -171,6 +171,19 @@ static std::shared_ptr<Node> ...@@ -171,6 +171,19 @@ static std::shared_ptr<Node>
return op::Constant::create<T>(type, Shape{}, {sum_cnst}); return op::Constant::create<T>(type, Shape{}, {sum_cnst});
} }
template <typename T>
static std::shared_ptr<Node>
pow_by(element::Type type, size_t multiplier, std::shared_ptr<op::Constant> cnst)
{
T prod = static_cast<T>(1);
T val = cnst->get_vector<T>().at(0);
for (size_t i = 0; i < multiplier; i++)
{
prod *= val;
}
return op::Constant::create<T>(type, Shape{}, {prod});
}
static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst, size_t multiplier) static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst, size_t multiplier)
{ {
if (cnst->get_element_type() == element::i32) if (cnst->get_element_type() == element::i32)
...@@ -179,7 +192,7 @@ static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst ...@@ -179,7 +192,7 @@ static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst
} }
else if (cnst->get_element_type() == element::i8) else if (cnst->get_element_type() == element::i8)
{ {
return multiply_by<char>(cnst->get_element_type(), multiplier, cnst); return multiply_by<signed char>(cnst->get_element_type(), multiplier, cnst);
} }
else if (cnst->get_element_type() == element::f32) else if (cnst->get_element_type() == element::f32)
{ {
...@@ -193,13 +206,40 @@ static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst ...@@ -193,13 +206,40 @@ static std::shared_ptr<Node> get_sum_constant(std::shared_ptr<op::Constant> cnst
return nullptr; return nullptr;
} }
//`simplify_sum` optimizes the following case: static std::shared_ptr<Node> get_prod_constant(std::shared_ptr<op::Constant> cnst,
size_t multiplier)
{
if (cnst->get_element_type() == element::i32)
{
return pow_by<int>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::i8)
{
return pow_by<signed char>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f32)
{
return pow_by<float>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f64)
{
return pow_by<double>(cnst->get_element_type(), multiplier, cnst);
}
return nullptr;
}
//`simplify_reduction` optimizes the following case:
//sum(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant) //sum(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
//where constant2's values are equal to scalar_constant * shape_size(reduction_axes) //where constant2's values are equal to scalar_constant * shape_size(reduction_axes)
static bool simplify_sum(std::shared_ptr<Node> n) //product(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
//where constant2's values are equal to scalar_constant ^ shape_size(reduction_axes)
template <typename T,
std::shared_ptr<Node> (*F)(std::shared_ptr<op::Constant> cnst, size_t multiplier)>
static bool simplify_reduction(std::shared_ptr<Node> n)
{ {
NGRAPH_DEBUG << "In simplify_sum for " << n->get_name(); NGRAPH_DEBUG << "In simplify_reduction for " << n->get_name();
auto sum = std::dynamic_pointer_cast<op::Sum>(n); auto reduction = std::dynamic_pointer_cast<T>(n);
auto broadcast = std::dynamic_pointer_cast<op::Broadcast>(n->get_argument(0)); auto broadcast = std::dynamic_pointer_cast<op::Broadcast>(n->get_argument(0));
if (!broadcast) if (!broadcast)
...@@ -215,39 +255,44 @@ static bool simplify_sum(std::shared_ptr<Node> n) ...@@ -215,39 +255,44 @@ static bool simplify_sum(std::shared_ptr<Node> n)
return false; return false;
} }
auto multiplier = reduction_shape_size(sum->get_reduction_axes(), broadcast->get_shape()); auto multiplier = reduction_shape_size(reduction->get_reduction_axes(), broadcast->get_shape());
auto sum_cnst = get_sum_constant(cnst, multiplier); auto reduction_cnst = F(cnst, multiplier);
//Unsupported type //Unsupported type
if (!sum_cnst) if (!reduction_cnst)
{ {
NGRAPH_DEBUG << "unsupported type"; NGRAPH_DEBUG << "unsupported type";
return false; return false;
} }
if (sum->get_shape().size() > 0) if (reduction->get_shape().size() > 0)
{ {
ngraph::AxisSet axes{}; ngraph::AxisSet axes{};
for (size_t i = 0; i < sum->get_shape().size(); i++) for (size_t i = 0; i < reduction->get_shape().size(); i++)
{ {
axes.insert(i); axes.insert(i);
} }
sum_cnst = std::make_shared<op::Broadcast>(sum_cnst, sum->get_shape(), axes); reduction_cnst =
std::make_shared<op::Broadcast>(reduction_cnst, reduction->get_shape(), axes);
} }
ngraph::replace_node(n, sum_cnst); ngraph::replace_node(n, reduction_cnst);
return true; return true;
} }
static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>> static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>
initialize_ops_to_simplifiers() initialize_ops_to_simplifiers()
{ {
return std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>({ return std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>(
{TI(op::Add), simplify_add}, {{TI(op::Add), simplify_add},
{TI(op::Multiply), simplify_multiply}, {TI(op::Multiply), simplify_multiply},
{TI(op::Sum), simplify_sum}, {TI(op::Sum),
{TI(op::Log), simplify_log}, std::function<bool(std::shared_ptr<Node>)>{
}); simplify_reduction<op::Sum, get_sum_constant>}},
{TI(op::Product),
std::function<bool(std::shared_ptr<Node>)>{
simplify_reduction<op::Product, get_prod_constant>}},
{TI(op::Log), simplify_log}});
} }
static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>> static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "ngraph/op/log.hpp" #include "ngraph/op/log.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp" #include "ngraph/op/negative.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/sqrt.hpp" #include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
...@@ -265,6 +266,62 @@ TEST(algebraic_simplification, multiply_negative_tests) ...@@ -265,6 +266,62 @@ TEST(algebraic_simplification, multiply_negative_tests)
} }
} }
TEST(algebraic_simplification, multiply_prod_vector_one)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {2.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, op::ParameterVector{});
pass_manager.run_passes(f);
auto new_broadcast =
std::dynamic_pointer_cast<op::Broadcast>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_broadcast);
auto new_const = std::dynamic_pointer_cast<op::Constant>(new_broadcast->get_argument(0));
auto values = new_const->get_vector<double>();
ASSERT_EQ(values.size(), 1);
ASSERT_EQ(values.at(0), 32);
}
TEST(algebraic_simplification, multiply_prod_scalar_one)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {2.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{0, 1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, op::ParameterVector{});
pass_manager.run_passes(f);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values = new_const->get_vector<double>();
ASSERT_EQ(values.size(), 1);
ASSERT_EQ(values.at(0), 32768);
}
TEST(algebraic_simplification, multiply_prod_negative)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{2, 5}, AxisSet{1});
auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{0, 1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, op::ParameterVector{});
pass_manager.run_passes(f);
auto f_prod = f->get_results().at(0)->get_argument(0);
ASSERT_EQ(f_prod, prod_fconst1);
}
TEST(algebraic_simplification, multiply_sum_scalar_one) TEST(algebraic_simplification, multiply_sum_scalar_one)
{ {
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {1.0}); auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {1.0});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment