Commit 0d125c51 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Batch norm folding (#992)

* Batch norm folding

* Addressed PR feedback

* Style fixes

* Style fix
parent 39d0453f
......@@ -21,11 +21,19 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
......@@ -73,3 +81,84 @@ void pass::CoreFusion::construct_relu()
auto m = make_shared<pattern::Matcher>(max, callback);
this->add_matcher(m);
}
void pass::CoreFusion::construct_folded_batch_norm()
{
Shape shape{2, 2, 1, 1};
auto input = std::make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape);
auto pconv = std::make_shared<op::Convolution>(input,
filters,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto mean_shape = Shape{2};
auto mean = std::make_shared<pattern::op::Label>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = std::make_shared<pattern::op::Label>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = std::make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = std::make_shared<pattern::op::Label>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{1, 2, 2, 2};
auto bn = std::make_shared<op::BatchNorm>(eps, gamma, beta, pconv, mean, var);
ngraph::pattern::graph_rewrite_callback callback = [input, filters, mean, var, gamma, beta](
pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for folded batch norm against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto m_bn = std::dynamic_pointer_cast<op::BatchNorm>(m.get_match_root());
auto m_conv = std::dynamic_pointer_cast<op::Convolution>(m_bn->get_argument(2));
if (m_conv->get_users().size() > 1)
{
return false;
}
if (m_conv->get_shape().size() != 4)
{
return false;
}
// new weights = old weights * gamma / sqrt(variance + epsilon)
// new biases = -mean * gamma / sqrt(variance + epsilon) + beta
auto bn_eps = op::Constant::create(element::f32, Shape{}, {m_bn->get_eps_value()});
auto var_eps = std::make_shared<op::Add>(
pattern_map[var],
std::make_shared<op::Broadcast>(bn_eps, pattern_map[var]->get_shape(), AxisSet{0}));
auto sqrt_var_eps = std::make_shared<op::Sqrt>(var_eps);
auto mean_gamma = std::make_shared<op::Multiply>(pattern_map[mean], pattern_map[gamma]);
auto new_biases = std::make_shared<op::Subtract>(
pattern_map[beta], std::make_shared<op::Divide>(mean_gamma, sqrt_var_eps));
auto weight_scaling = std::make_shared<op::Divide>(pattern_map[gamma], sqrt_var_eps);
auto new_weights = std::make_shared<op::Multiply>(
pattern_map[filters],
std::make_shared<op::Broadcast>(
weight_scaling, pattern_map[filters]->get_shape(), AxisSet{1, 2, 3}));
auto conv = std::make_shared<op::Convolution>(pattern_map[input],
new_weights,
m_conv->get_window_movement_strides(),
m_conv->get_window_dilation_strides(),
m_conv->get_padding_below(),
m_conv->get_padding_above(),
m_conv->get_data_dilation_strides());
auto conv_bias =
conv + std::make_shared<op::Broadcast>(new_biases, conv->get_shape(), AxisSet{0, 2, 3});
ngraph::replace_node(m.get_match_root(), conv_bias);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, callback);
this->add_matcher(m);
}
......@@ -33,6 +33,8 @@ public:
: GraphRewrite()
{
construct_relu();
construct_folded_batch_norm();
}
void construct_relu();
void construct_folded_batch_norm();
};
......@@ -943,3 +943,71 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
auto m = std::make_shared<pattern::Matcher>(prelu, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_relu()
{
Shape shape{2, 2, 1, 1};
auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape);
auto bias = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
auto conv_bias = std::make_shared<op::ConvolutionBias>(data_batch,
filters,
bias,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto prelu = std::make_shared<op::Relu>(conv_bias);
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_conv_relu against "
<< m.get_match_root()->get_name();
auto conv =
std::dynamic_pointer_cast<op::ConvolutionBias>(m.get_match_root()->get_argument(0));
//These checks are to make sure a MKLDNN Convolution kernel can be used.
bool data_dilated = false;
for (size_t s : conv->get_data_dilation_strides())
{
data_dilated = data_dilated || (s != 1);
}
if (data_dilated)
{
NGRAPH_DEBUG << "Convolution has dilations greater than 1";
return false;
}
if (conv->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "Convolution isn't of type float";
return false;
}
auto arg0_rank = conv->get_input_shape(0).size();
auto arg1_rank = conv->get_input_shape(1).size();
if (arg0_rank != 4 || arg1_rank != 4)
{
NGRAPH_DEBUG << "Convolution's arguments ranks aren't equal to 4";
return false;
}
if (conv->get_users().size() > 1)
{
NGRAPH_DEBUG << "Convolution has more than one user";
return false;
}
auto conv_relu = std::shared_ptr<Node>(new op::ConvolutionBiasRelu(conv));
ngraph::replace_node(m.get_match_root(), conv_relu);
return true;
};
auto m = std::make_shared<pattern::Matcher>(prelu, callback);
this->add_matcher(m);
}
......@@ -63,6 +63,7 @@ public:
construct_batch_norm_relu();
construct_batch_norm_relu_global_stats();
construct_conv_relu();
construct_conv_bias_relu();
}
if (fusions & DIFFERENTIABLE_FUSIONS)
......@@ -84,4 +85,5 @@ private:
void construct_batch_norm_relu();
void construct_batch_norm_relu_global_stats();
void construct_conv_relu();
void construct_conv_bias_relu();
};
......@@ -1234,3 +1234,62 @@ TEST(cpu_fusion, backwards_maxpool_with_indices_n4_c1_hw4_2x2_max)
backend->call(df, {output}, {input, ep});
ASSERT_TRUE(read_vector<float>(output) == expected);
}
TEST(cpu_fusion, batch_norm_folding)
{
Shape shape_input{1, 8, 3, 3};
Shape shape_weights{2, 8, 1, 1};
Shape shape_norm{2};
auto make_function = [shape_input, shape_weights, shape_norm]() {
auto input = std::make_shared<op::Parameter>(element::f32, shape_input);
auto weights = std::make_shared<op::Parameter>(element::f32, shape_weights);
double eps = 0.001;
auto gamma = std::make_shared<op::Parameter>(element::f32, shape_norm);
auto beta = std::make_shared<op::Parameter>(element::f32, shape_norm);
auto mean = std::make_shared<op::Parameter>(element::f32, shape_norm);
auto var = std::make_shared<op::Parameter>(element::f32, shape_norm);
auto conv = std::make_shared<op::Convolution>(input, weights, Strides{1, 1}, Strides{1, 1});
auto bn = std::make_shared<op::BatchNorm>(eps, gamma, beta, conv, mean, var);
auto f = make_shared<Function>(NodeVector{bn},
op::ParameterVector{input, weights, gamma, beta, mean, var});
return f;
};
auto int_f = make_function();
auto cpu_f = make_function();
vector<vector<float>> args{
{1.25f, 2.25f, 5.25f, 6.25f, -1.25f, -1.25f, 3.25f, -4.25f, 7.25f, 8.25f, -1.25f,
-1.25f, 1.25f, 2.25f, -3.25f, 2.25f, 4.25f, 4.25f, 1.25f, 2.25f, -4.25f, 2.25f,
4.25f, 4.25f, 0.f, 0.f, -1.f, 0.f, 2.f, 2.f, 0.f, 0.f, 0.f,
0.f, 2.f, 2.f, 1.25f, 2.25f, 5.25f, 6.25f, 1.25f, 1.25f, 3.25f, 4.25f,
-7.25f, 8.25f, 1.25f, -1.25f, -1.25f, 2.25f, 3.25f, 2.25f, -4.25f, -4.25f, -1.25f,
-2.25f, 4.25f, 2.25f, 4.25f, 4.25f, 0.f, 0.f, 1.f, 0.f, -2.f, 2.f,
0.f, 0.f, 0.f, 0.f, -2.f, -2.f},
{1.25f,
2.25f,
5.25f,
6.25f,
-1.25f,
-1.25f,
3.25f,
-4.25f,
7.25f,
8.25f,
-1.25f,
0.f,
0.f,
0.f,
0.f,
-2.f},
{-0.9384f, 0.01875f},
{11.0f, 1.3f},
{0.12f, 0.31f},
{0.01f, 0.11f},
};
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment