Commit 4015565b authored by Sang Ik Lee's avatar Sang Ik Lee Committed by Scott Cyphers

Reshape sinking: fix issue with handling rank changing reshape. (#3313)

parent ad20f217
...@@ -56,20 +56,51 @@ static string describe_reshape(shared_ptr<Node> node) ...@@ -56,20 +56,51 @@ static string describe_reshape(shared_ptr<Node> node)
return ss.str(); return ss.str();
} }
static shared_ptr<op::Reshape>
make_reshape(shared_ptr<Node> arg, const AxisVector& input_order, const Shape& output_shape)
{
auto reshape = make_shared<op::Reshape>(arg, input_order, output_shape);
NGRAPH_DEBUG << "Make Reshape " << describe_reshape(reshape);
return reshape;
}
static void
write_reshapemap(ReshapeMap& reorders, shared_ptr<Node> target, shared_ptr<op::Reshape> reshape)
{
NGRAPH_DEBUG << "Write ReshapeMap[" << target->get_name()
<< "] = " << describe_reshape(reshape);
reorders[target] = reshape;
}
static shared_ptr<op::Reshape> read_reshapemap(ReshapeMap& reorders, shared_ptr<Node> target)
{
auto reorder = reorders.at(target);
NGRAPH_DEBUG << "Read ReshapeMap[" << target->get_name() << "] -> "
<< describe_reshape(reorder);
return reorder;
}
static shared_ptr<op::Reshape> combine_reshapes(shared_ptr<op::Reshape> r1, static shared_ptr<op::Reshape> combine_reshapes(shared_ptr<op::Reshape> r1,
shared_ptr<op::Reshape> r2) shared_ptr<op::Reshape> r2)
{ {
auto default_order = ngraph::get_default_order(r1->get_shape()); auto default_order = ngraph::get_default_order(r1->get_shape());
auto perm_r1 = apply_permutation(default_order, r1->get_input_order()); auto perm_r1 = apply_permutation(default_order, r1->get_input_order());
auto perm_r2 = apply_permutation(perm_r1, r2->get_input_order()); auto perm_r2 = apply_permutation(perm_r1, r2->get_input_order());
auto rreshape = make_shared<op::Reshape>(r2->get_argument(0), perm_r2, r2->get_shape()); auto rreshape = make_reshape(r2->get_argument(0), perm_r2, r2->get_shape());
NGRAPH_DEBUG << "Combining " << describe_reshape(r1) << " and " << describe_reshape(r2)
<< " into " << describe_reshape(rreshape);
return rreshape; return rreshape;
} }
static void insert_reshape(shared_ptr<Node> target, shared_ptr<Node> reshape, size_t input_index) static void insert_reshape(shared_ptr<Node> target, shared_ptr<Node> reshape, size_t input_index)
{ {
NGRAPH_DEBUG << "Inserting reshape at input " << target->get_name() << " input index "
<< input_index;
auto arg = target->input(input_index).get_source_output(); auto arg = target->input(input_index).get_source_output();
NGRAPH_DEBUG << "Arg shape: " << arg.get_shape();
auto new_reshape = reshape->copy_with_new_inputs({arg}); auto new_reshape = reshape->copy_with_new_inputs({arg});
NGRAPH_DEBUG << "Inserting reshape " << describe_reshape(new_reshape) << " at input "
<< target->get_name() << " input index " << input_index;
target->input(input_index).replace_source_output(new_reshape->output(0)); target->input(input_index).replace_source_output(new_reshape->output(0));
} }
...@@ -92,7 +123,8 @@ static void mark_reshape_for_deletion(shared_ptr<Node> reshape, ...@@ -92,7 +123,8 @@ static void mark_reshape_for_deletion(shared_ptr<Node> reshape,
static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n) static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n)
{ {
auto default_order = ngraph::get_default_order(n->get_shape()); auto default_order = ngraph::get_default_order(n->get_shape());
auto default_reshape = make_shared<op::Reshape>(n, default_order, n->get_shape()); auto default_reshape = make_reshape(n, default_order, n->get_shape());
NGRAPH_DEBUG << "Default reshape: " << describe_reshape(default_reshape);
return default_reshape; return default_reshape;
} }
...@@ -187,7 +219,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape) ...@@ -187,7 +219,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
auto new_arg_shape = auto new_arg_shape =
ngraph::apply_permutation(broadcast_input->get_shape(), new_source_axis_order); ngraph::apply_permutation(broadcast_input->get_shape(), new_source_axis_order);
broadcast_input = broadcast_input =
make_shared<op::Reshape>(broadcast_input, new_source_axis_order, new_arg_shape); make_reshape(broadcast_input, new_source_axis_order, new_arg_shape);
} }
auto new_broadcast = make_shared<op::Broadcast>( auto new_broadcast = make_shared<op::Broadcast>(
...@@ -209,12 +241,11 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape) ...@@ -209,12 +241,11 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
//of a binary op isn't in the default format (i.e. nhwc instead of nchw) //of a binary op isn't in the default format (i.e. nhwc instead of nchw)
//We have to normalize this other argument to nchw by swimming nchw towards parameters //We have to normalize this other argument to nchw by swimming nchw towards parameters
//as far as we can //as far as we can
static void convert_binary_to_default_order( static void convert_binary_to_default_order(shared_ptr<Node> binary,
shared_ptr<Node> binary, const Input<Node>& input,
const Input<Node>& input, shared_ptr<Node> right,
shared_ptr<Node> right, ReshapeMap& reorders,
unordered_map<shared_ptr<Node>, shared_ptr<op::Reshape>>& reorders, set<shared_ptr<Node>>& reshapes_to_delete)
set<shared_ptr<Node>>& reshapes_to_delete)
{ {
auto left = input.get_source_output().get_node_shared_ptr(); auto left = input.get_source_output().get_node_shared_ptr();
auto perm_to_def = auto perm_to_def =
...@@ -222,13 +253,13 @@ static void convert_binary_to_default_order( ...@@ -222,13 +253,13 @@ static void convert_binary_to_default_order(
auto new_shape = apply_permutation(left->get_shape(), perm_to_def); auto new_shape = apply_permutation(left->get_shape(), perm_to_def);
NGRAPH_DEBUG << "right = " << ngraph::vector_to_string(right->get_shape()) << ", " NGRAPH_DEBUG << "right = " << ngraph::vector_to_string(right->get_shape()) << ", "
<< right->get_name(); << right->get_name();
auto new_reshape = make_shared<op::Reshape>(left, perm_to_def, new_shape); auto new_reshape = make_reshape(left, perm_to_def, new_shape);
NGRAPH_DEBUG << "left : About to swim " << describe_reshape(new_reshape) << " up to " NGRAPH_DEBUG << "left : About to swim " << describe_reshape(new_reshape) << " up to "
<< left->get_name(); << left->get_name();
//this should now insert and swim reshape on right //this should now insert and swim reshape on right
swim(input, new_reshape); swim(input, new_reshape);
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete); mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
reorders[binary] = reorders.at(right); write_reshapemap(reorders, binary, read_reshapemap(reorders, right));
} }
static void materialize_shapes(shared_ptr<Node> n, static void materialize_shapes(shared_ptr<Node> n,
...@@ -247,32 +278,37 @@ static void materialize_shapes(shared_ptr<Node> n, ...@@ -247,32 +278,37 @@ static void materialize_shapes(shared_ptr<Node> n,
auto arg = n->get_argument(i); auto arg = n->get_argument(i);
if (reorders.count(arg) != 0) if (reorders.count(arg) != 0)
{ {
NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) << " for " auto arg_reshape = reorders.at(arg);
NGRAPH_DEBUG << "Materializing " << describe_reshape(arg_reshape) << " for "
<< arg->get_name(); << arg->get_name();
mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete); mark_reshape_for_deletion(arg_reshape, reshapes_to_delete);
if (reorders.at(arg)->get_input_order() != get_default_order(arg->get_shape())) auto arg_shape = arg->get_shape();
if (arg_reshape->get_input_order() != get_default_order(arg->get_shape()))
{ {
// Insert if arg needs to be transposed. // Insert if arg needs to be transposed.
insert_reshape(n, reorders.at(arg), i); insert_reshape(n, arg_reshape, i);
} }
//no swimming up //no swimming up
} }
} }
reorders[n] = create_default_reshape(n); write_reshapemap(reorders, n, create_default_reshape(n));
} }
static void sink_reshape(shared_ptr<op::Reshape> reshape, static void sink_reshape(shared_ptr<op::Reshape> reshape,
ReshapeMap& reorders, ReshapeMap& reorders,
set<shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
NGRAPH_DEBUG << "Sinking Reshape :" << describe_reshape(reshape);
auto orig_reshape = reorders.at(reshape->get_argument(0)); auto orig_reshape = reorders.at(reshape->get_argument(0));
if (!reshape->get_is_transpose()) // 1) Not a Transpose or 2) Rank changing operation.
if ((reshape->get_output_shape().size() != reshape->get_input_order().size()) ||
(!reshape->get_is_transpose()))
{ {
NGRAPH_DEBUG << "Materializing " << describe_reshape(orig_reshape) << " for reshape " NGRAPH_DEBUG << "Materializing " << describe_reshape(orig_reshape) << " for reshape "
<< reshape->get_name(); << describe_reshape(reshape);
insert_reshape(reshape, orig_reshape, 0); insert_reshape(reshape, orig_reshape, 0);
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete); mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
reorders[reshape] = create_default_reshape(reshape); write_reshapemap(reorders, reshape, create_default_reshape(reshape));
} }
else else
{ {
...@@ -284,9 +320,7 @@ static void sink_reshape(shared_ptr<op::Reshape> reshape, ...@@ -284,9 +320,7 @@ static void sink_reshape(shared_ptr<op::Reshape> reshape,
//replace reshape with combined one //replace reshape with combined one
ngraph::replace_node(reshape, new_reshape); ngraph::replace_node(reshape, new_reshape);
mark_reshape_for_deletion(new_reshape, reshapes_to_delete); mark_reshape_for_deletion(new_reshape, reshapes_to_delete);
reorders[new_reshape] = new_reshape; write_reshapemap(reorders, new_reshape, new_reshape);
NGRAPH_DEBUG << "Combining " << describe_reshape(orig_reshape) << " and"
<< describe_reshape(reshape) << " into " << describe_reshape(new_reshape);
} }
} }
...@@ -294,9 +328,9 @@ static void sink_unary(shared_ptr<op::util::UnaryElementwiseArithmetic> n, ...@@ -294,9 +328,9 @@ static void sink_unary(shared_ptr<op::util::UnaryElementwiseArithmetic> n,
ReshapeMap& reorders, ReshapeMap& reorders,
set<shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
auto arg_reshape = reorders.at(n->get_argument(0)); auto arg_reshape = read_reshapemap(reorders, n->get_argument(0));
NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " << n->get_name(); NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " << n->get_name();
reorders[n] = reorders[n->get_argument(0)]; write_reshapemap(reorders, n, arg_reshape);
} }
static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary, static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary,
...@@ -310,7 +344,7 @@ static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary ...@@ -310,7 +344,7 @@ static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary
{ {
NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for " NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for "
<< binary->get_name(); << binary->get_name();
reorders[binary] = reorders.at(left); write_reshapemap(reorders, binary, read_reshapemap(reorders, left));
//at this point, both reshapes will be eventually removed //at this point, both reshapes will be eventually removed
mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete); mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete);
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete); mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
...@@ -360,9 +394,9 @@ static void sink_slice(shared_ptr<op::Slice> n, ...@@ -360,9 +394,9 @@ static void sink_slice(shared_ptr<op::Slice> n,
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_slice->get_name(); NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_slice->get_name();
ngraph::replace_node(n, new_slice); ngraph::replace_node(n, new_slice);
auto new_reshape = make_shared<op::Reshape>(new_slice, order, n->get_shape()); auto new_reshape = make_reshape(new_slice, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name(); NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_slice] = new_reshape; write_reshapemap(reorders, new_slice, new_reshape);
} }
static void static void
...@@ -385,9 +419,9 @@ static void ...@@ -385,9 +419,9 @@ static void
ngraph::replace_node(dummy_correct_shape, n->get_argument(0)); ngraph::replace_node(dummy_correct_shape, n->get_argument(0));
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name(); NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name();
ngraph::replace_node(n, new_pad); ngraph::replace_node(n, new_pad);
auto new_reshape = make_shared<op::Reshape>(new_pad, order, n->get_shape()); auto new_reshape = make_reshape(new_pad, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name(); NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_pad] = new_reshape; write_reshapemap(reorders, new_pad, new_reshape);
} }
static void sink_quantize(shared_ptr<op::Quantize> quantize, static void sink_quantize(shared_ptr<op::Quantize> quantize,
ReshapeMap& reorders, ReshapeMap& reorders,
...@@ -404,7 +438,7 @@ static void sink_quantize(shared_ptr<op::Quantize> quantize, ...@@ -404,7 +438,7 @@ static void sink_quantize(shared_ptr<op::Quantize> quantize,
quantize->get_round_mode()); quantize->get_round_mode());
ngraph::replace_node(quantize, new_quantize); ngraph::replace_node(quantize, new_quantize);
reorders[new_quantize] = arg_reshape; write_reshapemap(reorders, new_quantize, arg_reshape);
} }
static void sink_concat(shared_ptr<op::Concat> n, static void sink_concat(shared_ptr<op::Concat> n,
...@@ -451,9 +485,9 @@ static void sink_concat(shared_ptr<op::Concat> n, ...@@ -451,9 +485,9 @@ static void sink_concat(shared_ptr<op::Concat> n,
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name(); NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name();
ngraph::replace_node(n, new_concat); ngraph::replace_node(n, new_concat);
auto new_reshape = make_shared<op::Reshape>(new_concat, order, n->get_shape()); auto new_reshape = make_reshape(new_concat, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name(); NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_concat] = new_reshape; write_reshapemap(reorders, new_concat, new_reshape);
} }
static void sink_dequantize(shared_ptr<op::Dequantize> dequantize, static void sink_dequantize(shared_ptr<op::Dequantize> dequantize,
...@@ -470,7 +504,7 @@ static void sink_dequantize(shared_ptr<op::Dequantize> dequantize, ...@@ -470,7 +504,7 @@ static void sink_dequantize(shared_ptr<op::Dequantize> dequantize,
axes_in_def_order); axes_in_def_order);
ngraph::replace_node(dequantize, new_dequantize); ngraph::replace_node(dequantize, new_dequantize);
reorders[new_dequantize] = arg_reshape; write_reshapemap(reorders, new_dequantize, arg_reshape);
} }
//The goal of ReshapeSinking is to remove //The goal of ReshapeSinking is to remove
...@@ -491,7 +525,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function> ...@@ -491,7 +525,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
//STEP 1 : Sink or Swim reshapes away for op clusters //STEP 1 : Sink or Swim reshapes away for op clusters
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
{ {
NGRAPH_DEBUG << "Processing node " << n->get_name(); NGRAPH_DEBUG << "Start: Processing node " << n->get_name();
//collect all Result nodes for a sanity check //collect all Result nodes for a sanity check
if (n->is_output()) if (n->is_output())
{ {
...@@ -512,7 +546,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function> ...@@ -512,7 +546,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
} }
else if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(n)) else if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(n))
{ {
reorders[goe] = create_default_reshape(goe); write_reshapemap(reorders, goe, create_default_reshape(goe));
} }
else if (auto quantize = dynamic_pointer_cast<op::Quantize>(n)) else if (auto quantize = dynamic_pointer_cast<op::Quantize>(n))
{ {
...@@ -555,6 +589,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function> ...@@ -555,6 +589,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
{ {
materialize_shapes(n, reorders, reshapes_to_delete); materialize_shapes(n, reorders, reshapes_to_delete);
} }
NGRAPH_DEBUG << "End: Processing node " << n->get_name();
} }
//STEP 2: purge all the reshapes we either sunk or swam. //STEP 2: purge all the reshapes we either sunk or swam.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment