Commit 922aaaf8 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Reshape Broadcast (#2198)

* reshape broadcast

* fix warnings
parent 556179a2
......@@ -440,6 +440,76 @@ static size_t shape_to_index(Shape shape)
}
}
void ngraph::pass::CoreFusion::construct_reshape_broadcast()
{
Shape input_shape{10};
auto input = make_shared<pattern::op::Label>(element::f32, input_shape);
auto reshape1 = make_shared<op::Reshape>(input, AxisVector{0}, Shape{10, 1});
auto broadcast = make_shared<op::Broadcast>(reshape1, Shape{10, 1, 20}, AxisSet{2});
pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_reshape_broadcast against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto broadcast_m = std::static_pointer_cast<op::Broadcast>(m.get_match_root());
auto reshape1_m = std::static_pointer_cast<op::Reshape>(broadcast_m->get_argument(0));
auto input_m = m.get_pattern_map()[input];
//it doesn't seem to make sense to support shapes : [0] or [1]
if (input_m->get_shape().size() != 1 || input_m->get_shape().at(0) < 2)
{
NGRAPH_DEBUG << "input_m isn't a scalar or contains zero dimension";
return false;
}
size_t dim = input_m->get_shape().at(0);
//We are going to support the most common case where broadcast doesn't add 1-dimensions
//since it's also very simple to implement
size_t dim_one_count = 0;
for (auto d : reshape1_m->get_shape())
{
if (d != 1 && d != dim)
{
NGRAPH_DEBUG << "Input is reshaped in a way we can't directly broadcast ( shape = "
<< ngraph::vector_to_string(reshape1_m->get_shape()) << ")";
return false;
}
if (d == 1)
{
dim_one_count++;
}
}
AxisSet new_axes = broadcast_m->get_broadcast_axes();
auto broadcast_shape = broadcast_m->get_shape();
for (size_t i = 0; i < broadcast_shape.size(); i++)
{
if (broadcast_shape[i] == 1)
{
dim_one_count--;
new_axes.insert(i);
}
}
if (dim_one_count != 0)
{
NGRAPH_DEBUG << "Broadcast adds 1-dimensions";
return false;
}
auto new_broadcast =
make_shared<op::Broadcast>(input_m, broadcast_m->get_shape(), new_axes);
ngraph::replace_node(m.get_match_root(), new_broadcast);
return true;
};
auto m = make_shared<pattern::Matcher>(broadcast, callback, "CoreFusion.ReshapeBroadcast");
this->add_matcher(m);
}
// conv(56w3s1) conv(28w3s2)
// | |
// conv(56w1s1) ==> conv(28w1s1)
......
......@@ -38,6 +38,7 @@ public:
construct_sigmoid();
construct_sigmoid_bprop();
construct_optimized_strided_conv();
construct_reshape_broadcast();
construct_reshape_softmax_reshape();
}
void construct_relu();
......@@ -46,5 +47,6 @@ public:
void construct_sigmoid();
void construct_sigmoid_bprop();
void construct_optimized_strided_conv();
void construct_reshape_broadcast();
void construct_reshape_softmax_reshape();
};
......@@ -134,6 +134,91 @@ TEST(core_fusion, sigmoid_bprop_fusion)
ASSERT_EQ(ccg, 1);
}
TEST(core_fusion, reshape_broadcast)
{
auto generate_func = []() {
auto input = make_shared<op::Parameter>(element::f32, Shape{10});
auto reshape1 = make_shared<op::Reshape>(input, AxisVector{0}, Shape{1, 10, 1});
auto broadcast =
make_shared<op::Broadcast>(reshape1, Shape{1, 5, 10, 8, 1, 20}, AxisSet{1, 3, 5});
auto f = make_shared<Function>(broadcast, ParameterVector{input});
return f;
};
auto baseline_f = generate_func();
auto optimized_f = generate_func();
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>();
pass_manager.run_passes(optimized_f);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(baseline_input_shape));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
}
TEST(core_fusion, reshape_broadcast_graph_optimized)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{10});
auto reshape1 = make_shared<op::Reshape>(input, AxisVector{0}, Shape{1, 10, 1});
auto broadcast =
make_shared<op::Broadcast>(reshape1, Shape{1, 5, 10, 8, 1, 20}, AxisSet{1, 3, 5});
auto optimized_f = make_shared<Function>(broadcast, ParameterVector{input});
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>();
pass_manager.run_passes(optimized_f);
auto new_broadcast =
std::dynamic_pointer_cast<op::Broadcast>(optimized_f->get_results().at(0)->get_argument(0));
EXPECT_EQ(new_broadcast->get_argument(0), input);
EXPECT_EQ(new_broadcast->get_broadcast_axes(), (AxisSet{0, 1, 3, 4, 5}));
}
TEST(core_fusion, reshape_broadcast_adds_one)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{10});
auto reshape1 = make_shared<op::Reshape>(input, AxisVector{0}, Shape{1, 10, 1});
auto broadcast =
make_shared<op::Broadcast>(reshape1, Shape{1, 5, 10, 8, 1, 20, 1}, AxisSet{1, 3, 5, 6});
auto optimized_f = make_shared<Function>(broadcast, ParameterVector{input});
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>();
pass_manager.run_passes(optimized_f);
auto new_broadcast =
std::dynamic_pointer_cast<op::Broadcast>(optimized_f->get_results().at(0)->get_argument(0));
EXPECT_EQ(new_broadcast, broadcast);
EXPECT_EQ(new_broadcast->get_argument(0), reshape1);
}
TEST(core_fusion, reshape_broadcast_wrong_reshape)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{10});
auto reshape1 = make_shared<op::Reshape>(input, AxisVector{0}, Shape{1, 5, 2});
auto broadcast =
make_shared<op::Broadcast>(reshape1, Shape{1, 5, 5, 8, 2, 20}, AxisSet{1, 3, 5});
auto optimized_f = make_shared<Function>(broadcast, ParameterVector{input});
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>();
pass_manager.run_passes(optimized_f);
auto new_broadcast =
std::dynamic_pointer_cast<op::Broadcast>(optimized_f->get_results().at(0)->get_argument(0));
EXPECT_EQ(new_broadcast, broadcast);
EXPECT_EQ(new_broadcast->get_argument(0), reshape1);
}
TEST(core_fusion, sparsity_opt_56x56)
{
Shape win_size_3{1, 1, 3, 3};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment