Commit f15877e2 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

replace maxpool + broadcast with broadcast of appropriate shapes (#1142)

parent 3b49dd1a
...@@ -186,6 +186,21 @@ static bool are_img_dims_equal(Shape conv_shape, Shape image_shape) ...@@ -186,6 +186,21 @@ static bool are_img_dims_equal(Shape conv_shape, Shape image_shape)
return conv_shape[2] == image_shape[0] && conv_shape[3] == image_shape[1]; return conv_shape[2] == image_shape[0] && conv_shape[3] == image_shape[1];
} }
static std::shared_ptr<Node> reduce_broadcast(std::shared_ptr<Node> broadcast)
{
const size_t H = 2;
const size_t W = 3;
auto matched_broadcast_w1 = std::dynamic_pointer_cast<op::Broadcast>(broadcast);
Shape shape_w1{matched_broadcast_w1->get_shape()};
shape_w1[H] /= 2;
shape_w1[W] /= 2;
auto new_broadcast_w1 =
std::make_shared<op::Broadcast>(matched_broadcast_w1->get_argument(0),
shape_w1,
matched_broadcast_w1->get_broadcast_axes());
return new_broadcast_w1;
}
static size_t shape_to_index(Shape shape) static size_t shape_to_index(Shape shape)
{ {
if (shape.size() != 4) if (shape.size() != 4)
...@@ -353,17 +368,15 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -353,17 +368,15 @@ void pass::CoreFusion::construct_optimized_strided_conv()
pad_1, pad_1,
pad_1); pad_1);
auto maxpool_w3 = auto new_add_conv_28w3s2 = std::make_shared<op::Add>(
std::make_shared<op::MaxPool>(pattern_map[broadcast_w3_label], Shape{1, 1}, stride_2); conv_28w3s2, reduce_broadcast(pattern_map[broadcast_w3_label]));
auto new_add_conv_28w3s2 = std::make_shared<op::Add>(conv_28w3s2, maxpool_w3);
auto new_relu_28w3s2 = std::make_shared<op::Relu>(new_add_conv_28w3s2); auto new_relu_28w3s2 = std::make_shared<op::Relu>(new_add_conv_28w3s2);
auto conv_28w1s1 = std::make_shared<op::Convolution>( auto conv_28w1s1 = std::make_shared<op::Convolution>(
new_relu_28w3s2, m_conv_stride1->get_argument(1), stride_1, stride_1); new_relu_28w3s2, m_conv_stride1->get_argument(1), stride_1, stride_1);
auto maxpool_w1 = auto new_add_conv28s1 = std::make_shared<op::Add>(
std::make_shared<op::MaxPool>(pattern_map[broadcast_w1_label], Shape{1, 1}, stride_2); conv_28w1s1, reduce_broadcast(pattern_map[broadcast_w1_label]));
auto new_add_conv28s1 = std::make_shared<op::Add>(conv_28w1s1, maxpool_w1);
auto maxpool = auto maxpool =
std::make_shared<op::MaxPool>(pattern_map[eltwise_arg_label], Shape{1, 1}, stride_2); std::make_shared<op::MaxPool>(pattern_map[eltwise_arg_label], Shape{1, 1}, stride_2);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment