Commit a708df68 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

support a general case Broadcast swimming for ReshapeSinking (#2068)

* swim a special case of broadcast

* general case broadcast swimming for reshape sinking

* fix in_order=false case

* fix reshape redef warning

* add broadcast swimming test

* cleanup test case

* fix warnings

* fix test case
parent b1a22df1
......@@ -84,25 +84,6 @@ static void delete_reshape(std::shared_ptr<Node> reshape)
}
}
static bool unique_dims(const Shape& shape)
{
if (shape.size() == 0)
{
return true;
}
size_t n = shape.at(0);
for (size_t i = 1; i < shape.size(); i++)
{
if (n == shape.at(i))
{
return false;
}
}
return true;
}
static void mark_reshape_for_deletion(std::shared_ptr<Node> reshape,
std::set<std::shared_ptr<Node>>& reshapes_to_delete)
{
......@@ -163,21 +144,56 @@ void swim(descriptor::Input* input, std::shared_ptr<op::Reshape> reshape)
NGRAPH_DEBUG << "Propagating reshape " << describe_reshape(csw.reshape) << " for "
<< n->get_name() << " to " << unary->get_argument(0);
}
else if (std::dynamic_pointer_cast<op::Broadcast>(n) &&
n->get_argument(0)->get_shape().size() == 1 && unique_dims(n->get_shape()))
else if (std::dynamic_pointer_cast<op::Broadcast>(n))
{
auto old_broadcast = std::static_pointer_cast<op::Broadcast>(n);
ngraph::AxisSet as;
size_t channel = n->get_argument(0)->get_shape().at(0);
for (size_t i = 0; i < n->get_shape().size(); i++)
auto broadcast_axes = old_broadcast->get_broadcast_axes();
auto broadcast_reshape = csw.reshape;
bool in_order = true;
AxisSet new_broadcast_axes;
std::vector<size_t> new_source_axes;
auto input_order = broadcast_reshape->get_input_order();
for (size_t i = 0; i < input_order.size(); i++)
{
if (csw.reshape->get_shape().at(i) != channel)
if (broadcast_axes.count(input_order.at(i)) != 0)
{
as.insert(i);
new_broadcast_axes.insert(i);
}
else
{
if (new_source_axes.size() != 0 && new_source_axes.back() > input_order.at(i))
{
in_order = false;
}
new_source_axes.push_back(i);
}
}
auto broadcast_input = old_broadcast->get_argument(0);
if (!in_order)
{
AxisVector new_source_axes_sorted{new_source_axes};
std::sort(new_source_axes_sorted.begin(), new_source_axes_sorted.end());
std::map<size_t, size_t> old_new_source_axes;
for (size_t i = 0; new_source_axes_sorted.size(); i++)
{
old_new_source_axes.insert({new_source_axes.at(i), i});
}
AxisVector new_source_axis_order;
for (auto axis : new_source_axes_sorted)
{
new_source_axis_order.push_back(old_new_source_axes.at(axis));
}
auto new_arg_shape =
ngraph::apply_permutation(broadcast_input->get_shape(), new_source_axis_order);
broadcast_input = std::make_shared<op::Reshape>(
broadcast_input, new_source_axis_order, new_arg_shape);
}
auto new_broadcast =
std::make_shared<op::Broadcast>(n->get_argument(0), csw.reshape->get_shape(), as);
auto new_broadcast = std::make_shared<op::Broadcast>(
broadcast_input, broadcast_reshape->get_shape(), new_broadcast_axes);
csw.input->replace_output(new_broadcast->get_outputs().at(0));
}
//TODO: Add cases to push through Reshape and BinaryElementwiseArithmetic
......
......@@ -75,6 +75,43 @@ TEST(cpu_reshape_sinking, edge_splitting)
ASSERT_EQ(new_reshape->get_shape(), shape_nchw);
}
TEST(cpu_reshape_sinking, broadcast_swimming)
{
Shape shape_nchw{1, 32, 536, 536};
Shape shape_nhwc{1, 536, 536, 32};
Shape shape_weights{16, 32, 3, 3};
Shape conv_nhwc{1, 534, 534, 16};
Shape conv_nchw{1, 16, 534, 534};
AxisVector to_nhwc{0, 2, 3, 1};
AxisVector to_nchw{0, 3, 1, 2};
size_t channel = 16;
auto bias = make_shared<op::Parameter>(element::i32, Shape{channel});
auto bias_reshape = make_shared<op::Reshape>(bias, AxisVector{0}, Shape{1, channel});
auto bias_broadcast = make_shared<op::Broadcast>(bias_reshape, conv_nhwc, AxisSet{1, 2});
auto input = make_shared<op::Parameter>(element::i32, shape_nhwc);
auto reshape_input = make_shared<op::Reshape>(input, to_nchw, shape_nchw);
auto weights = make_shared<op::Parameter>(element::i32, shape_weights);
auto conv = make_shared<op::Convolution>(reshape_input, weights);
auto conv_reshape = make_shared<op::Reshape>(conv, to_nhwc, conv_nhwc);
auto add = bias_broadcast + conv_reshape;
auto relu = make_shared<op::Relu>(add);
auto func = make_shared<Function>(NodeVector{relu}, ParameterVector{bias, input, weights});
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUReshapeSinking>();
pass_manager.register_pass<pass::ReshapeElimination>();
pass_manager.register_pass<pass::CommonSubexpressionElimination>();
pass_manager.run_passes(func);
ASSERT_EQ(add->get_shape(), conv_nchw);
ASSERT_EQ(add->get_argument(0)->get_shape(), conv_nchw);
ASSERT_EQ(add->get_argument(1), conv);
}
TEST(cpu_reshape_sinking, mnist_conv)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, "tf_conv_mnist_nhwc.json");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment