Commit 7060d794 authored by Amy Zhuang's avatar Amy Zhuang

Add Batch Norm Inference Relu fusion.

parent 877fb219
......@@ -650,6 +650,99 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta
this->add_matcher(m, callback);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_infer_relu_with_multi_add()
{
auto input_shape = Shape{1, 3, 2, 2};
auto input = std::make_shared<pattern::op::Label>(element::f32, input_shape);
auto mean_shape = Shape{3};
auto mean = std::make_shared<pattern::op::Label>(element::f32, mean_shape);
auto var_shape = Shape{3};
auto var = std::make_shared<pattern::op::Label>(element::f32, var_shape);
auto gamma_shape = Shape{3};
auto gamma = std::make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto beta_shape = Shape{3};
auto beta = std::make_shared<pattern::op::Label>(element::f32, beta_shape);
double eps = 0.001;
auto bn = std::make_shared<ngraph::op::BatchNormInference>(eps, gamma, beta, input, mean, var);
auto bn_label = std::make_shared<pattern::op::Label>(bn, nullptr, NodeVector{bn});
auto broadcast1_input = std::make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto broadcast1 =
std::make_shared<ngraph::op::Broadcast>(broadcast1_input, input_shape, AxisSet{0, 2, 3});
auto broadcast1_label =
std::make_shared<pattern::op::Label>(broadcast1, nullptr, NodeVector{broadcast1});
auto multiply = std::make_shared<ngraph::op::Multiply>(bn_label, broadcast1_label);
auto multi_label =
std::make_shared<pattern::op::Label>(multiply, nullptr, NodeVector{multiply});
auto broadcast2_input = std::make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto broadcast2 =
std::make_shared<ngraph::op::Broadcast>(broadcast2_input, input_shape, AxisSet{0, 2, 3});
auto broadcast2_label =
std::make_shared<pattern::op::Label>(broadcast2, nullptr, NodeVector{broadcast2});
auto add = std::make_shared<ngraph::op::Add>(multi_label, broadcast2_label);
auto prelu = std::make_shared<ngraph::op::Relu>(add);
auto callback =
[input, mean, var, gamma, beta, bn_label, multi_label, broadcast1_input, broadcast2_input](
pattern::Matcher& m) {
NGRAPH_DEBUG
<< "In callback for construct_batch_norm_infer_relu_with_multi_add against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto bn_match = pattern_map[bn_label];
if (bn_match->get_users().size() > 1)
{
NGRAPH_DEBUG << "Multiply isn't the only user of BatchNorm's output";
return false;
}
auto multi_match = pattern_map[multi_label];
if (multi_match->get_users().size() > 1)
{
NGRAPH_DEBUG << "Add isn't the only user of Multiply's output";
return false;
}
auto new_gamma = std::make_shared<ngraph::op::Multiply>(pattern_map[gamma],
pattern_map[broadcast1_input]);
auto new_multi = std::make_shared<ngraph::op::Multiply>(pattern_map[beta],
pattern_map[broadcast1_input]);
auto new_beta =
std::make_shared<ngraph::op::Add>(new_multi, pattern_map[broadcast2_input]);
std::shared_ptr<Node> bn_relu;
if (auto bn_inference =
std::dynamic_pointer_cast<ngraph::op::BatchNormInference>(bn_match))
{
if (!mkldnn_utils::can_use_mkldnn_batchnorm_fprop(bn_inference.get()))
{
return false;
}
bn_relu = std::make_shared<ngraph::op::BatchNormInferenceRelu>(
bn_inference->get_eps_value(),
new_gamma,
new_beta,
pattern_map[input],
pattern_map[mean],
pattern_map[var]);
}
if (bn_relu)
{
ngraph::replace_node(m.get_match_root(), bn_relu);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu,
"CPUFusion.BatchNormInferReluWithMultiAdd");
this->add_matcher(m, callback);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
{
Shape shape{2, 2, 1, 1};
......
......@@ -78,6 +78,7 @@ public:
construct_deconvolution_affine_folding_relu();
}
construct_dropout();
construct_batch_norm_infer_relu_with_multi_add();
}
}
......@@ -90,6 +91,7 @@ private:
void construct_sigmoid_multiply();
void construct_batch_norm_relu();
void construct_batch_norm_relu_global_stats();
void construct_batch_norm_infer_relu_with_multi_add();
void construct_conv_relu();
void construct_conv_bias_relu();
void construct_conv_bias_add();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment