Commit 50283370 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Refactor ReshapeSinking into individual sinkers (#2222)

* Move CPU ReshapeSinking to Core pass

* Modify clang compile error

* Fix for style-apply check

* refactor into invidivual sinkerso

* adding individual sinkers

* add a header back

* fix build break
parent 02d4aa59
...@@ -44,6 +44,8 @@ extern template ngraph::AxisVector ...@@ -44,6 +44,8 @@ extern template ngraph::AxisVector
extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input, extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input,
ngraph::AxisVector order); ngraph::AxisVector order);
using ReshapeMap = std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>>;
static std::shared_ptr<op::Reshape> combine_reshapes(std::shared_ptr<op::Reshape> r1, static std::shared_ptr<op::Reshape> combine_reshapes(std::shared_ptr<op::Reshape> r1,
std::shared_ptr<op::Reshape> r2) std::shared_ptr<op::Reshape> r2)
{ {
...@@ -232,36 +234,15 @@ static void convert_binary_to_default_order( ...@@ -232,36 +234,15 @@ static void convert_binary_to_default_order(
reorders[binary] = reorders.at(right); reorders[binary] = reorders.at(right);
} }
//The goal of ReshapeSinking is to remove static void sink_reshape(std::shared_ptr<op::Reshape> reshape,
//round-trip reshapes(i.e. nhwc->nchw(nchw-only-op)->nhwc) ReshapeMap& reorders,
//around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool) std::set<std::shared_ptr<Node>>& reshapes_to_delete)
//This is achieved by both **sinking**, propagating reshapes
//through ops towards op::Results,
//or **swimming** Reshapes up towards op::Parameter
//For each op type we support we can either combine
//two reshapes by replacing the existing Reshape,
//materialize pending reshapes if they can't be propagated through op
bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Function> f)
{ {
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>> reorders; auto orig_reshape = reorders.at(reshape->get_argument(0));
NodeVector results;
std::set<std::shared_ptr<Node>> reshapes_to_delete;
for (auto n : f->get_ordered_ops())
{
NGRAPH_DEBUG << "Processing node " << n->get_name();
if (n->is_output())
{
results.push_back(n);
}
if (auto reshape = std::dynamic_pointer_cast<op::Reshape>(n))
{
auto orig_reshape = reorders.at(n->get_argument(0));
if (!reshape->get_is_transpose()) if (!reshape->get_is_transpose())
{ {
NGRAPH_DEBUG << "Materializing " << describe_reshape(orig_reshape) NGRAPH_DEBUG << "Materializing " << describe_reshape(orig_reshape) << " for reshape "
<< " for reshape " << reshape->get_name(); << reshape->get_name();
insert_reshape(reshape, orig_reshape, 0); insert_reshape(reshape, orig_reshape, 0);
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete); mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
reorders[reshape] = create_default_reshape(reshape); reorders[reshape] = create_default_reshape(reshape);
...@@ -277,39 +258,41 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -277,39 +258,41 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
ngraph::replace_node(reshape, new_reshape); ngraph::replace_node(reshape, new_reshape);
reorders[new_reshape] = new_reshape; reorders[new_reshape] = new_reshape;
NGRAPH_DEBUG << "Combining " << describe_reshape(orig_reshape) << " and" NGRAPH_DEBUG << "Combining " << describe_reshape(orig_reshape) << " and"
<< describe_reshape(reshape) << " into " << describe_reshape(reshape) << " into " << describe_reshape(new_reshape);
<< describe_reshape(new_reshape);
} }
} }
else if (auto unary = std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
{ static void sink_unary(std::shared_ptr<op::util::UnaryElementwiseArithmetic> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(n->get_argument(0)); auto arg_reshape = reorders.at(n->get_argument(0));
NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " << n->get_name();
<< n->get_name();
reorders[n] = reorders[n->get_argument(0)]; reorders[n] = reorders[n->get_argument(0)];
} }
else if (auto binary = std::dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n))
{ static void sink_binary(std::shared_ptr<op::util::BinaryElementwiseArithmetic> binary,
auto left = n->get_argument(0); ReshapeMap& reorders,
auto right = n->get_argument(1); std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto left = binary->get_argument(0);
auto right = binary->get_argument(1);
if (reorders.at(left)->get_input_order() == reorders.at(right)->get_input_order()) if (reorders.at(left)->get_input_order() == reorders.at(right)->get_input_order())
{ {
NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for " NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for "
<< n->get_name(); << binary->get_name();
reorders[n] = reorders.at(left); reorders[binary] = reorders.at(left);
//at this point, both reshapes will be eventually removed //at this point, both reshapes will be eventually removed
mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete); mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete);
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete); mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
} }
else if (reorders.at(left)->get_input_order() == else if (reorders.at(left)->get_input_order() == ngraph::get_default_order(left->get_shape()))
ngraph::get_default_order(left->get_shape()))
{ {
convert_binary_to_default_order( convert_binary_to_default_order(
binary, binary->get_inputs().at(0), right, reorders, reshapes_to_delete); binary, binary->get_inputs().at(0), right, reorders, reshapes_to_delete);
} }
else if (reorders.at(right)->get_input_order() == else if (reorders.at(right)->get_input_order() == ngraph::get_default_order(right->get_shape()))
ngraph::get_default_order(right->get_shape()))
{ {
convert_binary_to_default_order( convert_binary_to_default_order(
binary, binary->get_inputs().at(1), left, reorders, reshapes_to_delete); binary, binary->get_inputs().at(1), left, reorders, reshapes_to_delete);
...@@ -324,14 +307,13 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -324,14 +307,13 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
insert_reshape(binary, reorders.at(left), 0); insert_reshape(binary, reorders.at(left), 0);
insert_reshape(binary, reorders.at(right), 1); insert_reshape(binary, reorders.at(right), 1);
} }
} }
else if (auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(n))
{ static void sink_quantize(std::shared_ptr<op::Quantize> quantize,
reorders[goe] = create_default_reshape(goe); ReshapeMap& reorders,
} std::set<std::shared_ptr<Node>>& reshapes_to_delete)
else if (auto quantize = std::dynamic_pointer_cast<op::Quantize>(n)) {
{ auto arg_reshape = reorders.at(quantize->get_argument(0));
auto arg_reshape = reorders.at(n->get_argument(0));
AxisSet axes_in_def_order = AxisSet axes_in_def_order =
get_quantization_axes_in_default_order(arg_reshape, quantize->get_axes()); get_quantization_axes_in_default_order(arg_reshape, quantize->get_axes());
auto new_quantize = std::make_shared<op::Quantize>(quantize->get_argument(0), auto new_quantize = std::make_shared<op::Quantize>(quantize->get_argument(0),
...@@ -343,10 +325,13 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -343,10 +325,13 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
ngraph::replace_node(quantize, new_quantize); ngraph::replace_node(quantize, new_quantize);
reorders[new_quantize] = arg_reshape; reorders[new_quantize] = arg_reshape;
} }
else if (auto dequantize = std::dynamic_pointer_cast<op::Dequantize>(n))
{ static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize,
auto arg_reshape = reorders.at(n->get_argument(0)); ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(dequantize->get_argument(0));
AxisSet axes_in_def_order = AxisSet axes_in_def_order =
get_quantization_axes_in_default_order(arg_reshape, dequantize->get_axes()); get_quantization_axes_in_default_order(arg_reshape, dequantize->get_axes());
auto new_dequantize = std::make_shared<op::Dequantize>(dequantize->get_argument(0), auto new_dequantize = std::make_shared<op::Dequantize>(dequantize->get_argument(0),
...@@ -357,33 +342,90 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -357,33 +342,90 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
ngraph::replace_node(dequantize, new_dequantize); ngraph::replace_node(dequantize, new_dequantize);
reorders[new_dequantize] = arg_reshape; reorders[new_dequantize] = arg_reshape;
} }
else
{ static void materialize_shapes(std::shared_ptr<Node> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
//skip multiple output nodes and deal with GOEs exclusively //skip multiple output nodes and deal with GOEs exclusively
if (n->get_outputs().size() > 1) if (n->get_outputs().size() > 1)
{ {
continue; return;
} }
//TODO: multiple outputs
for (size_t i = 0; i < n->get_arguments().size(); i++) for (size_t i = 0; i < n->get_arguments().size(); i++)
{ {
//materialize all pending reshapes, flush pending reshapes //materialize all pending reshapes, flush pending reshapes
auto arg = n->get_argument(i); auto arg = n->get_argument(i);
if (reorders.count(arg) != 0) if (reorders.count(arg) != 0)
{ {
NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) << " for "
<< " for " << arg->get_name(); << arg->get_name();
mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete); mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete);
insert_reshape(n, reorders.at(arg), i); insert_reshape(n, reorders.at(arg), i);
//no swimming up //no swimming up
} }
} }
reorders[n] = create_default_reshape(n); reorders[n] = create_default_reshape(n);
}
//The goal of ReshapeSinking is to remove
//round-trip reshapes(i.e. nhwc->nchw(nchw-only-op)->nhwc)
//around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool)
//This is achieved by both **sinking**, propagating reshapes
//through ops towards op::Results,
//or **swimming** Reshapes up towards op::Parameter
//For each op type we support we can either combine
//two reshapes by replacing the existing Reshape,
//materialize pending reshapes if they can't be propagated through op
bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Function> f)
{
ReshapeMap reorders;
NodeVector results;
std::set<std::shared_ptr<Node>> reshapes_to_delete;
//STEP 1 : Sink or Swim reshapes away for op clusters
for (auto n : f->get_ordered_ops())
{
NGRAPH_DEBUG << "Processing node " << n->get_name();
//collect all Result nodes for a sanity check
if (n->is_output())
{
results.push_back(n);
}
if (auto reshape = std::dynamic_pointer_cast<op::Reshape>(n))
{
sink_reshape(reshape, reorders, reshapes_to_delete);
}
else if (auto unary = std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
{
sink_unary(unary, reorders, reshapes_to_delete);
}
else if (auto binary = std::dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n))
{
sink_binary(binary, reorders, reshapes_to_delete);
}
else if (auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(n))
{
reorders[goe] = create_default_reshape(goe);
}
else if (auto quantize = std::dynamic_pointer_cast<op::Quantize>(n))
{
sink_quantize(quantize, reorders, reshapes_to_delete);
}
else if (auto dequantize = std::dynamic_pointer_cast<op::Dequantize>(n))
{
sink_dequantize(dequantize, reorders, reshapes_to_delete);
}
else
{
materialize_shapes(n, reorders, reshapes_to_delete);
} }
} }
//purge all the reshapes we either sunk or swam. //STEP 2: purge all the reshapes we either sunk or swam.
for (auto r : reshapes_to_delete) for (auto r : reshapes_to_delete)
{ {
delete_reshape(r); delete_reshape(r);
...@@ -397,7 +439,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -397,7 +439,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
<< " op::Result = " << *r << ", Arg = " << *r->get_argument(0); << " op::Result = " << *r << ", Arg = " << *r->get_argument(0);
} }
//fix wrong shape info wholesale //STEP 3: fix wrong shape info wholesale
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
{ {
n->revalidate_and_infer_types(); n->revalidate_and_infer_types();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment