Commit 50283370 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Refactor ReshapeSinking into individual sinkers (#2222)

* Move CPU ReshapeSinking to Core pass

* Modify clang compile error

* Fix for style-apply check

* refactor into invidivual sinkerso

* adding individual sinkers

* add a header back

* fix build break
parent 02d4aa59
......@@ -44,6 +44,8 @@ extern template ngraph::AxisVector
extern template ngraph::Shape ngraph::apply_permutation<ngraph::Shape>(ngraph::Shape input,
ngraph::AxisVector order);
using ReshapeMap = std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>>;
static std::shared_ptr<op::Reshape> combine_reshapes(std::shared_ptr<op::Reshape> r1,
std::shared_ptr<op::Reshape> r2)
{
......@@ -232,36 +234,15 @@ static void convert_binary_to_default_order(
reorders[binary] = reorders.at(right);
}
//The goal of ReshapeSinking is to remove
//round-trip reshapes(i.e. nhwc->nchw(nchw-only-op)->nhwc)
//around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool)
//This is achieved by both **sinking**, propagating reshapes
//through ops towards op::Results,
//or **swimming** Reshapes up towards op::Parameter
//For each op type we support we can either combine
//two reshapes by replacing the existing Reshape,
//materialize pending reshapes if they can't be propagated through op
bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Function> f)
static void sink_reshape(std::shared_ptr<op::Reshape> reshape,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>> reorders;
NodeVector results;
std::set<std::shared_ptr<Node>> reshapes_to_delete;
for (auto n : f->get_ordered_ops())
{
NGRAPH_DEBUG << "Processing node " << n->get_name();
if (n->is_output())
{
results.push_back(n);
}
if (auto reshape = std::dynamic_pointer_cast<op::Reshape>(n))
{
auto orig_reshape = reorders.at(n->get_argument(0));
auto orig_reshape = reorders.at(reshape->get_argument(0));
if (!reshape->get_is_transpose())
{
NGRAPH_DEBUG << "Materializing " << describe_reshape(orig_reshape)
<< " for reshape " << reshape->get_name();
NGRAPH_DEBUG << "Materializing " << describe_reshape(orig_reshape) << " for reshape "
<< reshape->get_name();
insert_reshape(reshape, orig_reshape, 0);
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
reorders[reshape] = create_default_reshape(reshape);
......@@ -277,39 +258,41 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
ngraph::replace_node(reshape, new_reshape);
reorders[new_reshape] = new_reshape;
NGRAPH_DEBUG << "Combining " << describe_reshape(orig_reshape) << " and"
<< describe_reshape(reshape) << " into "
<< describe_reshape(new_reshape);
<< describe_reshape(reshape) << " into " << describe_reshape(new_reshape);
}
}
else if (auto unary = std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
{
}
static void sink_unary(std::shared_ptr<op::util::UnaryElementwiseArithmetic> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(n->get_argument(0));
NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for "
<< n->get_name();
NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " << n->get_name();
reorders[n] = reorders[n->get_argument(0)];
}
else if (auto binary = std::dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n))
{
auto left = n->get_argument(0);
auto right = n->get_argument(1);
}
static void sink_binary(std::shared_ptr<op::util::BinaryElementwiseArithmetic> binary,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto left = binary->get_argument(0);
auto right = binary->get_argument(1);
if (reorders.at(left)->get_input_order() == reorders.at(right)->get_input_order())
{
NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for "
<< n->get_name();
reorders[n] = reorders.at(left);
<< binary->get_name();
reorders[binary] = reorders.at(left);
//at this point, both reshapes will be eventually removed
mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete);
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
}
else if (reorders.at(left)->get_input_order() ==
ngraph::get_default_order(left->get_shape()))
else if (reorders.at(left)->get_input_order() == ngraph::get_default_order(left->get_shape()))
{
convert_binary_to_default_order(
binary, binary->get_inputs().at(0), right, reorders, reshapes_to_delete);
}
else if (reorders.at(right)->get_input_order() ==
ngraph::get_default_order(right->get_shape()))
else if (reorders.at(right)->get_input_order() == ngraph::get_default_order(right->get_shape()))
{
convert_binary_to_default_order(
binary, binary->get_inputs().at(1), left, reorders, reshapes_to_delete);
......@@ -324,14 +307,13 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
insert_reshape(binary, reorders.at(left), 0);
insert_reshape(binary, reorders.at(right), 1);
}
}
else if (auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(n))
{
reorders[goe] = create_default_reshape(goe);
}
else if (auto quantize = std::dynamic_pointer_cast<op::Quantize>(n))
{
auto arg_reshape = reorders.at(n->get_argument(0));
}
static void sink_quantize(std::shared_ptr<op::Quantize> quantize,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(quantize->get_argument(0));
AxisSet axes_in_def_order =
get_quantization_axes_in_default_order(arg_reshape, quantize->get_axes());
auto new_quantize = std::make_shared<op::Quantize>(quantize->get_argument(0),
......@@ -343,10 +325,13 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
ngraph::replace_node(quantize, new_quantize);
reorders[new_quantize] = arg_reshape;
}
else if (auto dequantize = std::dynamic_pointer_cast<op::Dequantize>(n))
{
auto arg_reshape = reorders.at(n->get_argument(0));
}
static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(dequantize->get_argument(0));
AxisSet axes_in_def_order =
get_quantization_axes_in_default_order(arg_reshape, dequantize->get_axes());
auto new_dequantize = std::make_shared<op::Dequantize>(dequantize->get_argument(0),
......@@ -357,33 +342,90 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
ngraph::replace_node(dequantize, new_dequantize);
reorders[new_dequantize] = arg_reshape;
}
else
{
}
static void materialize_shapes(std::shared_ptr<Node> n,
ReshapeMap& reorders,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
//skip multiple output nodes and deal with GOEs exclusively
if (n->get_outputs().size() > 1)
{
continue;
return;
}
//TODO: multiple outputs
for (size_t i = 0; i < n->get_arguments().size(); i++)
{
//materialize all pending reshapes, flush pending reshapes
auto arg = n->get_argument(i);
if (reorders.count(arg) != 0)
{
NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg))
<< " for " << arg->get_name();
NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) << " for "
<< arg->get_name();
mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete);
insert_reshape(n, reorders.at(arg), i);
//no swimming up
}
}
reorders[n] = create_default_reshape(n);
}
//The goal of ReshapeSinking is to remove
//round-trip reshapes(i.e. nhwc->nchw(nchw-only-op)->nhwc)
//around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool)
//This is achieved by both **sinking**, propagating reshapes
//through ops towards op::Results,
//or **swimming** Reshapes up towards op::Parameter
//For each op type we support we can either combine
//two reshapes by replacing the existing Reshape,
//materialize pending reshapes if they can't be propagated through op
bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Function> f)
{
ReshapeMap reorders;
NodeVector results;
std::set<std::shared_ptr<Node>> reshapes_to_delete;
//STEP 1 : Sink or Swim reshapes away for op clusters
for (auto n : f->get_ordered_ops())
{
NGRAPH_DEBUG << "Processing node " << n->get_name();
//collect all Result nodes for a sanity check
if (n->is_output())
{
results.push_back(n);
}
if (auto reshape = std::dynamic_pointer_cast<op::Reshape>(n))
{
sink_reshape(reshape, reorders, reshapes_to_delete);
}
else if (auto unary = std::dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
{
sink_unary(unary, reorders, reshapes_to_delete);
}
else if (auto binary = std::dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n))
{
sink_binary(binary, reorders, reshapes_to_delete);
}
else if (auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(n))
{
reorders[goe] = create_default_reshape(goe);
}
else if (auto quantize = std::dynamic_pointer_cast<op::Quantize>(n))
{
sink_quantize(quantize, reorders, reshapes_to_delete);
}
else if (auto dequantize = std::dynamic_pointer_cast<op::Dequantize>(n))
{
sink_dequantize(dequantize, reorders, reshapes_to_delete);
}
else
{
materialize_shapes(n, reorders, reshapes_to_delete);
}
}
//purge all the reshapes we either sunk or swam.
//STEP 2: purge all the reshapes we either sunk or swam.
for (auto r : reshapes_to_delete)
{
delete_reshape(r);
......@@ -397,7 +439,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
<< " op::Result = " << *r << ", Arg = " << *r->get_argument(0);
}
//fix wrong shape info wholesale
//STEP 3: fix wrong shape info wholesale
for (auto n : f->get_ordered_ops())
{
n->revalidate_and_infer_types();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment