Commit 5b760fff authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Adam Procter

Relu(BatchNorm) Fusion (#757)

parent 334ae2ad
......@@ -211,6 +211,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/op/convert_layout.cpp
runtime/cpu/op/sigmoid.cpp
runtime/cpu/op/matmul_bias.cpp
runtime/cpu/op/batch_norm_relu.cpp
runtime/cpu/pass/cpu_assignment.cpp
runtime/cpu/pass/cpu_fusion.cpp
runtime/cpu/pass/cpu_layout.cpp
......
......@@ -90,6 +90,7 @@
#include "ngraph/runtime/cpu/cpu_kernel_emitters.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
......@@ -474,6 +475,85 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormRelu)
{
if (!mkldnn_utils::use_mkldnn_kernel(node))
{
throw ngraph_error("BatchNormRelu is only supported with MKLDNN kernel.");
}
const ngraph::op::BatchNormRelu* batchnorm =
static_cast<const ngraph::op::BatchNormRelu*>(node);
if (!batchnorm->get_training_flag() || batchnorm->get_inputs().size() != 3)
{
throw ngraph_error("Only training batchnorm should have been fused");
}
const float ops_scale = 1.f;
const float ops_alpha = -0.f; // relu negative slope
const float ops_beta = 0.f;
mkldnn::post_ops ops;
ops.append_eltwise(ops_scale, mkldnn::algorithm::eltwise_relu, ops_alpha, ops_beta);
writer.block_begin();
writer << "{\n";
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
auto input_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto result_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto mean_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 1);
auto variance_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 2);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()};
auto input_desc = mkldnn_emitter->build_memory_descriptor(args[2], input_format);
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto results_desc = mkldnn_emitter->build_memory_descriptor(out[0], result_format);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(out[1], mean_format);
auto variance_desc =
mkldnn_emitter->build_memory_descriptor(out[2], variance_format);
auto batchnorm_index =
mkldnn_emitter->build_batchnorm_forward(input_desc,
weights_desc,
results_desc,
mean_desc,
variance_desc,
batchnorm->get_eps_value(),
batchnorm->get_training_flag(),
ops);
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", "
<< args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", bn_weights.data());\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", "
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3]) << ", "
<< out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4]) << ", "
<< out[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n";
writer.block_end();
writer << "}\n";
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormBackprop)
{
......
......@@ -110,6 +110,7 @@
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
#include "ngraph/runtime/cpu/cpu_tracing.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
......@@ -261,6 +262,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::AvgPoolBackprop), &runtime::cpu::CPU_Emitter::emit<op::AvgPoolBackprop>},
{TI(ngraph::op::Pad), &runtime::cpu::CPU_Emitter::emit<op::Pad>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::CPU_Emitter::emit<op::BatchNorm>},
{TI(ngraph::op::BatchNormRelu), &runtime::cpu::CPU_Emitter::emit<op::BatchNormRelu>},
{TI(ngraph::op::BatchNormBackprop), &runtime::cpu::CPU_Emitter::emit<op::BatchNormBackprop>},
{TI(ngraph::op::MaxPoolBackprop), &runtime::cpu::CPU_Emitter::emit<op::MaxPoolBackprop>},
{TI(ngraph::op::Product), &runtime::cpu::CPU_Emitter::emit<op::Product>},
......
......@@ -578,7 +578,8 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps,
bool bn_training_flag)
bool bn_training_flag,
const mkldnn::post_ops& pops)
{
size_t input_index = build_memory_primitive(input_desc);
size_t weights_index = build_memory_primitive(weights_desc);
......@@ -586,6 +587,9 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
size_t mean_index = build_memory_primitive(mean_desc);
size_t variance_index = build_memory_primitive(variance_desc);
mkldnn::primitive_attr bn_attr;
bn_attr.set_post_ops(pops);
if (bn_training_flag)
{
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
......@@ -593,6 +597,7 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
bn_attr,
mkldnn_utils::global_cpu_engine},
mkldnn::primitive::at(*m_mkldnn_primitives[input_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_index]),
......@@ -612,6 +617,7 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
eps,
mkldnn::batch_normalization_flag::use_scale_shift |
mkldnn::batch_normalization_flag::use_global_stats},
bn_attr,
mkldnn_utils::global_cpu_engine},
mkldnn::primitive::at(*m_mkldnn_primitives[input_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[mean_index]),
......
......@@ -171,7 +171,8 @@ namespace ngraph
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps,
bool bn_training_flag);
bool bn_training_flag,
const mkldnn::post_ops& pops = mkldnn::post_ops());
size_t build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& input_desc,
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
ngraph::op::BatchNormRelu::BatchNormRelu(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input)
: RequiresTensorViewArgs("BatchNormRelu", {gamma, beta, input})
, m_bn_input_shape(input->get_shape())
, m_epsilon(eps)
, m_training(true)
{
if (m_bn_input_shape.size() != 4)
{
throw ngraph_error("input tensor to batchnorm must have rank 4");
}
else
{
this->m_bn_variance_shape.push_back(input->get_shape()[1]);
this->m_bn_mean_shape.push_back(input->get_shape()[1]);
}
if (m_bn_input_shape[1] == 0)
{
throw ngraph_error(
"input tensor must have at least one channel axis for batch normalization");
}
auto et = input->get_element_type();
const char* input_names[] = {"gamma", "beta"};
for (size_t i = 0; i < 2; i++)
{
if (get_input_op(i)->get_element_type() != et)
{
auto err_msg = std::string("The element type of ") + input_names[i] +
" isn't equal to input data's type";
throw ngraph_error(err_msg.c_str());
}
}
if ((gamma->get_shape().size() != 1) || (beta->get_shape().size() != 1))
{
throw ngraph_error("gamma and beta shoud have rank 1");
}
if (gamma->get_shape().size() != beta->get_shape().size())
{
throw ngraph_error("gamma and beta rank does not match");
}
if (gamma->get_element_type() != beta->get_element_type())
{
throw ngraph_error("gamma and beta element type does not match");
}
add_output(input->get_element_type(), m_bn_input_shape);
add_output(input->get_element_type(), m_bn_mean_shape);
add_output(input->get_element_type(), m_bn_variance_shape);
}
std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormRelu::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<BatchNormRelu>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/util/requires_tensor_view_args.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
class BatchNormRelu : public util::RequiresTensorViewArgs
{
public:
BatchNormRelu(double eps,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input);
const Shape& get_inputs_shape() const { return m_bn_input_shape; }
const Shape& get_variance_shape() const { return m_bn_variance_shape; }
const Shape& get_mean_shape() const { return m_bn_mean_shape; }
double get_eps_value() const { return m_epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_training_flag() const { return m_training; }
private:
Shape m_bn_input_shape;
Shape m_bn_variance_shape;
Shape m_bn_mean_shape;
double m_epsilon;
bool m_training;
};
}
}
......@@ -33,6 +33,7 @@
#include "ngraph/op/relu.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
......@@ -112,6 +113,19 @@ namespace ngraph
convolution->set_op_annotations(op_annotations);
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormRelu)
{
if (node->get_input_op(2 /*input data*/)->get_shape().size() == 4)
{
auto bn_relu = static_cast<op::BatchNormRelu*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
bn_relu->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBackpropData)
{
......@@ -411,6 +425,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionRelu),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionRelu>},
{TI(ngraph::op::BatchNormRelu),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNormRelu>},
{TI(ngraph::op::ConvolutionBackpropData),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::ConvolutionBackpropFilters),
......
......@@ -44,6 +44,7 @@
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
......@@ -681,6 +682,80 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
{
auto input_shape = Shape{1, 2, 2, 2};
auto input = std::make_shared<pattern::op::Label>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto var_shape = Shape{2};
auto gamma_shape = Shape{2};
auto gamma = std::make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = std::make_shared<pattern::op::Label>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{1, 2, 2, 2};
auto bn = std::make_shared<op::BatchNorm>(eps, gamma, beta, input);
auto goe = std::make_shared<op::GetOutputElement>(bn, 0);
auto prelu = std::make_shared<op::Relu>(goe);
ngraph::pattern::gr_callback_fn callback = [input, gamma, beta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_batch_norm_relu against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto m_bn = std::dynamic_pointer_cast<op::BatchNorm>(
m.match_root()->get_input_op(0)->get_inputs().at(0).get_output().get_node());
if (!m_bn->get_training_flag())
{
NGRAPH_DEBUG << " This is an inference batchnorm, so skipping fusion";
return false;
}
//as of now, only MKLDNN supports this fusion
//and it requires input data's rank to be equal to 4
if (pattern_map[input]->get_shape().size() != 4)
{
NGRAPH_DEBUG << " Input data's rank isn't equal to 4. Shape = "
<< pattern_map[input]->get_shape().size();
return false;
}
std::vector<std::shared_ptr<Node>> mgoes(m_bn->get_outputs().size());
for (auto bn_in : m_bn->get_output_inputs(0))
{
auto mgoe = std::dynamic_pointer_cast<op::GetOutputElement>(bn_in->get_node());
mgoes[mgoe->get_n()] = mgoe;
}
if (mgoes[0]->get_users().size() > 1)
{
NGRAPH_DEBUG << "Relu isn't the only user of BatchNorm's output";
return false;
}
mgoes[0] = m.match_root(); //replace relu instead of its GetOutputElement
auto bn_relu = std::make_shared<op::BatchNormRelu>(
m_bn->get_eps_value(), pattern_map[gamma], pattern_map[beta], pattern_map[input]);
auto bn_relu_output = std::make_shared<op::GetOutputElement>(bn_relu, 0);
auto bn_relu_mean = std::make_shared<op::GetOutputElement>(bn_relu, 1);
auto bn_relu_var = std::make_shared<op::GetOutputElement>(bn_relu, 2);
std::shared_ptr<Node> new_nodes[] = {bn_relu_output, bn_relu_mean, bn_relu_var};
for (size_t i = 0; i < mgoes.size(); i++)
{
ngraph::replace_node(mgoes.at(i), new_nodes[i]);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
{
Shape shape{2, 2, 1, 1};
......
......@@ -46,6 +46,7 @@ public:
construct_sigmoid();
construct_sigmoid_bprop();
construct_conv_bias();
construct_batch_norm_relu();
construct_conv_relu();
}
......@@ -58,5 +59,6 @@ private:
void construct_sigmoid_bprop();
void construct_zero_padded_reshaped_conv();
void construct_zero_padded_conv();
void construct_batch_norm_relu();
void construct_conv_relu();
};
......@@ -38,6 +38,7 @@
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
......@@ -1053,6 +1054,40 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNormRelu)
{
auto bn = static_cast<const ngraph::op::BatchNormRelu*>(node.get());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 2);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
if (!bn->get_training_flag() || bn->get_inputs().size() != 3)
{
throw ngraph_error("Only training batchnorm should have been fused");
}
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
prim_output_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
throw ngraph_error("BatchnormRelu only supported in MKLDNN for now");
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNormBackprop)
{
......@@ -1138,6 +1173,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNorm>},
{TI(ngraph::op::BatchNormRelu),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNormRelu>},
{TI(ngraph::op::BatchNormBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::GetOutputElement),
......
......@@ -37,6 +37,7 @@
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
......@@ -53,6 +54,8 @@
#include "util/random.hpp"
#include "util/test_tools.hpp"
#include "util/random.hpp"
using namespace ngraph;
using namespace std;
......@@ -744,6 +747,84 @@ TEST(cpu_fusion, sigmoid_bprop_n1c1h4)
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result)));
}
TEST(cpu_fusion, batchnorm_fprop_relu_b1c2h2w2)
{
auto input_shape = Shape{1, 2, 2, 2};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto var_shape = Shape{2};
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{1, 2, 2, 2};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input);
auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0);
// Note, op::Splice is used to break Relu(BatchNorm) fusion
// otherwise we will be comparing two BatchNormRelus
// Unfortunately, we can't use INTERPRETER for
// verifying the results as it doesn't implement
// BatchNorm op.
auto slice =
std::make_shared<op::Slice>(output_rt, Coordinate{0, 0, 0, 0}, Coordinate{1, 2, 2, 2});
auto output_relu = std::make_shared<op::Relu>(slice);
auto mean_rt = std::make_shared<op::GetOutputElement>(bn, 1);
auto variance_rt = std::make_shared<op::GetOutputElement>(bn, 2);
auto bn_relu = make_shared<op::BatchNormRelu>(eps, gamma, beta, input);
auto output_rt_bnr = std::make_shared<op::GetOutputElement>(bn_relu, 0);
auto mean_rt_bnr = std::make_shared<op::GetOutputElement>(bn_relu, 1);
auto variance_rt_bnr = std::make_shared<op::GetOutputElement>(bn_relu, 2);
auto f = make_shared<Function>(
NodeVector{output_relu, mean_rt, variance_rt, output_rt_bnr, mean_rt_bnr, variance_rt_bnr},
op::ParameterVector{input, gamma, beta});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto input_t = backend->make_primary_tensor_view(element::f32, Shape{1, 2, 2, 2});
copy_data(input_t,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto gamma_t = backend->make_primary_tensor_view(element::f32, gamma_shape);
copy_data(gamma_t, vector<float>{1.0f, 1.0f});
auto beta_t = backend->make_primary_tensor_view(element::f32, beta_shape);
copy_data(beta_t, vector<float>{0.0f, 0.0f});
auto bn_output = backend->make_primary_tensor_view(element::f32, shape_r);
auto result_mean = backend->make_primary_tensor_view(element::f32, mean_shape);
auto result_variance = backend->make_primary_tensor_view(element::f32, var_shape);
auto bn_output_bnr = backend->make_primary_tensor_view(element::f32, shape_r);
auto result_mean_bnr = backend->make_primary_tensor_view(element::f32, mean_shape);
auto result_variance_bnr = backend->make_primary_tensor_view(element::f32, var_shape);
cf->call({bn_output,
result_mean,
result_variance,
bn_output_bnr,
result_mean_bnr,
result_variance_bnr},
{input_t, gamma_t, beta_t});
EXPECT_TRUE(test::all_close(read_vector<float>(bn_output), read_vector<float>(bn_output_bnr)));
EXPECT_TRUE(
test::all_close(read_vector<float>(result_mean), read_vector<float>(result_mean_bnr)));
EXPECT_TRUE(test::all_close(read_vector<float>(result_variance),
read_vector<float>(result_variance_bnr)));
}
TEST(cpu_fusion, fuse_conv_relu)
{
auto A = std::make_shared<op::Parameter>(element::f32, Shape{2, 1, 2, 2});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment