Commit 42a3a0a4 authored by Mohammad Mahbubuzzaman's avatar Mohammad Mahbubuzzaman Committed by Scott Cyphers

Implements proveance tag propagation for reshape sinking pass (#3742)

* Implements proveance tag propagation for reshape sinking pass

* Addresses code review feedback.

* Applies style fixes.
parent 5a43c0f7
...@@ -105,6 +105,7 @@ static void insert_reshape(shared_ptr<Node> target, shared_ptr<Node> reshape, si ...@@ -105,6 +105,7 @@ static void insert_reshape(shared_ptr<Node> target, shared_ptr<Node> reshape, si
auto arg = target->input(input_index).get_source_output(); auto arg = target->input(input_index).get_source_output();
NGRAPH_DEBUG << "Arg shape: " << arg.get_shape(); NGRAPH_DEBUG << "Arg shape: " << arg.get_shape();
auto new_reshape = reshape->copy_with_new_inputs({arg}); auto new_reshape = reshape->copy_with_new_inputs({arg});
new_reshape->merge_provenance_tags_from(reshape);
NGRAPH_DEBUG << "Inserting reshape " << describe_reshape(new_reshape) << " at input " NGRAPH_DEBUG << "Inserting reshape " << describe_reshape(new_reshape) << " at input "
<< target->get_name() << " input index " << input_index; << target->get_name() << " input index " << input_index;
target->input(input_index).replace_source_output(new_reshape->output(0)); target->input(input_index).replace_source_output(new_reshape->output(0));
...@@ -115,7 +116,8 @@ static void delete_reshape(shared_ptr<Node> reshape) ...@@ -115,7 +116,8 @@ static void delete_reshape(shared_ptr<Node> reshape)
NGRAPH_DEBUG << "Removing reshape " << reshape->get_name(); NGRAPH_DEBUG << "Removing reshape " << reshape->get_name();
if (!reshape->get_users().empty()) if (!reshape->get_users().empty())
{ {
ngraph::replace_node(reshape, reshape->get_argument(0)); ngraph::replace_node(
reshape, reshape->input(0).get_source_output().get_node_shared_ptr(), true);
} }
} }
...@@ -130,6 +132,7 @@ static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n) ...@@ -130,6 +132,7 @@ static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n)
{ {
auto default_order = ngraph::get_default_order(n->get_shape()); auto default_order = ngraph::get_default_order(n->get_shape());
auto default_reshape = make_reshape(n, default_order, n->get_shape()); auto default_reshape = make_reshape(n, default_order, n->get_shape());
default_reshape->merge_provenance_tags_from(n);
NGRAPH_DEBUG << "Default reshape: " << describe_reshape(default_reshape); NGRAPH_DEBUG << "Default reshape: " << describe_reshape(default_reshape);
return default_reshape; return default_reshape;
} }
...@@ -230,6 +233,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape) ...@@ -230,6 +233,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
auto new_broadcast = make_shared<op::Broadcast>( auto new_broadcast = make_shared<op::Broadcast>(
broadcast_input, broadcast_reshape->get_shape(), new_broadcast_axes); broadcast_input, broadcast_reshape->get_shape(), new_broadcast_axes);
new_broadcast->merge_provenance_tags_from(old_broadcast);
csw.input.replace_source_output(new_broadcast->output(0)); csw.input.replace_source_output(new_broadcast->output(0));
} }
//TODO: Add cases to push through Reshape and BinaryElementwiseArithmetic //TODO: Add cases to push through Reshape and BinaryElementwiseArithmetic
...@@ -237,6 +241,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape) ...@@ -237,6 +241,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
{ {
//materialize //materialize
auto new_reshape = csw.reshape->copy_with_new_args({n}); auto new_reshape = csw.reshape->copy_with_new_args({n});
new_reshape->merge_provenance_tags_from(n);
NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape); NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape);
csw.input.replace_source_output(new_reshape->output(0)); csw.input.replace_source_output(new_reshape->output(0));
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment