Commit 953c65f8 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

CSE constant (#1271)

parent f0283c6f
......@@ -64,6 +64,23 @@ using namespace ngraph;
#define TI(x) std::type_index(typeid(x))
static bool cse_constant(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_constant for " << a->get_name() << " and " << b->get_name();
if (a->get_shape() != b->get_shape() || a->get_element_type() != b->get_element_type())
{
return false;
}
auto ca = std::dynamic_pointer_cast<op::Constant>(a);
auto cb = std::dynamic_pointer_cast<op::Constant>(b);
size_t size = shape_size(a->get_shape()) * a->get_element_type().size();
return !memcmp(ca->get_data_ptr(), cb->get_data_ptr(), size);
}
static bool cse_reshape(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_reshape for " << a->get_name() << " and " << b->get_name();
......@@ -123,6 +140,7 @@ static std::unordered_map<std::type_index,
{TI(op::Asin), cse_unarywise},
{TI(op::Atan), cse_unarywise},
{TI(op::Ceiling), cse_unarywise},
{TI(op::Constant), cse_constant},
{TI(op::Cos), cse_unarywise},
{TI(op::Cosh), cse_unarywise},
{TI(op::Exp), cse_unarywise},
......@@ -233,8 +251,7 @@ bool ngraph::pass::CommonSubexpressionElimination::run_on_function(
for (auto n : f->get_ordered_ops())
{
if (n->is_output() || n->is_parameter() ||
n->is_constant() /*we could CSE constants as well*/)
if (n->is_output() || n->is_parameter())
{
continue;
}
......
......@@ -271,3 +271,39 @@ TEST(CSE, reduction_ops)
execute_cse_reduction_test<op::Sum>();
execute_cse_reduction_test<op::Product>();
}
TEST(CSE, constant)
{
Shape zero_shape{0};
auto iconst0 = op::Constant::create(element::i32, Shape{}, {0});
auto iconst0_1 = op::Constant::create(element::i32, Shape{}, {0});
auto iconst1 = op::Constant::create(element::i32, Shape{}, {1});
auto iconst1_1 = op::Constant::create(element::i32, Shape{}, {1});
auto fconst0 = op::Constant::create(element::f32, Shape{}, {0});
auto iconst111 = op::Constant::create(element::i32, Shape{3}, {1, 1, 1});
auto iconst112 = op::Constant::create(element::i32, Shape{3}, {1, 1, 2});
auto abs0 = std::make_shared<op::Abs>(iconst0);
auto abs0_1 = std::make_shared<op::Abs>(iconst0_1);
auto abs1 = std::make_shared<op::Abs>(iconst1);
auto abs1_1 = std::make_shared<op::Abs>(iconst1_1);
auto absf = std::make_shared<op::Abs>(fconst0);
auto abs111 = std::make_shared<op::Abs>(iconst111);
auto abs112 = std::make_shared<op::Abs>(iconst112);
auto f = std::make_shared<Function>(
NodeVector{abs0, abs0_1, abs1, abs1_1, absf, abs111, abs112}, op::ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(abs0->get_argument(0), abs0_1->get_argument(0));
ASSERT_EQ(abs1->get_argument(0), abs1_1->get_argument(0));
ASSERT_NE(abs0->get_argument(0), abs1->get_argument(0));
ASSERT_NE(abs0->get_argument(0), absf->get_argument(0));
ASSERT_NE(abs111->get_argument(0), abs112->get_argument(0));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment