Commit 1a8b1f97 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Fold affine transformations on 4d convolution (#1347)

* Fold affine transformations on 4d convolution

* Handle more cases for affine parameters

* Style fix
parent 6f61679c
...@@ -581,7 +581,7 @@ bool ngraph::possibly_overwritten(Node* node) ...@@ -581,7 +581,7 @@ bool ngraph::possibly_overwritten(Node* node)
{ {
for (auto oi_pair : op_annotations->get_in_place_oi_pairs()) for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{ {
if (input->get_index() == oi_pair.input) if (input->get_index() == oi_pair.input && oi_pair.destructive)
{ {
return true; return true;
} }
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "ngraph/op/negative.hpp" #include "ngraph/op/negative.hpp"
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sqrt.hpp" #include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
...@@ -259,6 +260,117 @@ void pass::CoreFusion::construct_folded_batch_norm() ...@@ -259,6 +260,117 @@ void pass::CoreFusion::construct_folded_batch_norm()
this->add_matcher(m); this->add_matcher(m);
} }
void pass::CoreFusion::construct_conv_affine_folding()
{
// A * Conv (input, filters) + B -> ConvBias (input, filters * A_c, B_c)
Shape shape{2, 2, 1, 1};
auto input = std::make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape);
auto conv = std::make_shared<op::Convolution>(input,
filters,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
auto Ac = std::make_shared<pattern::op::Label>(element::f32, Shape{2});
auto A = std::make_shared<op::Broadcast>(Ac, Shape{2, 2, 1, 1}, AxisSet{0, 2, 3});
auto A_label = std::make_shared<pattern::op::Label>(A, nullptr, NodeVector{A});
auto Bc = std::make_shared<pattern::op::Label>(element::f32, Shape{2});
auto B = std::make_shared<op::Broadcast>(Bc, Shape{2, 2, 1, 1}, AxisSet{0, 2, 3});
auto B_label = std::make_shared<pattern::op::Label>(B, nullptr, NodeVector{B});
auto multiply = std::make_shared<op::Multiply>(conv_label, A_label);
auto add = std::make_shared<op::Add>(multiply, B_label);
ngraph::pattern::graph_rewrite_callback callback =
[input, filters, conv_label, A_label, B_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for conv affine folding against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto conv_m = std::static_pointer_cast<op::Convolution>(pattern_map[conv_label]);
if (conv_m->get_users().size() > 1)
{
return false;
}
if (conv_m->get_shape().size() != 4)
{
return false;
}
auto A_m = std::static_pointer_cast<op::Broadcast>(pattern_map[A_label]);
auto B_m = std::static_pointer_cast<op::Broadcast>(pattern_map[B_label]);
// Check if values are being broadcast along channel (2nd) dimension
auto is_channel_bcast = [](const shared_ptr<op::Broadcast>& bcast) {
if (bcast->get_argument(0)->get_shape().size() == 1 &&
bcast->get_broadcast_axes() == AxisSet{0, 2, 3})
{
return true;
}
if (bcast->get_argument(0)->get_shape().size() == 2)
{
auto input_shape = bcast->get_argument(0)->get_shape();
if (input_shape[0] == 1 && bcast->get_broadcast_axes() == AxisSet{2, 3})
return true;
}
return false;
};
if (!is_channel_bcast(A_m) || !is_channel_bcast(B_m))
{
return false;
}
auto get_bcast_input = [](const shared_ptr<op::Broadcast>& bcast) {
if (bcast->get_argument(0)->get_shape().size() == 1)
{
return bcast->get_argument(0);
}
if (bcast->get_argument(0)->get_shape().size() == 2)
{
Shape bshape{bcast->get_argument(0)->get_shape()[1]};
return static_pointer_cast<ngraph::Node>(std::make_shared<op::Reshape>(
bcast->get_argument(0), AxisVector{0, 1}, bshape));
}
throw ngraph_error("Unexpected shape for bcast input");
};
auto Ac_m = get_bcast_input(A_m);
// new weights = old weights * Ac_m
// new biases = Bc_m
auto filters_n = std::make_shared<op::Multiply>(
pattern_map[filters],
std::make_shared<op::Broadcast>(
Ac_m, pattern_map[filters]->get_shape(), AxisSet{1, 2, 3}));
auto conv_n = std::make_shared<op::Convolution>(pattern_map[input],
filters_n,
conv_m->get_window_movement_strides(),
conv_m->get_window_dilation_strides(),
conv_m->get_padding_below(),
conv_m->get_padding_above(),
conv_m->get_data_dilation_strides());
auto convbias_n = conv_n + B_m;
ngraph::replace_node(m.get_match_root(), convbias_n);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(add, callback);
this->add_matcher(m);
}
static bool is_trivial_convolution(std::shared_ptr<op::Convolution> conv, static bool is_trivial_convolution(std::shared_ptr<op::Convolution> conv,
bool skip_pad_checks = false) bool skip_pad_checks = false)
{ {
......
...@@ -34,12 +34,14 @@ public: ...@@ -34,12 +34,14 @@ public:
{ {
construct_relu(); construct_relu();
construct_folded_batch_norm(); construct_folded_batch_norm();
construct_conv_affine_folding();
construct_sigmoid(); construct_sigmoid();
construct_sigmoid_bprop(); construct_sigmoid_bprop();
construct_optimized_strided_conv(); construct_optimized_strided_conv();
} }
void construct_relu(); void construct_relu();
void construct_folded_batch_norm(); void construct_folded_batch_norm();
void construct_conv_affine_folding();
void construct_sigmoid(); void construct_sigmoid();
void construct_sigmoid_bprop(); void construct_sigmoid_bprop();
void construct_optimized_strided_conv(); void construct_optimized_strided_conv();
......
...@@ -1561,6 +1561,63 @@ TEST(cpu_fusion, batch_norm_folding) ...@@ -1561,6 +1561,63 @@ TEST(cpu_fusion, batch_norm_folding)
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0))); EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
} }
TEST(cpu_fusion, affine_folding)
{
Shape shape_input{1, 8, 3, 3};
Shape shape_weights{2, 8, 1, 1};
Shape shape_norm{2};
auto make_function = [shape_input, shape_weights, shape_norm]() {
auto input = std::make_shared<op::Parameter>(element::f32, shape_input);
auto weights = std::make_shared<op::Parameter>(element::f32, shape_weights);
auto a = std::make_shared<op::Parameter>(element::f32, shape_norm);
auto b = std::make_shared<op::Parameter>(element::f32, shape_norm);
auto conv = std::make_shared<op::Convolution>(input, weights, Strides{1, 1}, Strides{1, 1});
auto out = std::make_shared<op::Add>(
std::make_shared<op::Multiply>(
conv, std::make_shared<op::Broadcast>(a, conv->get_shape(), AxisSet{0, 2, 3})),
std::make_shared<op::Broadcast>(b, conv->get_shape(), AxisSet{0, 2, 3}));
auto f = make_shared<Function>(NodeVector{out}, op::ParameterVector{input, weights, a, b});
return f;
};
auto int_f = make_function();
auto cpu_f = make_function();
vector<vector<float>> args{
{1.25f, 2.25f, 5.25f, 6.25f, -1.25f, -1.25f, 3.25f, -4.25f, 7.25f, 8.25f, -1.25f,
-1.25f, 1.25f, 2.25f, -3.25f, 2.25f, 4.25f, 4.25f, 1.25f, 2.25f, -4.25f, 2.25f,
4.25f, 4.25f, 0.f, 0.f, -1.f, 0.f, 2.f, 2.f, 0.f, 0.f, 0.f,
0.f, 2.f, 2.f, 1.25f, 2.25f, 5.25f, 6.25f, 1.25f, 1.25f, 3.25f, 4.25f,
-7.25f, 8.25f, 1.25f, -1.25f, -1.25f, 2.25f, 3.25f, 2.25f, -4.25f, -4.25f, -1.25f,
-2.25f, 4.25f, 2.25f, 4.25f, 4.25f, 0.f, 0.f, 1.f, 0.f, -2.f, 2.f,
0.f, 0.f, 0.f, 0.f, -2.f, -2.f},
{1.25f,
2.25f,
5.25f,
6.25f,
-1.25f,
-1.25f,
3.25f,
-4.25f,
7.25f,
8.25f,
-1.25f,
0.f,
0.f,
0.f,
0.f,
-2.f},
{-0.9384f, 0.01875f},
{11.0f, 1.3f},
};
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
}
TEST(cpu_fusion, group_convolution_fusion) TEST(cpu_fusion, group_convolution_fusion)
{ {
Shape shape_a{1, 32, 2, 2}; Shape shape_a{1, 32, 2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment