Commit 4015565b authored by Sang Ik Lee's avatar Sang Ik Lee Committed by Scott Cyphers

Reshape sinking: fix issue with handling rank changing reshape. (#3313)

parent ad20f217
......@@ -56,20 +56,51 @@ static string describe_reshape(shared_ptr<Node> node)
return ss.str();
}
static shared_ptr<op::Reshape>
make_reshape(shared_ptr<Node> arg, const AxisVector& input_order, const Shape& output_shape)
{
auto reshape = make_shared<op::Reshape>(arg, input_order, output_shape);
NGRAPH_DEBUG << "Make Reshape " << describe_reshape(reshape);
return reshape;
}
static void
write_reshapemap(ReshapeMap& reorders, shared_ptr<Node> target, shared_ptr<op::Reshape> reshape)
{
NGRAPH_DEBUG << "Write ReshapeMap[" << target->get_name()
<< "] = " << describe_reshape(reshape);
reorders[target] = reshape;
}
static shared_ptr<op::Reshape> read_reshapemap(ReshapeMap& reorders, shared_ptr<Node> target)
{
auto reorder = reorders.at(target);
NGRAPH_DEBUG << "Read ReshapeMap[" << target->get_name() << "] -> "
<< describe_reshape(reorder);
return reorder;
}
static shared_ptr<op::Reshape> combine_reshapes(shared_ptr<op::Reshape> r1,
shared_ptr<op::Reshape> r2)
{
auto default_order = ngraph::get_default_order(r1->get_shape());
auto perm_r1 = apply_permutation(default_order, r1->get_input_order());
auto perm_r2 = apply_permutation(perm_r1, r2->get_input_order());
auto rreshape = make_shared<op::Reshape>(r2->get_argument(0), perm_r2, r2->get_shape());
auto rreshape = make_reshape(r2->get_argument(0), perm_r2, r2->get_shape());
NGRAPH_DEBUG << "Combining " << describe_reshape(r1) << " and " << describe_reshape(r2)
<< " into " << describe_reshape(rreshape);
return rreshape;
}
static void insert_reshape(shared_ptr<Node> target, shared_ptr<Node> reshape, size_t input_index)
{
NGRAPH_DEBUG << "Inserting reshape at input " << target->get_name() << " input index "
<< input_index;
auto arg = target->input(input_index).get_source_output();
NGRAPH_DEBUG << "Arg shape: " << arg.get_shape();
auto new_reshape = reshape->copy_with_new_inputs({arg});
NGRAPH_DEBUG << "Inserting reshape " << describe_reshape(new_reshape) << " at input "
<< target->get_name() << " input index " << input_index;
target->input(input_index).replace_source_output(new_reshape->output(0));
}
......@@ -92,7 +123,8 @@ static void mark_reshape_for_deletion(shared_ptr<Node> reshape,
static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n)
{
auto default_order = ngraph::get_default_order(n->get_shape());
auto default_reshape = make_shared<op::Reshape>(n, default_order, n->get_shape());
auto default_reshape = make_reshape(n, default_order, n->get_shape());
NGRAPH_DEBUG << "Default reshape: " << describe_reshape(default_reshape);
return default_reshape;
}
......@@ -187,7 +219,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
auto new_arg_shape =
ngraph::apply_permutation(broadcast_input->get_shape(), new_source_axis_order);
broadcast_input =
make_shared<op::Reshape>(broadcast_input, new_source_axis_order, new_arg_shape);
make_reshape(broadcast_input, new_source_axis_order, new_arg_shape);
}
auto new_broadcast = make_shared<op::Broadcast>(
......@@ -209,12 +241,11 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
//of a binary op isn't in the default format (i.e. nhwc instead of nchw)
//We have to normalize this other argument to nchw by swimming nchw towards parameters
//as far as we can
static void convert_binary_to_default_order(
shared_ptr<Node> binary,
const Input<Node>& input,
shared_ptr<Node> right,
unordered_map<shared_ptr<Node>, shared_ptr<op::Reshape>>& reorders,
set<shared_ptr<Node>>& reshapes_to_delete)
static void convert_binary_to_default_order(shared_ptr<Node> binary,
const Input<Node>& input,
shared_ptr<Node> right,
ReshapeMap& reorders,
set<shared_ptr<Node>>& reshapes_to_delete)
{
auto left = input.get_source_output().get_node_shared_ptr();
auto perm_to_def =
......@@ -222,13 +253,13 @@ static void convert_binary_to_default_order(
auto new_shape = apply_permutation(left->get_shape(), perm_to_def);
NGRAPH_DEBUG << "right = " << ngraph::vector_to_string(right->get_shape()) << ", "
<< right->get_name();
auto new_reshape = make_shared<op::Reshape>(left, perm_to_def, new_shape);
auto new_reshape = make_reshape(left, perm_to_def, new_shape);
NGRAPH_DEBUG << "left : About to swim " << describe_reshape(new_reshape) << " up to "
<< left->get_name();
//this should now insert and swim reshape on right
swim(input, new_reshape);
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
reorders[binary] = reorders.at(right);
write_reshapemap(reorders, binary, read_reshapemap(reorders, right));
}
static void materialize_shapes(shared_ptr<Node> n,
......@@ -247,32 +278,37 @@ static void materialize_shapes(shared_ptr<Node> n,
auto arg = n->get_argument(i);
if (reorders.count(arg) != 0)
{
NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) << " for "
auto arg_reshape = reorders.at(arg);
NGRAPH_DEBUG << "Materializing " << describe_reshape(arg_reshape) << " for "
<< arg->get_name();
mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete);
if (reorders.at(arg)->get_input_order() != get_default_order(arg->get_shape()))
mark_reshape_for_deletion(arg_reshape, reshapes_to_delete);
auto arg_shape = arg->get_shape();
if (arg_reshape->get_input_order() != get_default_order(arg->get_shape()))
{
// Insert if arg needs to be transposed.
insert_reshape(n, reorders.at(arg), i);
insert_reshape(n, arg_reshape, i);
}
//no swimming up
}
}
reorders[n] = create_default_reshape(n);
write_reshapemap(reorders, n, create_default_reshape(n));
}
static void sink_reshape(shared_ptr<op::Reshape> reshape,
ReshapeMap& reorders,
set<shared_ptr<Node>>& reshapes_to_delete)
{
NGRAPH_DEBUG << "Sinking Reshape :" << describe_reshape(reshape);
auto orig_reshape = reorders.at(reshape->get_argument(0));
if (!reshape->get_is_transpose())
// 1) Not a Transpose or 2) Rank changing operation.
if ((reshape->get_output_shape().size() != reshape->get_input_order().size()) ||
(!reshape->get_is_transpose()))
{
NGRAPH_DEBUG << "Materializing " << describe_reshape(orig_reshape) << " for reshape "
<< reshape->get_name();
<< describe_reshape(reshape);
insert_reshape(reshape, orig_reshape, 0);
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
reorders[reshape] = create_default_reshape(reshape);
write_reshapemap(reorders, reshape, create_default_reshape(reshape));
}
else
{
......@@ -284,9 +320,7 @@ static void sink_reshape(shared_ptr<op::Reshape> reshape,
//replace reshape with combined one
ngraph::replace_node(reshape, new_reshape);
mark_reshape_for_deletion(new_reshape, reshapes_to_delete);
reorders[new_reshape] = new_reshape;
NGRAPH_DEBUG << "Combining " << describe_reshape(orig_reshape) << " and"
<< describe_reshape(reshape) << " into " << describe_reshape(new_reshape);
write_reshapemap(reorders, new_reshape, new_reshape);
}
}
......@@ -294,9 +328,9 @@ static void sink_unary(shared_ptr<op::util::UnaryElementwiseArithmetic> n,
ReshapeMap& reorders,
set<shared_ptr<Node>>& reshapes_to_delete)
{
auto arg_reshape = reorders.at(n->get_argument(0));
auto arg_reshape = read_reshapemap(reorders, n->get_argument(0));
NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " << n->get_name();
reorders[n] = reorders[n->get_argument(0)];
write_reshapemap(reorders, n, arg_reshape);
}
static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary,
......@@ -310,7 +344,7 @@ static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary
{
NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for "
<< binary->get_name();
reorders[binary] = reorders.at(left);
write_reshapemap(reorders, binary, read_reshapemap(reorders, left));
//at this point, both reshapes will be eventually removed
mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete);
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
......@@ -360,9 +394,9 @@ static void sink_slice(shared_ptr<op::Slice> n,
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_slice->get_name();
ngraph::replace_node(n, new_slice);
auto new_reshape = make_shared<op::Reshape>(new_slice, order, n->get_shape());
auto new_reshape = make_reshape(new_slice, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_slice] = new_reshape;
write_reshapemap(reorders, new_slice, new_reshape);
}
static void
......@@ -385,9 +419,9 @@ static void
ngraph::replace_node(dummy_correct_shape, n->get_argument(0));
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name();
ngraph::replace_node(n, new_pad);
auto new_reshape = make_shared<op::Reshape>(new_pad, order, n->get_shape());
auto new_reshape = make_reshape(new_pad, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_pad] = new_reshape;
write_reshapemap(reorders, new_pad, new_reshape);
}
static void sink_quantize(shared_ptr<op::Quantize> quantize,
ReshapeMap& reorders,
......@@ -404,7 +438,7 @@ static void sink_quantize(shared_ptr<op::Quantize> quantize,
quantize->get_round_mode());
ngraph::replace_node(quantize, new_quantize);
reorders[new_quantize] = arg_reshape;
write_reshapemap(reorders, new_quantize, arg_reshape);
}
static void sink_concat(shared_ptr<op::Concat> n,
......@@ -451,9 +485,9 @@ static void sink_concat(shared_ptr<op::Concat> n,
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name();
ngraph::replace_node(n, new_concat);
auto new_reshape = make_shared<op::Reshape>(new_concat, order, n->get_shape());
auto new_reshape = make_reshape(new_concat, order, n->get_shape());
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
reorders[new_concat] = new_reshape;
write_reshapemap(reorders, new_concat, new_reshape);
}
static void sink_dequantize(shared_ptr<op::Dequantize> dequantize,
......@@ -470,7 +504,7 @@ static void sink_dequantize(shared_ptr<op::Dequantize> dequantize,
axes_in_def_order);
ngraph::replace_node(dequantize, new_dequantize);
reorders[new_dequantize] = arg_reshape;
write_reshapemap(reorders, new_dequantize, arg_reshape);
}
//The goal of ReshapeSinking is to remove
......@@ -491,7 +525,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
//STEP 1 : Sink or Swim reshapes away for op clusters
for (auto n : f->get_ordered_ops())
{
NGRAPH_DEBUG << "Processing node " << n->get_name();
NGRAPH_DEBUG << "Start: Processing node " << n->get_name();
//collect all Result nodes for a sanity check
if (n->is_output())
{
......@@ -512,7 +546,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
}
else if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(n))
{
reorders[goe] = create_default_reshape(goe);
write_reshapemap(reorders, goe, create_default_reshape(goe));
}
else if (auto quantize = dynamic_pointer_cast<op::Quantize>(n))
{
......@@ -555,6 +589,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
{
materialize_shapes(n, reorders, reshapes_to_delete);
}
NGRAPH_DEBUG << "End: Processing node " << n->get_name();
}
//STEP 2: purge all the reshapes we either sunk or swam.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment