Commit c349056e authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

simplifier for log (#962)

parent 1f37c26d
......@@ -24,8 +24,12 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pattern/matcher.hpp"
......@@ -130,6 +134,24 @@ static bool simplify_add(std::shared_ptr<Node> n)
return false;
}
//`simplify_log` optimizes `log(exp(x)/y)` into `x - log(y)`
static bool simplify_log(std::shared_ptr<Node> n)
{
if (auto div = std::dynamic_pointer_cast<op::Divide>(n->get_argument(0)))
{
if (auto exp = std::dynamic_pointer_cast<op::Exp>(div->get_argument(0)))
{
auto denom = div->get_argument(1);
auto diff = std::make_shared<op::Subtract>(exp->get_argument(0),
std::make_shared<op::Log>(denom));
ngraph::replace_node(n, diff);
return true;
}
}
return false;
}
static size_t reduction_shape_size(const AxisSet& axes, const Shape& shape)
{
size_t prod = 1;
......@@ -224,6 +246,7 @@ static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<No
{TI(op::Add), simplify_add},
{TI(op::Multiply), simplify_multiply},
{TI(op::Sum), simplify_sum},
{TI(op::Log), simplify_log},
});
}
......
......@@ -29,9 +29,14 @@
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
......@@ -315,3 +320,71 @@ TEST(algebraic_simplification, multiply_sum_negative)
auto f_sum = f->get_results().at(0)->get_argument(0);
ASSERT_EQ(f_sum, sum_fconst1);
}
TEST(algebraic_simplification, log_neg_neg)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto exp_a = make_shared<op::Exp>(a);
auto div = exp_a / b;
auto log_div = make_shared<op::Log>(div);
auto neg_inner = make_shared<op::Negative>(log_div);
auto neg2 = make_shared<op::Negative>(neg_inner);
auto neg3 = make_shared<op::Negative>(neg2);
auto neg4 = make_shared<op::Negative>(neg3);
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, op::ParameterVector{a, b});
pass_manager.run_passes(f);
auto sub = std::dynamic_pointer_cast<op::Subtract>(neg_inner->get_argument(0));
ASSERT_TRUE(sub != nullptr);
ASSERT_EQ(sub->get_argument(0), a);
auto new_log = std::dynamic_pointer_cast<op::Log>(sub->get_argument(1));
ASSERT_TRUE(new_log != nullptr);
ASSERT_EQ(new_log->get_argument(0), b);
}
TEST(algebraic_simplification, log_no_exp)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto abs_a = make_shared<op::Abs>(a);
auto div = abs_a / b;
auto log_div = make_shared<op::Log>(div);
auto neg_inner = make_shared<op::Negative>(log_div);
auto neg2 = make_shared<op::Negative>(neg_inner);
auto neg3 = make_shared<op::Negative>(neg2);
auto neg4 = make_shared<op::Negative>(neg3);
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, op::ParameterVector{a, b});
pass_manager.run_passes(f);
ASSERT_EQ(neg_inner->get_argument(0), log_div);
}
TEST(algebraic_simplification, log_no_divide)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto exp_a = make_shared<op::Exp>(a);
auto mul = exp_a * b;
auto log_mul = make_shared<op::Log>(mul);
auto neg_inner = make_shared<op::Negative>(log_mul);
auto neg2 = make_shared<op::Negative>(neg_inner);
auto neg3 = make_shared<op::Negative>(neg2);
auto neg4 = make_shared<op::Negative>(neg3);
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, op::ParameterVector{a, b});
pass_manager.run_passes(f);
ASSERT_EQ(neg_inner->get_argument(0), log_mul);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment