Unverified Commit bf365b12 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Merge pull request #2313 from NervanaSystems/krovatkin/rs_concat

Sink Concat
parents ad30e973 80923525
......@@ -26,6 +26,7 @@
#include "ngraph/log.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -230,6 +231,32 @@ static void convert_binary_to_default_order(
reorders[binary] = reorders.at(right);
}
static void materialize_shapes(std::shared_ptr<Node> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
//skip multiple output nodes and deal with GOEs exclusively
if (n->get_outputs().size() > 1)
{
return;
}
for (size_t i = 0; i < n->get_arguments().size(); i++)
{
//materialize all pending reshapes, flush pending reshapes
auto arg = n->get_argument(i);
if (reorders.count(arg) != 0)
{
NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) << " for "
<< arg->get_name();
mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete);
insert_reshape(n, reorders.at(arg), i);
//no swimming up
}
}
reorders[n] = create_default_reshape(n);
}
static void sink_reshape(std::shared_ptr<op::Reshape> reshape,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
......@@ -379,6 +406,55 @@ static void sink_quantize(std::shared_ptr<op::Quantize> quantize,
reorders[new_quantize] = arg_reshape;
}
static void sink_concat(std::shared_ptr<op::Concat> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(n->get_argument(0));
auto order = arg_reshape->get_input_order();
// we need the correct input shape to produce the right output shape
// we are going to create a label of the right input shape,
// so a new slice will have the right shape
auto def_order = ngraph::get_permutation_to_default_order(order);
auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order);
auto dummy_correct_shape =
std::make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
NodeVector new_args;
new_args.push_back(dummy_correct_shape);
for (size_t i = 1; i < n->get_input_size(); i++)
{
auto iarg_reshape = reorders.at(n->get_argument(i));
auto iorder = iarg_reshape->get_input_order();
if (iorder != order)
{
NGRAPH_DEBUG << " input order at " << i << "-th arg is different from first arg";
materialize_shapes(n, reorders, reshapes_to_delete);
return;
}
auto iinput_shape = ngraph::apply_permutation(iarg_reshape->get_shape(), def_order);
auto idummy_correct_shape =
std::make_shared<pattern::op::Label>(iarg_reshape->get_element_type(), iinput_shape);
new_args.push_back(idummy_correct_shape);
}
auto new_axis = order.at(n->get_concatenation_axis());
auto new_concat = std::make_shared<op::Concat>(new_args, new_axis);
//put back the original arguments
for (size_t i = 0; i < new_concat->get_input_size(); i++)
{
ngraph::replace_node(new_args.at(i), n->get_argument(i));
}
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name();
ngraph::replace_node(n, new_concat);
auto new_reshape = std::make_shared<op::Reshape>(new_concat, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_concat] = new_reshape;
}
static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
......@@ -396,32 +472,6 @@ static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize,
reorders[new_dequantize] = arg_reshape;
}
static void materialize_shapes(std::shared_ptr<Node> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
//skip multiple output nodes and deal with GOEs exclusively
if (n->get_outputs().size() > 1)
{
return;
}
for (size_t i = 0; i < n->get_arguments().size(); i++)
{
//materialize all pending reshapes, flush pending reshapes
auto arg = n->get_argument(i);
if (reorders.count(arg) != 0)
{
NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) << " for "
<< arg->get_name();
mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete);
insert_reshape(n, reorders.at(arg), i);
//no swimming up
}
}
reorders[n] = create_default_reshape(n);
}
//The goal of ReshapeSinking is to remove
//round-trip reshapes(i.e. nhwc->nchw(nchw-only-op)->nhwc)
//around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool)
......@@ -493,6 +543,10 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
{
sink_pad(pad, reorders, reshapes_to_delete);
}
else if (auto concat = std::dynamic_pointer_cast<op::Concat>(n))
{
sink_concat(concat, reorders, reshapes_to_delete);
}
else
{
materialize_shapes(n, reorders, reshapes_to_delete);
......
......@@ -203,3 +203,76 @@ TEST(reshape_sinking, slice_pad)
size_t before_after = count_ops_of_type<op::Reshape>(f);
ASSERT_LE(before_after, before_count);
}
TEST(reshape_sinking, concat)
{
Shape shape{};
Shape shape_w{1, 1, 1, 1};
Shape shape_x{1, 3, 3, 1};
Shape shape_b{1, 3, 3, 1};
Shape r_shape{1, 3, 3, 2};
auto B_ = op::Constant::create(element::f32, shape_w, {3});
auto B = make_shared<op::Reshape>(B_, AxisVector{3, 2, 0, 1}, Shape{1, 1, 1, 1}); /* nchw */
auto A_ = make_shared<op::Parameter>(element::f32, shape_x);
auto A = make_shared<op::Reshape>(A_, AxisVector{0, 3, 1, 2}, Shape{1, 1, 3, 3}); /* nchw */
auto C = op::Constant::create(element::f32, Shape{1}, {2});
auto R = make_shared<op::Parameter>(element::f32, r_shape);
auto conv = make_shared<op::Convolution>(A,
B,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto reshape_conv =
make_shared<op::Reshape>(conv, AxisVector{0, 2, 3, 1}, Shape{1, 3, 3, 1}); /* nhwc */
auto broadcast = make_shared<op::Broadcast>(C, reshape_conv->get_shape(), AxisSet{0, 1, 2});
auto add = broadcast + reshape_conv;
auto B1_ = op::Constant::create(element::f32, shape_w, {3});
auto B1 = make_shared<op::Reshape>(B1_, AxisVector{3, 2, 0, 1}, Shape{1, 1, 1, 1});
auto A1_ = make_shared<op::Parameter>(element::f32, shape_x);
auto A1 = make_shared<op::Reshape>(A1_, AxisVector{0, 3, 1, 2}, Shape{1, 1, 3, 3});
auto C1 = op::Constant::create(element::f32, Shape{1}, {2});
auto R1 = make_shared<op::Parameter>(element::f32, r_shape);
auto conv1 = make_shared<op::Convolution>(A1,
B1,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto reshape_conv1 = make_shared<op::Reshape>(conv1, AxisVector{0, 2, 3, 1}, Shape{1, 3, 3, 1});
auto broadcast1 = make_shared<op::Broadcast>(C1, reshape_conv->get_shape(), AxisSet{0, 1, 2});
auto add1 = broadcast1 + reshape_conv1;
auto concat = make_shared<op::Concat>(NodeVector{add, add1}, 3);
auto relu = make_shared<op::Relu>(concat);
auto reshape_relu =
make_shared<op::Reshape>(relu, AxisVector{0, 3, 1, 2}, Shape{1, 2, 3, 3}); /* nchw */
auto B2_ = op::Constant::create(element::f32, Shape{1, 1, 2, 1}, {2});
auto B2 = make_shared<op::Reshape>(B2_, AxisVector{3, 2, 0, 1}, Shape{1, 2, 1, 1});
auto conv2 = make_shared<op::Convolution>(reshape_relu,
B2,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto reshape_conv2 =
make_shared<op::Reshape>(conv2, AxisVector{0, 2, 3, 1}, Shape{1, 3, 3, 1}); /* nhwc */
auto f = make_shared<Function>(reshape_conv2, ParameterVector{A_, A1_});
pass::Manager pass_manager;
size_t before_count = count_ops_of_type<op::Reshape>(f);
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::ReshapeSinking>();
pass_manager.register_pass<pass::ReshapeElimination>();
pass_manager.register_pass<pass::CommonSubexpressionElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.run_passes(f);
size_t before_after = count_ops_of_type<op::Reshape>(f);
ASSERT_LE(before_after, before_count);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment