Unverified Commit 554d9e53 authored by Ashok Emani's avatar Ashok Emani Committed by GitHub

fix reshape_sinking for broadcast & single user swim only (#4283)

Co-authored-by: 's avatarasemx <998264+asemx@users.noreply.github.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 4381c416
......@@ -166,6 +166,18 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
auto csw = work_queue.front();
work_queue.pop_front();
auto n = csw.input.get_source_output().get_node_shared_ptr();
auto materialize = [csw, n]() {
auto new_reshape = csw.reshape->copy_with_new_args({n});
new_reshape->merge_provenance_tags_from(n);
NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape);
csw.input.replace_source_output(new_reshape->output(0));
};
// Only swim past nodes which have a single user
if (n->get_users().size() > 1)
{
materialize();
continue;
}
NGRAPH_DEBUG << "Processing (swimming) " << n->get_name();
if (n->is_unary_elementwise_arithmetic())
{
......@@ -179,6 +191,12 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
auto old_broadcast = static_pointer_cast<op::Broadcast>(n);
auto broadcast_axes = old_broadcast->get_broadcast_axes();
auto broadcast_reshape = csw.reshape;
// swimming can only handle 1 dim change
if (broadcast_reshape->get_shape().size() - old_broadcast->get_shape().size() > 1)
{
materialize();
continue;
}
bool in_order = true;
AxisSet new_broadcast_axes;
vector<size_t> new_source_axes;
......@@ -230,9 +248,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
else
{
// materialize
auto new_reshape = csw.reshape->copy_with_new_args({n});
NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape);
csw.input.replace_source_output(new_reshape->output(0));
materialize();
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment