Commit cf568ef9 authored by Louis Feng's avatar Louis Feng Committed by Robert Kimball

Added reshape and broadcast to CSE (#1221)

* reshape inplace without copy data if possible.

* added reshape and broadcast to CSE.

* Fixed debug messages.
parent 41942f8b
......@@ -47,6 +47,7 @@
#include "ngraph/op/product.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/remainder.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
......@@ -63,6 +64,28 @@ using namespace ngraph;
#define TI(x) std::type_index(typeid(x))
static bool cse_reshape(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_reshape for " << a->get_name() << " and " << b->get_name();
auto reshape_a = std::dynamic_pointer_cast<ngraph::op::Reshape>(a);
auto reshape_b = std::dynamic_pointer_cast<ngraph::op::Reshape>(b);
return (a->get_argument(0) == b->get_argument(0)) &&
(reshape_a->get_input_order() == reshape_b->get_input_order()) &&
(reshape_a->get_output_shape() == reshape_b->get_output_shape());
}
static bool cse_broadcast(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_broadcast for " << a->get_name() << " and " << b->get_name();
auto broadcast_a = std::dynamic_pointer_cast<ngraph::op::Broadcast>(a);
auto broadcast_b = std::dynamic_pointer_cast<ngraph::op::Broadcast>(b);
return (a->get_argument(0) == b->get_argument(0)) &&
(broadcast_a->get_broadcast_axes() == broadcast_b->get_broadcast_axes()) &&
(broadcast_a->get_broadcast_shape() == broadcast_b->get_broadcast_shape());
}
static bool cse_unarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_unarywise for " << a->get_name() << " and " << b->get_name();
......@@ -94,38 +117,39 @@ static std::unordered_map<std::type_index,
initialize_ops_to_cse_handlers()
{
return std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>({
{TI(op::Abs), cse_unarywise},
{TI(op::Acos), cse_unarywise},
{TI(op::Asin), cse_unarywise},
{TI(op::Atan), cse_unarywise},
{TI(op::Ceiling), cse_unarywise},
{TI(op::Cos), cse_unarywise},
{TI(op::Cosh), cse_unarywise},
{TI(op::Exp), cse_unarywise},
{TI(op::Floor), cse_unarywise},
{TI(op::Log), cse_unarywise},
{TI(op::Negative), cse_unarywise},
{TI(op::Relu), cse_unarywise},
{TI(op::Sigmoid), cse_unarywise},
{TI(op::Sign), cse_unarywise},
{TI(op::Sin), cse_unarywise},
{TI(op::Sinh), cse_unarywise},
//{TI(op::Softmax), cse_unarywise},
{TI(op::Sqrt), cse_unarywise},
{TI(op::Tan), cse_unarywise},
{TI(op::Tanh), cse_unarywise},
{TI(op::Add), cse_binarywise},
{TI(op::Divide), cse_binarywise},
{TI(op::Maximum), cse_binarywise},
{TI(op::Minimum), cse_binarywise},
{TI(op::Multiply), cse_binarywise},
{TI(op::Power), cse_binarywise},
//{TI(op::Remainder), cse_binarywise},
{TI(op::Subtract), cse_binarywise},
{TI(op::Sum), cse_reduction},
{TI(op::Product), cse_reduction},
});
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>(
{{TI(op::Abs), cse_unarywise},
{TI(op::Acos), cse_unarywise},
{TI(op::Asin), cse_unarywise},
{TI(op::Atan), cse_unarywise},
{TI(op::Ceiling), cse_unarywise},
{TI(op::Cos), cse_unarywise},
{TI(op::Cosh), cse_unarywise},
{TI(op::Exp), cse_unarywise},
{TI(op::Floor), cse_unarywise},
{TI(op::Log), cse_unarywise},
{TI(op::Negative), cse_unarywise},
{TI(op::Relu), cse_unarywise},
{TI(op::Sigmoid), cse_unarywise},
{TI(op::Sign), cse_unarywise},
{TI(op::Sin), cse_unarywise},
{TI(op::Sinh), cse_unarywise},
//{TI(op::Softmax), cse_unarywise},
{TI(op::Sqrt), cse_unarywise},
{TI(op::Tan), cse_unarywise},
{TI(op::Tanh), cse_unarywise},
{TI(op::Add), cse_binarywise},
{TI(op::Divide), cse_binarywise},
{TI(op::Maximum), cse_binarywise},
{TI(op::Minimum), cse_binarywise},
{TI(op::Multiply), cse_binarywise},
{TI(op::Power), cse_binarywise},
//{TI(op::Remainder), cse_binarywise},
{TI(op::Subtract), cse_binarywise},
{TI(op::Sum), cse_reduction},
{TI(op::Product), cse_reduction},
{TI(op::Reshape), cse_reshape},
{TI(op::Broadcast), cse_broadcast}});
}
static std::unordered_map<std::type_index,
......
......@@ -134,6 +134,57 @@ TEST(CSE, abs_add)
ASSERT_EQ(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
}
TEST(CSE, abs_add_reshape_broadcast)
{
Shape zero_shape{1};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto abs_a1 = std::make_shared<op::Abs>(A);
auto abs_b1 = std::make_shared<op::Abs>(B);
auto abs_a2 = std::make_shared<op::Abs>(A);
auto abs_b2 = std::make_shared<op::Abs>(B);
auto add1 = std::make_shared<op::Add>(abs_a1, abs_b1);
auto add2 = std::make_shared<op::Add>(abs_a2, abs_b2);
{
// success case
auto reshape1 = std::make_shared<op::Reshape>(add1, AxisVector{0}, Shape{1, 1});
auto reshape2 = std::make_shared<op::Reshape>(add2, AxisVector{0}, Shape{1, 1});
auto broadcast1 = std::make_shared<op::Broadcast>(reshape1, Shape{1, 1, 3}, AxisSet{2});
auto broadcast2 = std::make_shared<op::Broadcast>(reshape2, Shape{1, 1, 3}, AxisSet{2});
auto f = std::make_shared<Function>(NodeVector{broadcast1, broadcast2},
op::ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
}
{
// fail case
auto reshape1 = std::make_shared<op::Reshape>(add1, AxisVector{0}, Shape{1});
auto reshape2 = std::make_shared<op::Reshape>(add2, AxisVector{0}, Shape{1, 1});
auto f =
std::make_shared<Function>(NodeVector{reshape1, reshape2}, op::ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
ASSERT_NE(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
}
{
// fail case
auto broadcast1 = std::make_shared<op::Broadcast>(add1, Shape{1, 2}, AxisSet{1});
auto broadcast2 = std::make_shared<op::Broadcast>(add2, Shape{1, 1, 2}, AxisSet{1, 2});
auto f = std::make_shared<Function>(NodeVector{broadcast1, broadcast2},
op::ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
ASSERT_NE(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
}
}
TEST(CSE, abs_add_abs_add)
{
Shape zero_shape{0};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment