Commit 42a3a0a4 authored by Mohammad Mahbubuzzaman's avatar Mohammad Mahbubuzzaman Committed by Scott Cyphers

Implements proveance tag propagation for reshape sinking pass (#3742)

* Implements proveance tag propagation for reshape sinking pass

* Addresses code review feedback.

* Applies style fixes.
parent 5a43c0f7
......@@ -105,6 +105,7 @@ static void insert_reshape(shared_ptr<Node> target, shared_ptr<Node> reshape, si
auto arg = target->input(input_index).get_source_output();
NGRAPH_DEBUG << "Arg shape: " << arg.get_shape();
auto new_reshape = reshape->copy_with_new_inputs({arg});
new_reshape->merge_provenance_tags_from(reshape);
NGRAPH_DEBUG << "Inserting reshape " << describe_reshape(new_reshape) << " at input "
<< target->get_name() << " input index " << input_index;
target->input(input_index).replace_source_output(new_reshape->output(0));
......@@ -115,7 +116,8 @@ static void delete_reshape(shared_ptr<Node> reshape)
NGRAPH_DEBUG << "Removing reshape " << reshape->get_name();
if (!reshape->get_users().empty())
{
ngraph::replace_node(reshape, reshape->get_argument(0));
ngraph::replace_node(
reshape, reshape->input(0).get_source_output().get_node_shared_ptr(), true);
}
}
......@@ -130,6 +132,7 @@ static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n)
{
auto default_order = ngraph::get_default_order(n->get_shape());
auto default_reshape = make_reshape(n, default_order, n->get_shape());
default_reshape->merge_provenance_tags_from(n);
NGRAPH_DEBUG << "Default reshape: " << describe_reshape(default_reshape);
return default_reshape;
}
......@@ -230,6 +233,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
auto new_broadcast = make_shared<op::Broadcast>(
broadcast_input, broadcast_reshape->get_shape(), new_broadcast_axes);
new_broadcast->merge_provenance_tags_from(old_broadcast);
csw.input.replace_source_output(new_broadcast->output(0));
}
//TODO: Add cases to push through Reshape and BinaryElementwiseArithmetic
......@@ -237,6 +241,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
{
//materialize
auto new_reshape = csw.reshape->copy_with_new_args({n});
new_reshape->merge_provenance_tags_from(n);
NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape);
csw.input.replace_source_output(new_reshape->output(0));
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment