Commit 94844d13 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Adam Procter

Strided Convolution (#1058)

* optimized strided convolutions

* clean up debug messages

* format fixes

* more tests

* even more tests

* adapt to resnet-50.v1

* fix format errors; remove changes from diff PRs
parent 656dfa55
......@@ -27,6 +27,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
......@@ -162,3 +163,217 @@ void pass::CoreFusion::construct_folded_batch_norm()
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, callback);
this->add_matcher(m);
}
static bool is_trivial_convolution(std::shared_ptr<op::Convolution> conv,
bool skip_pad_checks = false)
{
Strides stride_1{1, 1};
CoordinateDiff pad_0{0, 0};
return conv->get_window_dilation_strides() == stride_1 &&
conv->get_data_dilation_strides() == stride_1 &&
(skip_pad_checks ||
(conv->get_padding_above() == pad_0 && conv->get_padding_below() == pad_0));
}
static bool are_img_dims_equal(Shape conv_shape, Shape image_shape)
{
if (conv_shape.size() != 4)
{
return false;
}
return conv_shape[2] == image_shape[0] && conv_shape[3] == image_shape[1];
}
static size_t shape_to_index(Shape shape)
{
if (shape.size() != 4)
{
return 0;
}
const size_t HEIGHT_DIM = 2;
const size_t WIDTH_DIM = 3;
if (shape.at(HEIGHT_DIM) != shape.at(WIDTH_DIM))
{
return 0;
}
switch (shape.at(HEIGHT_DIM))
{
case 28: return 1;
case 14: return 2;
case 7: return 3;
default: return 0;
}
}
// conv(56w3s1) conv(28w3s2)
// | |
// conv(56w1s1) ==> conv(28w1s1)
// | |
//elt------------56 elt------------pool(28s2)
// | | | |
//conv(28w1s2) conv(28w1s2) conv(28w1s1) conv(28w1s1)
void pass::CoreFusion::construct_optimized_strided_conv()
{
Shape win_size_1{1, 1, 1, 1};
auto is_bc = ngraph::pattern::has_class<op::Broadcast>();
auto data_stride3 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 128, 128});
auto weights_stride3 = std::make_shared<pattern::op::Label>(element::f32, win_size_1);
auto conv_stride3 = std::make_shared<op::Convolution>(data_stride3, weights_stride3);
auto conv_stride3_label =
std::make_shared<pattern::op::Label>(conv_stride3, nullptr, NodeVector{conv_stride3});
auto broadcast_w3_label = std::make_shared<pattern::op::Label>(conv_stride3_label, is_bc);
auto add_w3 = std::make_shared<op::Add>(conv_stride3_label, broadcast_w3_label);
auto relu_w3 = std::make_shared<op::Relu>(add_w3);
auto weights_stride1 = std::make_shared<pattern::op::Label>(element::f32, win_size_1);
auto conv_stride1 = std::make_shared<op::Convolution>(relu_w3, weights_stride1);
auto conv_stride1_label =
std::make_shared<pattern::op::Label>(conv_stride1, nullptr, NodeVector{conv_stride1});
auto broadcast_w1_label = std::make_shared<pattern::op::Label>(conv_stride1_label, is_bc);
auto add_w1 = std::make_shared<op::Add>(conv_stride1_label, broadcast_w1_label);
auto eltwise_arg_label =
std::make_shared<pattern::op::Label>(element::f32, conv_stride1->get_shape());
auto add_two_convs = std::make_shared<op::Add>(add_w1, eltwise_arg_label);
auto relu_two_convs = std::make_shared<op::Relu>(add_two_convs);
auto eltwise_label =
std::make_shared<pattern::op::Label>(relu_two_convs, nullptr, NodeVector{relu_two_convs});
auto weights_eltwise = std::make_shared<pattern::op::Label>(element::f32, win_size_1);
auto eltwise_conv = std::make_shared<op::Convolution>(eltwise_label, weights_eltwise);
pattern::graph_rewrite_callback callback = [win_size_1,
eltwise_label,
conv_stride1_label,
conv_stride3_label,
eltwise_arg_label,
broadcast_w3_label,
broadcast_w1_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_conv_skip against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto m_eltwise = pattern_map[eltwise_label];
auto strided_convs = m_eltwise->get_users();
if (strided_convs.size() != 2)
{
NGRAPH_DEBUG << "Number of users of element wise operation isn't equal to 2";
return false;
}
Shape supported_shapes[] = {Shape{56, 56}, Shape{28, 28}, Shape{14, 14}, Shape{7, 7}};
Shape shape_1{1, 1};
Shape shape_3{3, 3};
Strides stride_2{2, 2};
Strides stride_1{1, 1};
CoordinateDiff pad_0{0, 0};
CoordinateDiff pad_1{1, 1};
Shape win_size_3{1, 1, 3, 3};
size_t sparse_shape_index = 0;
NodeVector sconvs;
for (auto sc : strided_convs)
{
if (sc->get_argument(0) != m_eltwise)
{
NGRAPH_DEBUG << "element-wise isn't data";
return false;
}
auto sconv = std::dynamic_pointer_cast<op::Convolution>(sc);
sparse_shape_index = shape_to_index(sconv->get_shape());
if (sparse_shape_index == 0)
{
NGRAPH_DEBUG << "Unsupported shape of " << sconv->get_name();
return false;
}
if (!are_img_dims_equal(sconv->get_shape(), supported_shapes[sparse_shape_index]) ||
!are_img_dims_equal(sconv->get_argument(1)->get_shape(), shape_1) ||
sconv->get_window_movement_strides() != stride_2 || !is_trivial_convolution(sconv))
{
NGRAPH_DEBUG << sconv->get_name() << " and its weights are of the wrong shape (not "
<< vector_to_string(supported_shapes[sparse_shape_index])
<< " and 1x1) and strides (2x2)";
return false;
}
sconvs.push_back(sconv);
}
const size_t full_shape_index = sparse_shape_index - 1;
auto m_conv_stride1 =
std::dynamic_pointer_cast<op::Convolution>(pattern_map[conv_stride1_label]);
if (!are_img_dims_equal(m_conv_stride1->get_shape(), supported_shapes[full_shape_index]) ||
!are_img_dims_equal(m_conv_stride1->get_argument(1)->get_shape(), win_size_1) ||
m_conv_stride1->get_window_movement_strides() != stride_1 ||
!is_trivial_convolution(m_conv_stride1))
{
NGRAPH_DEBUG << m_conv_stride1->get_name()
<< " and its weights are of the wrong shape (not "
<< vector_to_string(supported_shapes[full_shape_index])
<< " and 1x1) and strides (1x1)";
return false;
}
auto m_conv_stride3 =
std::dynamic_pointer_cast<op::Convolution>(pattern_map[conv_stride3_label]);
if (!are_img_dims_equal(m_conv_stride3->get_shape(), supported_shapes[full_shape_index]) ||
!are_img_dims_equal(m_conv_stride3->get_argument(1)->get_shape(), shape_3) ||
m_conv_stride3->get_window_movement_strides() != stride_1 ||
!is_trivial_convolution(m_conv_stride3, true))
{
NGRAPH_DEBUG << m_conv_stride3->get_name()
<< " and its weights are of the wrong shape (not "
<< vector_to_string(supported_shapes[full_shape_index])
<< " and 3x3) and strides (1x1)";
return false;
}
auto conv_28w3s2 = std::make_shared<op::Convolution>(m_conv_stride3->get_argument(0),
m_conv_stride3->get_argument(1),
stride_2,
stride_1,
pad_1,
pad_1);
auto maxpool_w3 =
std::make_shared<op::MaxPool>(pattern_map[broadcast_w3_label], Shape{1, 1}, stride_2);
auto new_add_conv_28w3s2 = std::make_shared<op::Add>(conv_28w3s2, maxpool_w3);
auto new_relu_28w3s2 = std::make_shared<op::Relu>(new_add_conv_28w3s2);
auto conv_28w1s1 = std::make_shared<op::Convolution>(
new_relu_28w3s2, m_conv_stride1->get_argument(1), stride_1, stride_1);
auto maxpool_w1 =
std::make_shared<op::MaxPool>(pattern_map[broadcast_w1_label], Shape{1, 1}, stride_2);
auto new_add_conv28s1 = std::make_shared<op::Add>(conv_28w1s1, maxpool_w1);
auto maxpool =
std::make_shared<op::MaxPool>(pattern_map[eltwise_arg_label], Shape{1, 1}, stride_2);
auto new_add_two_convs = std::make_shared<op::Add>(new_add_conv28s1, maxpool);
auto new_relu_two_convs = std::make_shared<op::Relu>(new_add_two_convs);
for (auto sconv : sconvs)
{
auto sconv_28w1s1 = std::make_shared<op::Convolution>(
new_relu_two_convs, sconv->get_argument(1), stride_1, stride_1);
NGRAPH_DEBUG << "Replacing " << sconv->get_name() << " with "
<< sconv_28w1s1->get_name();
ngraph::replace_node(sconv, sconv_28w1s1);
}
return true;
};
auto m = make_shared<pattern::Matcher>(eltwise_conv, callback);
this->add_matcher(m);
}
......@@ -34,7 +34,9 @@ public:
{
construct_relu();
construct_folded_batch_norm();
construct_optimized_strided_conv();
}
void construct_relu();
void construct_folded_batch_norm();
void construct_optimized_strided_conv();
};
......@@ -55,3 +55,59 @@ TEST(core_fusion, core_fusion_pass_basic)
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::Relu>(graph->get_argument(0)), nullptr);
}
TEST(core_fusion, sparsity_opt_56x56)
{
Shape win_size_3{1, 1, 3, 3};
Shape win_size_1{1, 1, 1, 1};
Strides stride_2{2, 2};
Strides stride_1{1, 1};
CoordinateDiff pad_0{0, 0};
CoordinateDiff pad_1{1, 1};
auto data_stride3 = std::make_shared<op::Parameter>(element::f32, Shape{1, 64, 56, 56});
auto weights_stride3 = std::make_shared<op::Parameter>(element::f32, Shape{64, 64, 3, 3});
auto conv_stride3 = std::make_shared<op::Convolution>(
data_stride3, weights_stride3, stride_1, stride_1, pad_1, pad_1);
auto param_broadcast_w3 = std::make_shared<op::Parameter>(element::f32, Shape{64});
auto broadcast_w3 =
std::make_shared<op::Broadcast>(param_broadcast_w3, Shape{1, 64, 56, 56}, AxisSet{0, 2, 3});
auto add_w3 = std::make_shared<op::Add>(conv_stride3, broadcast_w3);
auto relu_w3 = std::make_shared<op::Relu>(add_w3);
///
auto weights_stride1 = std::make_shared<op::Parameter>(element::f32, Shape{256, 64, 1, 1});
auto conv_stride1 = std::make_shared<op::Convolution>(relu_w3, weights_stride1);
auto param_broadcast_w1 = std::make_shared<op::Parameter>(element::f32, Shape{256});
auto broadcast_w1 = std::make_shared<op::Broadcast>(
param_broadcast_w1, Shape{1, 256, 56, 56}, AxisSet{0, 2, 3});
auto add_w1 = std::make_shared<op::Add>(conv_stride1, broadcast_w1);
////
auto other_arg = std::make_shared<op::Parameter>(element::f32, Shape{1, 256, 56, 56});
auto add_two_convs = std::make_shared<op::Add>(add_w1, other_arg);
auto relu_two_convs = std::make_shared<op::Relu>(add_two_convs);
///
auto weights_conv_s2 = std::make_shared<op::Parameter>(element::f32, Shape{512, 256, 1, 1});
auto conv_s2_1 = std::make_shared<op::Convolution>(relu_two_convs, weights_conv_s2, stride_2);
auto conv_s2_2 = std::make_shared<op::Convolution>(relu_two_convs, weights_conv_s2, stride_2);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>();
auto params = op::ParameterVector{data_stride3,
weights_stride3,
param_broadcast_w3,
weights_stride1,
param_broadcast_w1,
other_arg,
weights_conv_s2};
auto func = make_shared<Function>(NodeVector{conv_s2_1, conv_s2_2}, params);
pass_manager.run_passes(func);
auto results = func->get_results();
auto t_eltwise_conv1 =
std::dynamic_pointer_cast<op::Convolution>(results.at(0)->get_argument(0));
auto t_eltwise_conv2 =
std::dynamic_pointer_cast<op::Convolution>(results.at(1)->get_argument(0));
ASSERT_TRUE(t_eltwise_conv1);
ASSERT_TRUE(t_eltwise_conv2);
ASSERT_EQ(t_eltwise_conv1->get_window_movement_strides(), stride_1);
ASSERT_EQ(t_eltwise_conv2->get_window_movement_strides(), stride_1);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment