Commit 22819e78 authored by nikolay.korovaiko's avatar nikolay.korovaiko

conv+bias fusion

parent 5c0a29ee
/******************************************************************************* /*******************************************************************************
* Copyright 2017-2018 Intel Corporation * Copyright 2017-2018 Intel Corporation
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "cpu_fusion.hpp" #include "cpu_fusion.hpp"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <unordered_set> #include <unordered_set>
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/batch_norm.hpp" #include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/divide.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/dot.hpp" #include "ngraph/ops/divide.hpp"
#include "ngraph/ops/multiply.hpp" #include "ngraph/ops/dot.hpp"
#include "ngraph/ops/parameter.hpp" #include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/sqrt.hpp" #include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/sqrt.hpp"
#include "ngraph/ops/sum.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/ops/sum.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/any.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/ops/conv_bias.hpp"
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape, #include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
std::shared_ptr<ngraph::Node> arg,
bool& transpose_w, static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
ngraph::Shape& shape_w) std::shared_ptr<ngraph::Node> arg,
{ bool& transpose_w,
auto r_w = std::dynamic_pointer_cast<ngraph::op::Reshape>(reshape); ngraph::Shape& shape_w)
{
if (!r_w) auto r_w = std::dynamic_pointer_cast<ngraph::op::Reshape>(reshape);
{
return true; //nth to do; reshape isn't a reshape if (!r_w)
} {
return true; //nth to do; reshape isn't a reshape
if (r_w->get_shape().size() != 2) }
{
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " doesn't reshape into matrix" if (r_w->get_shape().size() != 2)
<< ngraph::vector_to_string(r_w->get_shape()); {
return false; NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " doesn't reshape into matrix"
} << ngraph::vector_to_string(r_w->get_shape());
return false;
auto io = r_w->get_input_order(); }
if (r_w->get_shape().size() != arg->get_shape().size()) //reshape
{ auto io = r_w->get_input_order();
ngraph::AxisVector dio(io.size()); if (r_w->get_shape().size() != arg->get_shape().size()) //reshape
std::iota(begin(dio), end(dio), 0); {
ngraph::AxisVector dio(io.size());
if (io != dio) //we can't reshape and transpose at the same time std::iota(begin(dio), end(dio), 0);
{
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order " if (io != dio) //we can't reshape and transpose at the same time
<< ngraph::vector_to_string(io); {
NGRAPH_DEBUG << "r_w shape = " << ngraph::vector_to_string(r_w->get_shape()); NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order "
NGRAPH_DEBUG << "arg shape = " << ngraph::vector_to_string(arg->get_shape()); << ngraph::vector_to_string(io);
return false; NGRAPH_DEBUG << "r_w shape = " << ngraph::vector_to_string(r_w->get_shape());
} NGRAPH_DEBUG << "arg shape = " << ngraph::vector_to_string(arg->get_shape());
return false;
shape_w = r_w->get_shape(); }
}
else shape_w = r_w->get_shape();
{ }
if (io == ngraph::AxisVector{1, 0}) else
{ {
transpose_w = true; if (io == ngraph::AxisVector{1, 0})
} {
//otherwise no-op reshape transpose_w = true;
} }
//otherwise no-op reshape
return true; }
}
return true;
template <typename T> }
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
{ template <typename T>
if (input.size() != order.size()) static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
{ {
throw "input and order sizes don't match!"; if (input.size() != order.size())
} {
throw "input and order sizes don't match!";
std::vector<T> output(input.size()); }
for (size_t i = 0; i < order.size(); i++) std::vector<T> output(input.size());
{
output[i] = input.at(order.at(i)); for (size_t i = 0; i < order.size(); i++)
} {
output[i] = input.at(order.at(i));
return output; }
}
return output;
void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern() }
{
Shape shape_w{2, 4}; void ngraph::runtime::cpu::pass::CPUFusion::construct_gemm_pattern()
Shape shape_x{4, 1}; {
Shape shape_b{1}; Shape shape_w{2, 4};
Shape shape_dot{2, 1}; Shape shape_x{4, 1};
Shape shape_b{1};
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w); Shape shape_dot{2, 1};
auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
auto reshape_pred = [](std::shared_ptr<Node> n) { auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
}; auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred); };
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
auto pdot = std::make_shared<op::Dot>(skip_w, skip_x); auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto pbroadcast = std::make_shared<op::Broadcast>(b, shape_dot, AxisSet{0}); auto pdot = std::make_shared<op::Dot>(skip_w, skip_x);
auto padd = pdot + pbroadcast; auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto pbroadcast = std::make_shared<op::Broadcast>(b, shape_dot, AxisSet{0});
ngraph::pattern::gr_callback_fn callback = [W, x, b](pattern::Matcher& m) { auto padd = pdot + pbroadcast;
NGRAPH_DEBUG << "In callback for construct_gemm_pattern against node = "
<< m.match_root()->get_name(); ngraph::pattern::gr_callback_fn callback = [W, x, b](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map(); NGRAPH_DEBUG << "In callback for construct_gemm_pattern against node = "
std::shared_ptr<Node> nn = nullptr; << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto mpattern = m.match_root(); std::shared_ptr<Node> nn = nullptr;
if (mpattern->get_element_type() != element::f32)
{ auto mpattern = m.match_root();
NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!"; if (mpattern->get_element_type() != element::f32)
return nn; {
} NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!";
return nn;
auto dot = mpattern->get_input_op(0); }
if (dot->get_shape().size() != 2)
{ auto dot = mpattern->get_input_op(0);
NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!"; if (dot->get_shape().size() != 2)
return nn; {
} NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!";
return nn;
bool transpose_w = false; }
Shape shape_arg0{pattern_map[W]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0)) bool transpose_w = false;
{ Shape shape_arg0{pattern_map[W]->get_shape()};
return nn; if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0))
} {
return nn;
bool transpose_x = false; }
Shape shape_arg1{pattern_map[x]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1)) bool transpose_x = false;
{ Shape shape_arg1{pattern_map[x]->get_shape()};
return nn; if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1))
} {
return nn;
auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W], }
pattern_map[x],
mpattern->get_input_op(1), auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W],
shape_arg0, pattern_map[x],
shape_arg1, mpattern->get_input_op(1),
transpose_w, shape_arg0,
transpose_x)); shape_arg1,
return cg; transpose_w,
}; transpose_x));
return cg;
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback); };
this->add_matcher(m);
} auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback);
this->add_matcher(m);
void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn() }
{
// construct varaiance void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2}); {
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3}); // construct varaiance
auto input_sq = std::make_shared<op::Multiply>(input, input); auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0}); auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input); auto input_sq = std::make_shared<op::Multiply>(input, input);
auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0}); auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N); auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq); auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
auto variance = std::make_shared<op::Divide>(xmu, N); auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
auto variance_label = std::make_shared<pattern::op::Label>(variance, nullptr, Nodes{variance}); auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
auto variance_with_broadcast = auto variance = std::make_shared<op::Divide>(xmu, N);
std::make_shared<op::Broadcast>(variance_label, Shape{2, 3}, AxisSet{0}); auto variance_label = std::make_shared<pattern::op::Label>(variance, nullptr, Nodes{variance});
auto variance_with_broadcast =
// construct mean std::make_shared<op::Broadcast>(variance_label, Shape{2, 3}, AxisSet{0});
auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
auto mean = std::make_shared<op::Divide>(sum_input1, N); // construct mean
auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, Nodes{mean}); auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
auto mean_with_broadcast = std::make_shared<op::Broadcast>(mean_label, Shape{2, 3}, AxisSet{0}); auto mean = std::make_shared<op::Divide>(sum_input1, N);
auto input_diff_mean = std::make_shared<op::Subtract>(input, mean_with_broadcast); auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, Nodes{mean});
auto mean_with_broadcast = std::make_shared<op::Broadcast>(mean_label, Shape{2, 3}, AxisSet{0});
// Eps auto input_diff_mean = std::make_shared<op::Subtract>(input, mean_with_broadcast);
auto eps_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto eps_with_broadcast = std::make_shared<op::Broadcast>(eps_label, Shape{2, 3}, AxisSet{0}); // Eps
auto eps_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto add1 = std::make_shared<op::Add>(eps_with_broadcast, variance_with_broadcast); auto eps_with_broadcast = std::make_shared<op::Broadcast>(eps_label, Shape{2, 3}, AxisSet{0});
auto sqrt_variance_eps = std::make_shared<op::Sqrt>(add1);
auto divide_mean_variance = std::make_shared<op::Divide>(input_diff_mean, sqrt_variance_eps); auto add1 = std::make_shared<op::Add>(eps_with_broadcast, variance_with_broadcast);
auto sqrt_variance_eps = std::make_shared<op::Sqrt>(add1);
//Gamma auto divide_mean_variance = std::make_shared<op::Divide>(input_diff_mean, sqrt_variance_eps);
auto gamma_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto gamma_with_broadcast = //Gamma
std::make_shared<op::Broadcast>(gamma_label, Shape{2, 3}, AxisSet{0}); auto gamma_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto multiply_gamma = auto gamma_with_broadcast =
std::make_shared<op::Multiply>(gamma_with_broadcast, divide_mean_variance); std::make_shared<op::Broadcast>(gamma_label, Shape{2, 3}, AxisSet{0});
auto multiply_gamma =
//Beta std::make_shared<op::Multiply>(gamma_with_broadcast, divide_mean_variance);
auto beta_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto beta_with_broadcast = std::make_shared<op::Broadcast>(beta_label, Shape{2, 3}, AxisSet{0}); //Beta
auto beta_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto add_beta = std::make_shared<op::Add>(beta_with_broadcast, multiply_gamma); auto beta_with_broadcast = std::make_shared<op::Broadcast>(beta_label, Shape{2, 3}, AxisSet{0});
// This completes fprop bn pattern
auto add_beta = std::make_shared<op::Add>(beta_with_broadcast, multiply_gamma);
//Define a call back that needs to called once the DFG matches the pattern // This completes fprop bn pattern
ngraph::pattern::gr_callback_fn callback =
[variance_label, mean_label, input, eps_label, gamma_label, beta_label]( //Define a call back that needs to called once the DFG matches the pattern
pattern::Matcher& m) { ngraph::pattern::gr_callback_fn callback =
NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against " [variance_label, mean_label, input, eps_label, gamma_label, beta_label](
<< m.match_root()->get_name(); pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against "
std::shared_ptr<Node> nn = nullptr; << m.match_root()->get_name();
//TODO - add assert's based on the matched node
auto pattern_map = m.get_pattern_map(); std::shared_ptr<Node> nn = nullptr;
NGRAPH_DEBUG << "Input: " << pattern_map[input]->get_name() << " " //TODO - add assert's based on the matched node
<< pattern_map[input]->get_shape().size(); auto pattern_map = m.get_pattern_map();
NGRAPH_DEBUG << "Variance: " << pattern_map[variance_label]->get_name() << " " NGRAPH_DEBUG << "Input: " << pattern_map[input]->get_name() << " "
<< pattern_map[variance_label]->get_shape().size(); << pattern_map[input]->get_shape().size();
NGRAPH_DEBUG << "Mean: " << pattern_map[mean_label]->get_name() << " " NGRAPH_DEBUG << "Variance: " << pattern_map[variance_label]->get_name() << " "
<< pattern_map[mean_label]->get_shape().size(); << pattern_map[variance_label]->get_shape().size();
NGRAPH_DEBUG << "eps: " << pattern_map[eps_label]->get_name() << " " NGRAPH_DEBUG << "Mean: " << pattern_map[mean_label]->get_name() << " "
<< pattern_map[eps_label]->get_shape().size(); << pattern_map[mean_label]->get_shape().size();
NGRAPH_DEBUG << "gamma: " << pattern_map[gamma_label]->get_name() << " " NGRAPH_DEBUG << "eps: " << pattern_map[eps_label]->get_name() << " "
<< pattern_map[gamma_label]->get_shape().size(); << pattern_map[eps_label]->get_shape().size();
NGRAPH_DEBUG << "beta: " << pattern_map[beta_label]->get_name() << " " NGRAPH_DEBUG << "gamma: " << pattern_map[gamma_label]->get_name() << " "
<< pattern_map[beta_label]->get_shape().size(); << pattern_map[gamma_label]->get_shape().size();
NGRAPH_DEBUG << "beta: " << pattern_map[beta_label]->get_name() << " "
// dont fuse if the inout doesnt have 4dims << pattern_map[beta_label]->get_shape().size();
if (pattern_map[input]->get_shape().size() != 4)
{ // dont fuse if the inout doesnt have 4dims
NGRAPH_DEBUG << "Input to bn doesnt not have a rank=4, so not fusing"; if (pattern_map[input]->get_shape().size() != 4)
return nn; {
} NGRAPH_DEBUG << "Input to bn doesnt not have a rank=4, so not fusing";
Shape bn_output_shape{m.match_root()->get_shape()}; return nn;
Shape m_bn_mean_shape{pattern_map[mean_label]->get_shape()}; }
Shape m_bn_variance_shape{pattern_map[variance_label]->get_shape()}; Shape bn_output_shape{m.match_root()->get_shape()};
Shape m_bn_mean_shape{pattern_map[mean_label]->get_shape()};
// get epsilon value Shape m_bn_variance_shape{pattern_map[variance_label]->get_shape()};
auto eps_ptr = std::dynamic_pointer_cast<op::Constant>(pattern_map[eps_label]);
double epsilon = *(reinterpret_cast<const double*>(eps_ptr->get_data_ptr())); // get epsilon value
auto bn_node = std::shared_ptr<Node>(new op::BatchNorm(epsilon, auto eps_ptr = std::dynamic_pointer_cast<op::Constant>(pattern_map[eps_label]);
pattern_map[gamma_label], double epsilon = *(reinterpret_cast<const double*>(eps_ptr->get_data_ptr()));
pattern_map[beta_label], auto bn_node = std::shared_ptr<Node>(new op::BatchNorm(epsilon,
pattern_map[input], pattern_map[gamma_label],
pattern_map[mean_label], pattern_map[beta_label],
pattern_map[variance_label])); pattern_map[input],
pattern_map[mean_label],
return bn_node; pattern_map[variance_label]));
};
return bn_node;
auto m = std::make_shared<ngraph::pattern::Matcher>(add_beta, callback); };
this->add_matcher(m);
} auto m = std::make_shared<ngraph::pattern::Matcher>(add_beta, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
{
Shape shape{2, 2, 1, 1};
auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape);
auto pbias = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto pbroadcast = std::make_shared<op::Broadcast>(pbias, shape, AxisSet{0, 1, 2, 3});
auto pconv1 = std::make_shared<op::Convolution>(data_batch,
filters,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto p_conv_bias = pbroadcast + pconv1;
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_conv_bias against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
std::shared_ptr<Node> nn;
auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root()->get_input_op(0));
auto bias = m.match_root()->get_input_op(1);
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias));
return conv_bias;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, callback);
this->add_matcher(m);
}
...@@ -40,9 +40,11 @@ public: ...@@ -40,9 +40,11 @@ public:
{ {
construct_gemm_pattern(); construct_gemm_pattern();
construct_fprop_bn(); construct_fprop_bn();
construct_conv_bias();
} }
private: private:
void construct_gemm_pattern(); void construct_gemm_pattern();
void construct_fprop_bn(); void construct_fprop_bn();
void construct_conv_bias();
}; };
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment