Commit 0ac2a8b6 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

swim a special case of broadcast (#2034)

parent 1daac094
......@@ -84,6 +84,25 @@ static void delete_reshape(std::shared_ptr<Node> reshape)
}
}
static bool unique_dims(const Shape& shape)
{
if (shape.size() == 0)
{
return true;
}
size_t n = shape.at(0);
for (size_t i = 1; i < shape.size(); i++)
{
if (n == shape.at(i))
{
return false;
}
}
return true;
}
static void mark_reshape_for_deletion(std::shared_ptr<Node> reshape,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
......@@ -144,6 +163,23 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
NGRAPH_DEBUG << "Propagating reshape " << describe_reshape(csw.reshape) << " for "
<< n->get_name() << " to " << unary->get_argument(0);
}
else if (std::dynamic_pointer_cast<op::Broadcast>(n) &&
n->get_argument(0)->get_shape().size() == 1 && unique_dims(n->get_shape()))
{
auto old_broadcast = std::static_pointer_cast<op::Broadcast>(n);
ngraph::AxisSet as;
size_t channel = n->get_argument(0)->get_shape().at(0);
for (size_t i = 0; i < n->get_shape().size(); i++)
{
if (csw.reshape->get_shape().at(i) != channel)
{
as.insert(i);
}
}
auto new_broadcast =
std::make_shared<op::Broadcast>(n->get_argument(0), csw.reshape->get_shape(), as);
csw.input->replace_output(new_broadcast->get_outputs().at(0));
}
//TODO: Add cases to push through Reshape and BinaryElementwiseArithmetic
else
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment