Commit 50283370 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Refactor ReshapeSinking into individual sinkers (#2222)

* Move CPU ReshapeSinking to Core pass

* Modify clang compile error

* Fix for style-apply check

* refactor into invidivual sinkerso

* adding individual sinkers

* add a header back

* fix build break
parent 02d4aa59
...@@ -44,6 +44,8 @@ extern template ngraph::AxisVector ...@@ -44,6 +44,8 @@ extern template ngraph::AxisVector
extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input, extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input,
ngraph::AxisVector order); ngraph::AxisVector order);
using ReshapeMap = std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>>;
static std::shared_ptr<op::Reshape> combine_reshapes(std::shared_ptr<op::Reshape> r1, static std::shared_ptr<op::Reshape> combine_reshapes(std::shared_ptr<op::Reshape> r1,
std::shared_ptr<op::Reshape> r2) std::shared_ptr<op::Reshape> r2)
{ {
...@@ -232,6 +234,142 @@ static void convert_binary_to_default_order( ...@@ -232,6 +234,142 @@ static void convert_binary_to_default_order(
reorders[binary] = reorders.at(right); reorders[binary] = reorders.at(right);
} }
static void sink_reshape(std::shared_ptr<op::Reshape> reshape,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto orig_reshape = reorders.at(reshape->get_argument(0));
if (!reshape->get_is_transpose())
{
NGRAPH_DEBUG << "Materializing " << describe_reshape(orig_reshape) << " for reshape "
<< reshape->get_name();
insert_reshape(reshape, orig_reshape, 0);
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
reorders[reshape] = create_default_reshape(reshape);
}
else
{
//combine both reshapes
auto new_reshape = combine_reshapes(orig_reshape, reshape);
//remove original reshape now it's combined with a new one
//should be safe to remove an already detached node
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
//replace reshape with combined one
ngraph::replace_node(reshape, new_reshape);
reorders[new_reshape] = new_reshape;
NGRAPH_DEBUG << "Combining " << describe_reshape(orig_reshape) << " and"
<< describe_reshape(reshape) << " into " << describe_reshape(new_reshape);
}
}
static void sink_unary(std::shared_ptr<op::util::UnaryElementwiseArithmetic> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(n->get_argument(0));
NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " << n->get_name();
reorders[n] = reorders[n->get_argument(0)];
}
static void sink_binary(std::shared_ptr<op::util::BinaryElementwiseArithmetic> binary,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto left = binary->get_argument(0);
auto right = binary->get_argument(1);
if (reorders.at(left)->get_input_order() == reorders.at(right)->get_input_order())
{
NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for "
<< binary->get_name();
reorders[binary] = reorders.at(left);
//at this point, both reshapes will be eventually removed
mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete);
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
}
else if (reorders.at(left)->get_input_order() == ngraph::get_default_order(left->get_shape()))
{
convert_binary_to_default_order(
binary, binary->get_inputs().at(0), right, reorders, reshapes_to_delete);
}
else if (reorders.at(right)->get_input_order() == ngraph::get_default_order(right->get_shape()))
{
convert_binary_to_default_order(
binary, binary->get_inputs().at(1), left, reorders, reshapes_to_delete);
}
else
{
NGRAPH_DEBUG << "Materializing both reshapes for " << binary->get_name();
NGRAPH_DEBUG << "Left = " << describe_reshape(reorders.at(left));
NGRAPH_DEBUG << "Right = " << describe_reshape(reorders.at(right));
mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete);
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
insert_reshape(binary, reorders.at(left), 0);
insert_reshape(binary, reorders.at(right), 1);
}
}
static void sink_quantize(std::shared_ptr<op::Quantize> quantize,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(quantize->get_argument(0));
AxisSet axes_in_def_order =
get_quantization_axes_in_default_order(arg_reshape, quantize->get_axes());
auto new_quantize = std::make_shared<op::Quantize>(quantize->get_argument(0),
quantize->get_argument(1),
quantize->get_argument(2),
quantize->get_element_type(),
axes_in_def_order,
quantize->get_round_mode());
ngraph::replace_node(quantize, new_quantize);
reorders[new_quantize] = arg_reshape;
}
static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(dequantize->get_argument(0));
AxisSet axes_in_def_order =
get_quantization_axes_in_default_order(arg_reshape, dequantize->get_axes());
auto new_dequantize = std::make_shared<op::Dequantize>(dequantize->get_argument(0),
dequantize->get_argument(1),
dequantize->get_argument(2),
dequantize->get_element_type(),
axes_in_def_order);
ngraph::replace_node(dequantize, new_dequantize);
reorders[new_dequantize] = arg_reshape;
}
static void materialize_shapes(std::shared_ptr<Node> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
//skip multiple output nodes and deal with GOEs exclusively
if (n->get_outputs().size() > 1)
{
return;
}
for (size_t i = 0; i < n->get_arguments().size(); i++)
{
//materialize all pending reshapes, flush pending reshapes
auto arg = n->get_argument(i);
if (reorders.count(arg) != 0)
{
NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) << " for "
<< arg->get_name();
mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete);
insert_reshape(n, reorders.at(arg), i);
//no swimming up
}
}
reorders[n] = create_default_reshape(n);
}
//The goal of ReshapeSinking is to remove //The goal of ReshapeSinking is to remove
//round-trip reshapes(i.e. nhwc->nchw(nchw-only-op)->nhwc) //round-trip reshapes(i.e. nhwc->nchw(nchw-only-op)->nhwc)
//around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool) //around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool)
...@@ -243,13 +381,15 @@ static void convert_binary_to_default_order( ...@@ -243,13 +381,15 @@ static void convert_binary_to_default_order(
//materialize pending reshapes if they can't be propagated through op //materialize pending reshapes if they can't be propagated through op
bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Function> f) bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Function> f)
{ {
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>> reorders; ReshapeMap reorders;
NodeVector results; NodeVector results;
std::set<std::shared_ptr<Node>> reshapes_to_delete; std::set<std::shared_ptr<Node>> reshapes_to_delete;
//STEP 1 : Sink or Swim reshapes away for op clusters
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
{ {
NGRAPH_DEBUG << "Processing node " << n->get_name(); NGRAPH_DEBUG << "Processing node " << n->get_name();
//collect all Result nodes for a sanity check
if (n->is_output()) if (n->is_output())
{ {
results.push_back(n); results.push_back(n);
...@@ -257,73 +397,15 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -257,73 +397,15 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
if (auto reshape = std::dynamic_pointer_cast<op::Reshape>(n)) if (auto reshape = std::dynamic_pointer_cast<op::Reshape>(n))
{ {
auto orig_reshape = reorders.at(n->get_argument(0)); sink_reshape(reshape, reorders, reshapes_to_delete);
if (!reshape->get_is_transpose())
{
NGRAPH_DEBUG << "Materializing " << describe_reshape(orig_reshape)
<< " for reshape " << reshape->get_name();
insert_reshape(reshape, orig_reshape, 0);
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
reorders[reshape] = create_default_reshape(reshape);
}
else
{
//combine both reshapes
auto new_reshape = combine_reshapes(orig_reshape, reshape);
//remove original reshape now it's combined with a new one
//should be safe to remove an already detached node
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
//replace reshape with combined one
ngraph::replace_node(reshape, new_reshape);
reorders[new_reshape] = new_reshape;
NGRAPH_DEBUG << "Combining " << describe_reshape(orig_reshape) << " and"
<< describe_reshape(reshape) << " into "
<< describe_reshape(new_reshape);
}
} }
else if (auto unary = std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n)) else if (auto unary = std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
{ {
auto arg_reshape = reorders.at(n->get_argument(0)); sink_unary(unary, reorders, reshapes_to_delete);
NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for "
<< n->get_name();
reorders[n] = reorders[n->get_argument(0)];
} }
else if (auto binary = std::dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n)) else if (auto binary = std::dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n))
{ {
auto left = n->get_argument(0); sink_binary(binary, reorders, reshapes_to_delete);
auto right = n->get_argument(1);
if (reorders.at(left)->get_input_order() == reorders.at(right)->get_input_order())
{
NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for "
<< n->get_name();
reorders[n] = reorders.at(left);
//at this point, both reshapes will be eventually removed
mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete);
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
}
else if (reorders.at(left)->get_input_order() ==
ngraph::get_default_order(left->get_shape()))
{
convert_binary_to_default_order(
binary, binary->get_inputs().at(0), right, reorders, reshapes_to_delete);
}
else if (reorders.at(right)->get_input_order() ==
ngraph::get_default_order(right->get_shape()))
{
convert_binary_to_default_order(
binary, binary->get_inputs().at(1), left, reorders, reshapes_to_delete);
}
else
{
NGRAPH_DEBUG << "Materializing both reshapes for " << binary->get_name();
NGRAPH_DEBUG << "Left = " << describe_reshape(reorders.at(left));
NGRAPH_DEBUG << "Right = " << describe_reshape(reorders.at(right));
mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete);
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
insert_reshape(binary, reorders.at(left), 0);
insert_reshape(binary, reorders.at(right), 1);
}
} }
else if (auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(n)) else if (auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(n))
{ {
...@@ -331,59 +413,19 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -331,59 +413,19 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
} }
else if (auto quantize = std::dynamic_pointer_cast<op::Quantize>(n)) else if (auto quantize = std::dynamic_pointer_cast<op::Quantize>(n))
{ {
auto arg_reshape = reorders.at(n->get_argument(0)); sink_quantize(quantize, reorders, reshapes_to_delete);
AxisSet axes_in_def_order =
get_quantization_axes_in_default_order(arg_reshape, quantize->get_axes());
auto new_quantize = std::make_shared<op::Quantize>(quantize->get_argument(0),
quantize->get_argument(1),
quantize->get_argument(2),
quantize->get_element_type(),
axes_in_def_order,
quantize->get_round_mode());
ngraph::replace_node(quantize, new_quantize);
reorders[new_quantize] = arg_reshape;
} }
else if (auto dequantize = std::dynamic_pointer_cast<op::Dequantize>(n)) else if (auto dequantize = std::dynamic_pointer_cast<op::Dequantize>(n))
{ {
auto arg_reshape = reorders.at(n->get_argument(0)); sink_dequantize(dequantize, reorders, reshapes_to_delete);
AxisSet axes_in_def_order =
get_quantization_axes_in_default_order(arg_reshape, dequantize->get_axes());
auto new_dequantize = std::make_shared<op::Dequantize>(dequantize->get_argument(0),
dequantize->get_argument(1),
dequantize->get_argument(2),
dequantize->get_element_type(),
axes_in_def_order);
ngraph::replace_node(dequantize, new_dequantize);
reorders[new_dequantize] = arg_reshape;
} }
else else
{ {
//skip multiple output nodes and deal with GOEs exclusively materialize_shapes(n, reorders, reshapes_to_delete);
if (n->get_outputs().size() > 1)
{
continue;
}
//TODO: multiple outputs
for (size_t i = 0; i < n->get_arguments().size(); i++)
{
//materialize all pending reshapes, flush pending reshapes
auto arg = n->get_argument(i);
if (reorders.count(arg) != 0)
{
NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg))
<< " for " << arg->get_name();
mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete);
insert_reshape(n, reorders.at(arg), i);
//no swimming up
}
}
reorders[n] = create_default_reshape(n);
} }
} }
//purge all the reshapes we either sunk or swam. //STEP 2: purge all the reshapes we either sunk or swam.
for (auto r : reshapes_to_delete) for (auto r : reshapes_to_delete)
{ {
delete_reshape(r); delete_reshape(r);
...@@ -397,7 +439,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct ...@@ -397,7 +439,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
<< " op::Result = " << *r << ", Arg = " << *r->get_argument(0); << " op::Result = " << *r << ", Arg = " << *r->get_argument(0);
} }
//fix wrong shape info wholesale //STEP 3: fix wrong shape info wholesale
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
{ {
n->revalidate_and_infer_types(); n->revalidate_and_infer_types();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment