Commit 78c57f10 authored by Louis Feng's avatar Louis Feng

format.

parent 2fa8a678
/******************************************************************************* /*******************************************************************************
* Copyright 2017-2018 Intel Corporation * Copyright 2017-2018 Intel Corporation
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "cpu_fusion.hpp" #include "cpu_fusion.hpp"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <unordered_set> #include <unordered_set>
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/batch_norm.hpp" #include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convolution.hpp" #include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/divide.hpp" #include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp" #include "ngraph/ops/dot.hpp"
#include "ngraph/ops/multiply.hpp" #include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/pad.hpp" #include "ngraph/ops/pad.hpp"
#include "ngraph/ops/parameter.hpp" #include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/sqrt.hpp" #include "ngraph/ops/sqrt.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/sum.hpp" #include "ngraph/ops/sum.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/ops/conv_bias.hpp" #include "ngraph/runtime/cpu/ops/conv_bias.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp" #include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape, static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
std::shared_ptr<ngraph::Node> arg, std::shared_ptr<ngraph::Node> arg,
bool& transpose_w, bool& transpose_w,
ngraph::Shape& shape_w) ngraph::Shape& shape_w)
{ {
auto r_w = std::dynamic_pointer_cast<ngraph::op::Reshape>(reshape); auto r_w = std::dynamic_pointer_cast<ngraph::op::Reshape>(reshape);
if (!r_w) if (!r_w)
{ {
if (arg->get_shape().size() != 2) if (arg->get_shape().size() != 2)
{ {
NGRAPH_DEBUG << arg->get_name() << " 's rank != 2 " NGRAPH_DEBUG << arg->get_name() << " 's rank != 2 "
<< ngraph::vector_to_string(arg->get_shape()); << ngraph::vector_to_string(arg->get_shape());
return false; return false;
} }
return true; //nth to do; reshape isn't a reshape return true; //nth to do; reshape isn't a reshape
} }
if (r_w->get_shape().size() != 2) if (r_w->get_shape().size() != 2)
{ {
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " doesn't reshape into matrix" NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " doesn't reshape into matrix"
<< ngraph::vector_to_string(r_w->get_shape()); << ngraph::vector_to_string(r_w->get_shape());
return false; return false;
} }
auto io = r_w->get_input_order(); auto io = r_w->get_input_order();
if (r_w->get_shape().size() != arg->get_shape().size()) //reshape if (r_w->get_shape().size() != arg->get_shape().size()) //reshape
{ {
ngraph::AxisVector dio(io.size()); ngraph::AxisVector dio(io.size());
std::iota(begin(dio), end(dio), 0); std::iota(begin(dio), end(dio), 0);
if (io != dio) //we can't reshape and transpose at the same time if (io != dio) //we can't reshape and transpose at the same time
{ {
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order " NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order "
<< ngraph::vector_to_string(io); << ngraph::vector_to_string(io);
NGRAPH_DEBUG << "r_w shape = " << ngraph::vector_to_string(r_w->get_shape()); NGRAPH_DEBUG << "r_w shape = " << ngraph::vector_to_string(r_w->get_shape());
NGRAPH_DEBUG << "arg shape = " << ngraph::vector_to_string(arg->get_shape()); NGRAPH_DEBUG << "arg shape = " << ngraph::vector_to_string(arg->get_shape());
return false; return false;
} }
shape_w = r_w->get_shape(); shape_w = r_w->get_shape();
} }
else else
{ {
if (io == ngraph::AxisVector{1, 0}) if (io == ngraph::AxisVector{1, 0})
{ {
transpose_w = true; transpose_w = true;
} }
//otherwise no-op reshape //otherwise no-op reshape
} }
return true; return true;
} }
template <typename T> template <typename T>
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order) static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
{ {
if (input.size() != order.size()) if (input.size() != order.size())
{ {
throw "input and order sizes don't match!"; throw "input and order sizes don't match!";
} }
std::vector<T> output(input.size()); std::vector<T> output(input.size());
for (size_t i = 0; i < order.size(); i++) for (size_t i = 0; i < order.size(); i++)
{ {
output[i] = input.at(order.at(i)); output[i] = input.at(order.at(i));
} }
return output; return output;
} }
void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias_pattern() void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias_pattern()
{ {
Shape shape_w{2, 4}; Shape shape_w{2, 4};
Shape shape_x{4, 1}; Shape shape_x{4, 1};
Shape shape_b{1}; Shape shape_b{1};
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w); auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x); auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b); auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto pmmb = std::make_shared<op::MatmulBias>( auto pmmb = std::make_shared<op::MatmulBias>(
W, x, nullptr, W->get_shape(), x->get_shape(), false, false); W, x, nullptr, W->get_shape(), x->get_shape(), false, false);
auto pbroadcast = std::make_shared<op::Broadcast>(b, pmmb->get_shape(), AxisSet{0}); auto pbroadcast = std::make_shared<op::Broadcast>(b, pmmb->get_shape(), AxisSet{0});
auto padd = pmmb + pbroadcast; auto padd = pmmb + pbroadcast;
ngraph::pattern::gr_callback_fn callback = [W, x](pattern::Matcher& m) { ngraph::pattern::gr_callback_fn callback = [W, x](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_matmulbias_pattern against node = " NGRAPH_DEBUG << "In callback for construct_matmulbias_pattern against node = "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto mpattern = m.match_root(); //add auto mpattern = m.match_root(); //add
auto m_matmul = mpattern->get_input_op(0); auto m_matmul = mpattern->get_input_op(0);
auto m_broadcast = mpattern->get_input_op(1); auto m_broadcast = mpattern->get_input_op(1);
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
return m_matmul->copy_with_new_args( return m_matmul->copy_with_new_args(
NodeVector{pattern_map[W], pattern_map[x], m_broadcast}); NodeVector{pattern_map[W], pattern_map[x], m_broadcast});
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback);
this->add_matcher(m); this->add_matcher(m);
} }
void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern() void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul_pattern()
{ {
Shape shape_w{2, 4}; Shape shape_w{2, 4};
Shape shape_x{4, 1}; Shape shape_x{4, 1};
Shape shape_b{1}; Shape shape_b{1};
Shape shape_dot{2, 1}; Shape shape_dot{2, 1};
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w); auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x); auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
auto reshape_pred = [](std::shared_ptr<Node> n) { auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n)); return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
}; };
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred); auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred); auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto pdot = std::make_shared<op::Dot>(skip_w, skip_x); auto pdot = std::make_shared<op::Dot>(skip_w, skip_x);
ngraph::pattern::gr_callback_fn callback = [W, x](pattern::Matcher& m) { ngraph::pattern::gr_callback_fn callback = [W, x](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_matmul_pattern against node = " NGRAPH_DEBUG << "In callback for construct_matmul_pattern against node = "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
std::shared_ptr<Node> nn; std::shared_ptr<Node> nn;
auto mpattern = m.match_root(); auto mpattern = m.match_root();
auto dot = m.match_root(); auto dot = m.match_root();
if (mpattern->get_element_type() != element::f32) if (mpattern->get_element_type() != element::f32)
{ {
NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!"; NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!";
return nn; return nn;
} }
if (dot->get_shape().size() != 2) if (dot->get_shape().size() != 2)
{ {
NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!"; NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!";
return nn; return nn;
} }
if (shape_size(dot->get_shape()) == 0) if (shape_size(dot->get_shape()) == 0)
{ {
NGRAPH_DEBUG << "dot has a zero dimension"; NGRAPH_DEBUG << "dot has a zero dimension";
return nn; return nn;
} }
bool transpose_w = false; bool transpose_w = false;
Shape shape_arg0{pattern_map[W]->get_shape()}; Shape shape_arg0{pattern_map[W]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0)) if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0))
{ {
return nn; return nn;
} }
bool transpose_x = false; bool transpose_x = false;
Shape shape_arg1{pattern_map[x]->get_shape()}; Shape shape_arg1{pattern_map[x]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1)) if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1))
{ {
return nn; return nn;
} }
auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W], auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W],
pattern_map[x], pattern_map[x],
nullptr, nullptr,
shape_arg0, shape_arg0,
shape_arg1, shape_arg1,
transpose_w, transpose_w,
transpose_x)); transpose_x));
return cg; return cg;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(pdot, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(pdot, callback);
this->add_matcher(m); this->add_matcher(m);
} }
void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn() void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
{ {
// construct varaiance // construct varaiance
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2}); auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3}); auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
auto input_sq = std::make_shared<op::Multiply>(input, input); auto input_sq = std::make_shared<op::Multiply>(input, input);
auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0}); auto sum_input = std::make_shared<op::Sum>(input, AxisSet{0});
auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input); auto square_sumed_input = std::make_shared<op::Multiply>(sum_input, sum_input);
auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0}); auto sum_squared_input = std::make_shared<op::Sum>(input_sq, AxisSet{0});
auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N); auto avg_input_sum_sq = std::make_shared<op::Divide>(square_sumed_input, N);
auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq); auto xmu = std::make_shared<op::Subtract>(sum_squared_input, avg_input_sum_sq);
auto variance = std::make_shared<op::Divide>(xmu, N); auto variance = std::make_shared<op::Divide>(xmu, N);
auto variance_label = auto variance_label =
std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance}); std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
auto variance_with_broadcast = auto variance_with_broadcast =
std::make_shared<op::Broadcast>(variance_label, Shape{2, 3}, AxisSet{0}); std::make_shared<op::Broadcast>(variance_label, Shape{2, 3}, AxisSet{0});
// construct mean // construct mean
auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0}); auto sum_input1 = std::make_shared<op::Sum>(input, AxisSet{0});
auto mean = std::make_shared<op::Divide>(sum_input1, N); auto mean = std::make_shared<op::Divide>(sum_input1, N);
auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean}); auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
auto mean_with_broadcast = std::make_shared<op::Broadcast>(mean_label, Shape{2, 3}, AxisSet{0}); auto mean_with_broadcast = std::make_shared<op::Broadcast>(mean_label, Shape{2, 3}, AxisSet{0});
auto input_diff_mean = std::make_shared<op::Subtract>(input, mean_with_broadcast); auto input_diff_mean = std::make_shared<op::Subtract>(input, mean_with_broadcast);
// Eps // Eps
auto eps_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3}); auto eps_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto eps_with_broadcast = std::make_shared<op::Broadcast>(eps_label, Shape{2, 3}, AxisSet{0}); auto eps_with_broadcast = std::make_shared<op::Broadcast>(eps_label, Shape{2, 3}, AxisSet{0});
auto add1 = std::make_shared<op::Add>(eps_with_broadcast, variance_with_broadcast); auto add1 = std::make_shared<op::Add>(eps_with_broadcast, variance_with_broadcast);
auto sqrt_variance_eps = std::make_shared<op::Sqrt>(add1); auto sqrt_variance_eps = std::make_shared<op::Sqrt>(add1);
auto divide_mean_variance = std::make_shared<op::Divide>(input_diff_mean, sqrt_variance_eps); auto divide_mean_variance = std::make_shared<op::Divide>(input_diff_mean, sqrt_variance_eps);
//Gamma //Gamma
auto gamma_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3}); auto gamma_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto gamma_with_broadcast = auto gamma_with_broadcast =
std::make_shared<op::Broadcast>(gamma_label, Shape{2, 3}, AxisSet{0}); std::make_shared<op::Broadcast>(gamma_label, Shape{2, 3}, AxisSet{0});
auto multiply_gamma = auto multiply_gamma =
std::make_shared<op::Multiply>(gamma_with_broadcast, divide_mean_variance); std::make_shared<op::Multiply>(gamma_with_broadcast, divide_mean_variance);
//Beta //Beta
auto beta_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3}); auto beta_label = std::make_shared<pattern::op::Label>(element::f32, Shape{3});
auto beta_with_broadcast = std::make_shared<op::Broadcast>(beta_label, Shape{2, 3}, AxisSet{0}); auto beta_with_broadcast = std::make_shared<op::Broadcast>(beta_label, Shape{2, 3}, AxisSet{0});
auto add_beta = std::make_shared<op::Add>(beta_with_broadcast, multiply_gamma); auto add_beta = std::make_shared<op::Add>(beta_with_broadcast, multiply_gamma);
// This completes fprop bn pattern // This completes fprop bn pattern
//Define a call back that needs to called once the DFG matches the pattern //Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback = ngraph::pattern::gr_callback_fn callback =
[variance_label, mean_label, input, eps_label, gamma_label, beta_label]( [variance_label, mean_label, input, eps_label, gamma_label, beta_label](
pattern::Matcher& m) { pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_bn pattern against "
<< m.match_root()->get_name(); << m.match_root()->get_name();
std::shared_ptr<Node> nn = nullptr; std::shared_ptr<Node> nn = nullptr;
//TODO - add assert's based on the matched node //TODO - add assert's based on the matched node
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
NGRAPH_DEBUG << "Input: " << pattern_map[input]->get_name() << " " NGRAPH_DEBUG << "Input: " << pattern_map[input]->get_name() << " "
<< pattern_map[input]->get_shape().size(); << pattern_map[input]->get_shape().size();
NGRAPH_DEBUG << "Variance: " << pattern_map[variance_label]->get_name() << " " NGRAPH_DEBUG << "Variance: " << pattern_map[variance_label]->get_name() << " "
<< pattern_map[variance_label]->get_shape().size(); << pattern_map[variance_label]->get_shape().size();
NGRAPH_DEBUG << "Mean: " << pattern_map[mean_label]->get_name() << " " NGRAPH_DEBUG << "Mean: " << pattern_map[mean_label]->get_name() << " "
<< pattern_map[mean_label]->get_shape().size(); << pattern_map[mean_label]->get_shape().size();
NGRAPH_DEBUG << "eps: " << pattern_map[eps_label]->get_name() << " " NGRAPH_DEBUG << "eps: " << pattern_map[eps_label]->get_name() << " "
<< pattern_map[eps_label]->get_shape().size(); << pattern_map[eps_label]->get_shape().size();
NGRAPH_DEBUG << "gamma: " << pattern_map[gamma_label]->get_name() << " " NGRAPH_DEBUG << "gamma: " << pattern_map[gamma_label]->get_name() << " "
<< pattern_map[gamma_label]->get_shape().size(); << pattern_map[gamma_label]->get_shape().size();
NGRAPH_DEBUG << "beta: " << pattern_map[beta_label]->get_name() << " " NGRAPH_DEBUG << "beta: " << pattern_map[beta_label]->get_name() << " "
<< pattern_map[beta_label]->get_shape().size(); << pattern_map[beta_label]->get_shape().size();
// dont fuse if the inout doesnt have 4dims // dont fuse if the inout doesnt have 4dims
if (pattern_map[input]->get_shape().size() != 4) if (pattern_map[input]->get_shape().size() != 4)
{ {
NGRAPH_DEBUG << "Input to bn doesnt not have a rank=4, so not fusing"; NGRAPH_DEBUG << "Input to bn doesnt not have a rank=4, so not fusing";
return nn; return nn;
} }
Shape bn_output_shape{m.match_root()->get_shape()}; Shape bn_output_shape{m.match_root()->get_shape()};
Shape m_bn_mean_shape{pattern_map[mean_label]->get_shape()}; Shape m_bn_mean_shape{pattern_map[mean_label]->get_shape()};
Shape m_bn_variance_shape{pattern_map[variance_label]->get_shape()}; Shape m_bn_variance_shape{pattern_map[variance_label]->get_shape()};
// get epsilon value // get epsilon value
auto eps_ptr = std::dynamic_pointer_cast<op::Constant>(pattern_map[eps_label]); auto eps_ptr = std::dynamic_pointer_cast<op::Constant>(pattern_map[eps_label]);
double epsilon = *(reinterpret_cast<const double*>(eps_ptr->get_data_ptr())); double epsilon = *(reinterpret_cast<const double*>(eps_ptr->get_data_ptr()));
auto bn_node = std::shared_ptr<Node>(new op::BatchNorm(epsilon, auto bn_node = std::shared_ptr<Node>(new op::BatchNorm(epsilon,
pattern_map[gamma_label], pattern_map[gamma_label],
pattern_map[beta_label], pattern_map[beta_label],
pattern_map[input], pattern_map[input],
pattern_map[mean_label], pattern_map[mean_label],
pattern_map[variance_label])); pattern_map[variance_label]));
return bn_node; return bn_node;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(add_beta, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(add_beta, callback);
this->add_matcher(m); this->add_matcher(m);
} }
static bool static bool
zero_padded_conv_consistency_check(const std::shared_ptr<ngraph::Node>& match_root, zero_padded_conv_consistency_check(const std::shared_ptr<ngraph::Node>& match_root,
const std::shared_ptr<ngraph::op::Constant>& pad_value_op, const std::shared_ptr<ngraph::op::Constant>& pad_value_op,
const std::shared_ptr<ngraph::Node>& pad_input, const std::shared_ptr<ngraph::Node>& pad_input,
const std::shared_ptr<ngraph::op::Pad>& matched_pad, const std::shared_ptr<ngraph::op::Pad>& matched_pad,
const std::shared_ptr<ngraph::op::Convolution>& matched_conv, const std::shared_ptr<ngraph::op::Convolution>& matched_conv,
size_t batch_index, size_t batch_index,
size_t channel_index) size_t channel_index)
{ {
// Only match float32 convolutions // Only match float32 convolutions
if (match_root->get_element_type() != ngraph::element::f32) if (match_root->get_element_type() != ngraph::element::f32)
{ {
return false; return false;
} }
// Only match zero padding // Only match zero padding
if (pad_value_op->get_vector<float>().at(0) != 0.0f) if (pad_value_op->get_vector<float>().at(0) != 0.0f)
{ {
return false; return false;
} }
// Only match 4D tensors // Only match 4D tensors
if (pad_input->get_shape().size() != 4) if (pad_input->get_shape().size() != 4)
{ {
return false; return false;
} }
// Only match no interior padding // Only match no interior padding
if (matched_pad->get_padding_interior() != ngraph::Shape(pad_input->get_shape().size())) if (matched_pad->get_padding_interior() != ngraph::Shape(pad_input->get_shape().size()))
{ {
return false; return false;
} }
// Only match convolutions with no padding specification // Only match convolutions with no padding specification
if (matched_conv->get_padding_below() != ngraph::CoordinateDiff(2) || if (matched_conv->get_padding_below() != ngraph::CoordinateDiff(2) ||
matched_conv->get_padding_above() != ngraph::CoordinateDiff(2)) matched_conv->get_padding_above() != ngraph::CoordinateDiff(2))
{ {
return false; return false;
} }
// Only match no padding in the batch dimension // Only match no padding in the batch dimension
if (matched_pad->get_padding_above().at(batch_index) != 0 || if (matched_pad->get_padding_above().at(batch_index) != 0 ||
matched_pad->get_padding_below().at(batch_index) != 0) matched_pad->get_padding_below().at(batch_index) != 0)
{ {
return false; return false;
} }
// Only match no padding in the channel dimension // Only match no padding in the channel dimension
if (matched_pad->get_padding_above().at(channel_index) != 0 || if (matched_pad->get_padding_above().at(channel_index) != 0 ||
matched_pad->get_padding_below().at(channel_index) != 0) matched_pad->get_padding_below().at(channel_index) != 0)
{ {
return false; return false;
} }
return true; return true;
} }
void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv() void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_reshaped_conv()
{ {
auto pad_input = std::make_shared<pattern::op::Label>(element::f32, Shape{}); auto pad_input = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto pad_value = std::make_shared<pattern::op::Label>(element::f32, Shape{}); auto pad_value = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto pad = std::make_shared<op::Pad>(pad_input, pad_value, Shape{}, Shape{}, Shape{}); auto pad = std::make_shared<op::Pad>(pad_input, pad_value, Shape{}, Shape{}, Shape{});
auto pad_label = std::make_shared<pattern::op::Label>(pad, nullptr, NodeVector{pad}); auto pad_label = std::make_shared<pattern::op::Label>(pad, nullptr, NodeVector{pad});
auto reshape = std::make_shared<op::Reshape>(pad_label, AxisVector{}, Shape{1, 1, 1, 1}); auto reshape = std::make_shared<op::Reshape>(pad_label, AxisVector{}, Shape{1, 1, 1, 1});
auto reshape_label = auto reshape_label =
std::make_shared<pattern::op::Label>(reshape, nullptr, NodeVector{reshape}); std::make_shared<pattern::op::Label>(reshape, nullptr, NodeVector{reshape});
auto conv_filter = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1}); auto conv_filter = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto conv = std::make_shared<op::Convolution>(reshape_label, auto conv = std::make_shared<op::Convolution>(reshape_label,
conv_filter, conv_filter,
Strides{1, 1}, Strides{1, 1},
Strides{1, 1}, Strides{1, 1},
CoordinateDiff{1, 1}, CoordinateDiff{1, 1},
CoordinateDiff{1, 1}, CoordinateDiff{1, 1},
Strides{1, 1}); Strides{1, 1});
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv}); auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
ngraph::pattern::gr_callback_fn callback = ngraph::pattern::gr_callback_fn callback =
[pad_input, pad_value, pad_label, reshape_label, conv_filter, conv_label]( [pad_input, pad_value, pad_label, reshape_label, conv_filter, conv_label](
pattern::Matcher& m) -> std::shared_ptr<Node> { pattern::Matcher& m) -> std::shared_ptr<Node> {
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto pad_value_op = std::dynamic_pointer_cast<op::Constant>(pattern_map[pad_value]); auto pad_value_op = std::dynamic_pointer_cast<op::Constant>(pattern_map[pad_value]);
const auto& matched_conv = const auto& matched_conv =
std::dynamic_pointer_cast<op::Convolution>(pattern_map[conv_label]); std::dynamic_pointer_cast<op::Convolution>(pattern_map[conv_label]);
const auto& matched_pad = std::dynamic_pointer_cast<op::Pad>(pattern_map[pad_label]); const auto& matched_pad = std::dynamic_pointer_cast<op::Pad>(pattern_map[pad_label]);
const auto& matched_reshape = const auto& matched_reshape =
std::dynamic_pointer_cast<op::Reshape>(pattern_map[reshape_label]); std::dynamic_pointer_cast<op::Reshape>(pattern_map[reshape_label]);
const auto& input_order = matched_reshape->get_input_order(); const auto& input_order = matched_reshape->get_input_order();
auto hoisted_reshape_output_shape = auto hoisted_reshape_output_shape =
apply_permutation<Shape::value_type>(pattern_map[pad_input]->get_shape(), input_order); apply_permutation<Shape::value_type>(pattern_map[pad_input]->get_shape(), input_order);
auto hoisted_reshape = std::make_shared<op::Reshape>( auto hoisted_reshape = std::make_shared<op::Reshape>(
pattern_map[pad_input], pattern_map[pad_input],
input_order, input_order,
Shape(hoisted_reshape_output_shape.begin(), hoisted_reshape_output_shape.end())); Shape(hoisted_reshape_output_shape.begin(), hoisted_reshape_output_shape.end()));
if (!zero_padded_conv_consistency_check(m.match_root(), if (!zero_padded_conv_consistency_check(m.match_root(),
pad_value_op, pad_value_op,
pattern_map[pad_input], pattern_map[pad_input],
matched_pad, matched_pad,
matched_conv, matched_conv,
input_order[0], input_order[0],
input_order[1])) input_order[1]))
{ {
return nullptr; return nullptr;
} }
CoordinateDiff padding_below{static_cast<CoordinateDiff::value_type>( CoordinateDiff padding_below{static_cast<CoordinateDiff::value_type>(
matched_pad->get_padding_below().at(input_order[2])), matched_pad->get_padding_below().at(input_order[2])),
static_cast<CoordinateDiff::value_type>( static_cast<CoordinateDiff::value_type>(
matched_pad->get_padding_below().at(input_order[3]))}; matched_pad->get_padding_below().at(input_order[3]))};
CoordinateDiff padding_above{static_cast<CoordinateDiff::value_type>( CoordinateDiff padding_above{static_cast<CoordinateDiff::value_type>(
matched_pad->get_padding_above().at(input_order[2])), matched_pad->get_padding_above().at(input_order[2])),
static_cast<CoordinateDiff::value_type>( static_cast<CoordinateDiff::value_type>(
matched_pad->get_padding_above().at(input_order[3]))}; matched_pad->get_padding_above().at(input_order[3]))};
auto zero_padded_conv = auto zero_padded_conv =
std::make_shared<op::Convolution>(hoisted_reshape, std::make_shared<op::Convolution>(hoisted_reshape,
pattern_map[conv_filter], pattern_map[conv_filter],
matched_conv->get_window_movement_strides(), matched_conv->get_window_movement_strides(),
matched_conv->get_window_dilation_strides(), matched_conv->get_window_dilation_strides(),
padding_below, padding_below,
padding_above, padding_above,
matched_conv->get_data_dilation_strides()); matched_conv->get_data_dilation_strides());
return zero_padded_conv; return zero_padded_conv;
}; };
this->add_matcher(std::make_shared<ngraph::pattern::Matcher>(conv_label, callback)); this->add_matcher(std::make_shared<ngraph::pattern::Matcher>(conv_label, callback));
} }
void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv() void ngraph::runtime::cpu::pass::CPUFusion::construct_zero_padded_conv()
{ {
auto pad_input = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1}); auto pad_input = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto pad_value = std::make_shared<pattern::op::Label>(element::f32, Shape{}); auto pad_value = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto pad = std::make_shared<op::Pad>( auto pad = std::make_shared<op::Pad>(
pad_input, pad_value, Shape{0, 0, 0, 0}, Shape{0, 0, 0, 0}, Shape{0, 0, 0, 0}); pad_input, pad_value, Shape{0, 0, 0, 0}, Shape{0, 0, 0, 0}, Shape{0, 0, 0, 0});
auto pad_label = std::make_shared<pattern::op::Label>(pad, nullptr, NodeVector{pad}); auto pad_label = std::make_shared<pattern::op::Label>(pad, nullptr, NodeVector{pad});
auto conv_filter = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1}); auto conv_filter = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto conv = std::make_shared<op::Convolution>(pad_label, auto conv = std::make_shared<op::Convolution>(pad_label,
conv_filter, conv_filter,
Strides{1, 1}, Strides{1, 1},
Strides{1, 1}, Strides{1, 1},
CoordinateDiff{1, 1}, CoordinateDiff{1, 1},
CoordinateDiff{1, 1}, CoordinateDiff{1, 1},
Strides{1, 1}); Strides{1, 1});
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv}); auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
ngraph::pattern::gr_callback_fn callback = ngraph::pattern::gr_callback_fn callback =
[pad_input, pad_value, pad_label, conv_filter, conv_label]( [pad_input, pad_value, pad_label, conv_filter, conv_label](
pattern::Matcher& m) -> std::shared_ptr<Node> { pattern::Matcher& m) -> std::shared_ptr<Node> {
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto pad_value_op = std::dynamic_pointer_cast<op::Constant>(pattern_map[pad_value]); auto pad_value_op = std::dynamic_pointer_cast<op::Constant>(pattern_map[pad_value]);
const auto& matched_conv = const auto& matched_conv =
std::dynamic_pointer_cast<op::Convolution>(pattern_map[conv_label]); std::dynamic_pointer_cast<op::Convolution>(pattern_map[conv_label]);
const auto& matched_pad = std::dynamic_pointer_cast<op::Pad>(pattern_map[pad_label]); const auto& matched_pad = std::dynamic_pointer_cast<op::Pad>(pattern_map[pad_label]);
if (!zero_padded_conv_consistency_check(m.match_root(), if (!zero_padded_conv_consistency_check(m.match_root(),
pad_value_op, pad_value_op,
pattern_map[pad_input], pattern_map[pad_input],
matched_pad, matched_pad,
matched_conv, matched_conv,
0, 0,
1)) 1))
{ {
return nullptr; return nullptr;
} }
CoordinateDiff padding_below{ CoordinateDiff padding_below{
static_cast<CoordinateDiff::value_type>(matched_pad->get_padding_below().at(2)), static_cast<CoordinateDiff::value_type>(matched_pad->get_padding_below().at(2)),
static_cast<CoordinateDiff::value_type>(matched_pad->get_padding_below().at(3))}; static_cast<CoordinateDiff::value_type>(matched_pad->get_padding_below().at(3))};
CoordinateDiff padding_above{ CoordinateDiff padding_above{
static_cast<CoordinateDiff::value_type>(matched_pad->get_padding_above().at(2)), static_cast<CoordinateDiff::value_type>(matched_pad->get_padding_above().at(2)),
static_cast<CoordinateDiff::value_type>(matched_pad->get_padding_above().at(3))}; static_cast<CoordinateDiff::value_type>(matched_pad->get_padding_above().at(3))};
auto zero_padded_conv = auto zero_padded_conv =
std::make_shared<op::Convolution>(pattern_map[pad_input], std::make_shared<op::Convolution>(pattern_map[pad_input],
pattern_map[conv_filter], pattern_map[conv_filter],
matched_conv->get_window_movement_strides(), matched_conv->get_window_movement_strides(),
matched_conv->get_window_dilation_strides(), matched_conv->get_window_dilation_strides(),
padding_below, padding_below,
padding_above, padding_above,
matched_conv->get_data_dilation_strides()); matched_conv->get_data_dilation_strides());
return zero_padded_conv; return zero_padded_conv;
}; };
this->add_matcher(std::make_shared<ngraph::pattern::Matcher>(conv_label, callback)); this->add_matcher(std::make_shared<ngraph::pattern::Matcher>(conv_label, callback));
} }
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias() void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
{ {
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape); auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape); auto filters = std::make_shared<pattern::op::Label>(element::f32, shape);
auto pbias = std::make_shared<pattern::op::Label>(element::f32, Shape{}); auto pbias = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto pbroadcast = std::make_shared<op::Broadcast>(pbias, shape, AxisSet{0, 1, 2, 3}); auto pbroadcast = std::make_shared<op::Broadcast>(pbias, shape, AxisSet{0, 1, 2, 3});
auto pconv1 = std::make_shared<op::Convolution>(data_batch, auto pconv1 = std::make_shared<op::Convolution>(data_batch,
filters, filters,
Strides{1, 1}, Strides{1, 1},
Strides{1, 1}, Strides{1, 1},
CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
Strides{1, 1}); Strides{1, 1});
auto p_conv_bias = pbroadcast + pconv1; auto p_conv_bias = pbroadcast + pconv1;
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) { ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_conv_bias against node = " NGRAPH_DEBUG << "In callback for construct_conv_bias against node = "
<< m.match_root()->get_name(); << m.match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
std::shared_ptr<Node> nn; std::shared_ptr<Node> nn;
auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root()->get_input_op(0)); auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root()->get_input_op(0));
auto bias = m.match_root()->get_input_op(1)->get_input_op(0); auto bias = m.match_root()->get_input_op(1)->get_input_op(0);
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias)); auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias));
return conv_bias; return conv_bias;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, callback);
this->add_matcher(m); this->add_matcher(m);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment