Unverified Commit 554d9e53 authored by Ashok Emani's avatar Ashok Emani Committed by GitHub

fix reshape_sinking for broadcast & single user swim only (#4283)

Co-authored-by: 's avatarasemx <998264+asemx@users.noreply.github.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 4381c416
...@@ -166,6 +166,18 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape) ...@@ -166,6 +166,18 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
auto csw = work_queue.front(); auto csw = work_queue.front();
work_queue.pop_front(); work_queue.pop_front();
auto n = csw.input.get_source_output().get_node_shared_ptr(); auto n = csw.input.get_source_output().get_node_shared_ptr();
auto materialize = [csw, n]() {
auto new_reshape = csw.reshape->copy_with_new_args({n});
new_reshape->merge_provenance_tags_from(n);
NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape);
csw.input.replace_source_output(new_reshape->output(0));
};
// Only swim past nodes which have a single user
if (n->get_users().size() > 1)
{
materialize();
continue;
}
NGRAPH_DEBUG << "Processing (swimming) " << n->get_name(); NGRAPH_DEBUG << "Processing (swimming) " << n->get_name();
if (n->is_unary_elementwise_arithmetic()) if (n->is_unary_elementwise_arithmetic())
{ {
...@@ -179,6 +191,12 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape) ...@@ -179,6 +191,12 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
auto old_broadcast = static_pointer_cast<op::Broadcast>(n); auto old_broadcast = static_pointer_cast<op::Broadcast>(n);
auto broadcast_axes = old_broadcast->get_broadcast_axes(); auto broadcast_axes = old_broadcast->get_broadcast_axes();
auto broadcast_reshape = csw.reshape; auto broadcast_reshape = csw.reshape;
// swimming can only handle 1 dim change
if (broadcast_reshape->get_shape().size() - old_broadcast->get_shape().size() > 1)
{
materialize();
continue;
}
bool in_order = true; bool in_order = true;
AxisSet new_broadcast_axes; AxisSet new_broadcast_axes;
vector<size_t> new_source_axes; vector<size_t> new_source_axes;
...@@ -230,9 +248,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape) ...@@ -230,9 +248,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
else else
{ {
// materialize // materialize
auto new_reshape = csw.reshape->copy_with_new_args({n}); materialize();
NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape);
csw.input.replace_source_output(new_reshape->output(0));
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment