Commit 587b96e5 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Adding leaky relu (#2096)

* Adding leaky relu

* Silence compiler warning around fp compares

* Fix copy-paste error and enable in-place for relu mkldnn kernels
parent 266843fa
......@@ -46,6 +46,7 @@ set(SRC
builder/convolution.cpp
builder/dot.cpp
builder/function_call.cpp
builder/leaky_relu.cpp
builder/lstm.cpp
builder/lrn.cpp
builder/matmul_bias.cpp
......@@ -90,6 +91,7 @@ set(SRC
op/conv_bias.cpp
op/conv_relu.cpp
op/convert_layout.cpp
op/leaky_relu.cpp
op/loop_kernel.cpp
op/lstm.cpp
op/matmul_bias.cpp
......
......@@ -38,10 +38,15 @@ namespace ngraph
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
size_t count = out[0].get_size();
auto alpha = static_cast<const op::BoundedRelu*>(node)->get_alpha();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto bounded_relu_index = mkldnn_emitter->build_bounded_relu(node, args, out);
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
auto bounded_relu_index =
mkldnn_emitter->build_bounded_relu(input_desc, result_desc, alpha);
auto& deps = mkldnn_emitter->get_primitive_deps(bounded_relu_index);
auto functor = [&, bounded_relu_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
......@@ -58,7 +63,6 @@ namespace ngraph
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::bounded_relu);
auto alpha = static_cast<const op::BoundedRelu*>(node)->get_alpha();
auto functor = [&, kernel, alpha, count](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(input_tensor, out_tensor, alpha, count, ectx->arena);
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/relu.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::LeakyRelu)
{
auto& functors = external_function->get_functors();
auto& input_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
size_t count = out[0].get_size();
auto alpha = static_cast<const op::LeakyRelu*>(node)->get_alpha();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
auto leaky_relu_index =
mkldnn_emitter->build_leaky_relu(input_desc, result_desc, alpha);
auto& deps = mkldnn_emitter->get_primitive_deps(leaky_relu_index);
auto functor = [&, leaky_relu_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], input_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, leaky_relu_index);
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::leaky_relu<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::leaky_relu);
auto functor = [&, kernel, alpha, count](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(input_tensor, out_tensor, alpha, count, ectx->arena);
};
functors.emplace_back(functor);
}
}
REGISTER_OP_BUILDER(LeakyRelu);
}
}
}
......@@ -112,6 +112,7 @@
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
......@@ -4119,6 +4120,39 @@ namespace ngraph
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::LeakyRelu)
{
auto leaky_relu_node = static_cast<const ngraph::op::LeakyRelu*>(node);
float alpha = leaky_relu_node->get_alpha();
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
auto leaky_relu_index =
mkldnn_emitter->build_leaky_relu(input_desc, result_desc, alpha);
auto& deps = mkldnn_emitter->get_primitive_deps(leaky_relu_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(leaky_relu_index) << ");\n";
}
else
{
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] > 0 ? "
<< args[0].get_name() << "[i] : (" << alpha << " * "
<< args[0].get_name() << "[i]);\n";
writer.block_end();
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::BoundedRelu)
{
......@@ -4127,7 +4161,10 @@ namespace ngraph
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto bounded_relu_index = mkldnn_emitter->build_bounded_relu(node, args, out);
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
auto bounded_relu_index =
mkldnn_emitter->build_bounded_relu(input_desc, result_desc, alpha);
auto& deps = mkldnn_emitter->get_primitive_deps(bounded_relu_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
......
......@@ -153,6 +153,7 @@
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
......@@ -395,6 +396,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::SigmoidBackprop), &runtime::cpu::CPU_Emitter::emit<op::SigmoidBackprop>},
{TI(ngraph::op::And), &runtime::cpu::CPU_Emitter::emit<op::And>},
{TI(ngraph::op::Or), &runtime::cpu::CPU_Emitter::emit<op::Or>},
{TI(ngraph::op::LeakyRelu), &runtime::cpu::CPU_Emitter::emit<op::LeakyRelu>},
{TI(ngraph::runtime::cpu::op::LoopKernel),
&runtime::cpu::CPU_Emitter::emit<runtime::cpu::op::LoopKernel>},
{TI(ngraph::op::LRN), &runtime::cpu::CPU_Emitter::emit<ngraph::op::LRN>},
......
......@@ -63,6 +63,23 @@ namespace ngraph
in0.cwiseMax(ElementType(0)).cwiseMin(alpha);
}
template <typename ElementType>
void leaky_relu(
void* input0, void* output, ElementType alpha, size_t count, int arena)
{
Eigen::array<Eigen::Index, 1> out_dims, in_dims;
out_dims[0] = in_dims[0] = count;
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> out(
static_cast<ElementType*>(output), out_dims);
Eigen::TensorMap<Eigen::Tensor<ElementType, 1, Eigen::RowMajor>> in0(
static_cast<ElementType*>(input0), in_dims);
out.device(ngraph::runtime::cpu::executor::GetCPUExecutor().get_device(arena)) =
in0.cwiseMax(in0 * alpha);
}
template <typename ElementType>
void relu_backprop(void* arg, void* delta_arg, void* out, size_t count, int arena)
{
......
......@@ -1162,6 +1162,27 @@ size_t MKLDNNEmitter::build_softmax_forward(const mkldnn::memory::desc& input_de
return primitive_index;
}
size_t MKLDNNEmitter::build_leaky_relu(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
float alpha)
{
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index =
insert_primitive(new mkldnn::eltwise_forward({{mkldnn::prop_kind::forward_training,
mkldnn::algorithm::eltwise_relu,
input_desc,
alpha,
0.0f},
executor::global_cpu_engine},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_bounded_relu(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
float alpha)
......
......@@ -598,17 +598,9 @@ namespace ngraph
const mkldnn::memory::desc& result_desc,
int softmax_axis);
size_t build_bounded_relu(const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args,
const std::vector<TensorViewWrapper>& out)
{
auto bounded_relu_node = static_cast<const ngraph::op::BoundedRelu*>(node);
float alpha = bounded_relu_node->get_alpha();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
return build_bounded_relu(input_desc, result_desc, alpha);
}
size_t build_leaky_relu(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
float alpha);
size_t build_bounded_relu(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
op::LeakyRelu::LeakyRelu(shared_ptr<Node> arg, float alpha)
: Op("LeakyRelu", check_single_output_args({arg}))
, m_alpha(alpha)
{
constructor_validate_and_infer_types();
if (alpha < 0)
{
throw ngraph_error("Leaky Relu expects non-negative alpha");
}
set_output_type(0, arg->get_element_type(), arg->get_shape());
}
shared_ptr<Node> op::LeakyRelu::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<LeakyRelu>(new_args.at(0), m_alpha);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Elementwise Maximum(arg, arg * alpha) operation
/// alpha > 0
///
class LeakyRelu : public Op
{
public:
/// \brief Constructs a LeakyRelu operation.
///
/// \param arg Node input to the Relu.
LeakyRelu(std::shared_ptr<ngraph::Node> arg, float alpha);
float get_alpha() const { return m_alpha; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
float m_alpha;
};
}
}
......@@ -53,6 +53,7 @@
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
......@@ -719,10 +720,39 @@ namespace ngraph
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
if (get_user_count(node->get_argument(0).get()) == 1)
{
// Safe to overwrite input
op_annotations->add_in_place_oi_pair({0, 0, true});
}
bounded_relu->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::LeakyRelu)
{
auto leaky_relu = static_cast<op::LeakyRelu*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if ((arg0_rank == 4 || arg0_rank == 2) &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
if (get_user_count(node->get_argument(0).get()) == 1)
{
// Safe to overwrite input
op_annotations->add_in_place_oi_pair({0, 0, true});
}
leaky_relu->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolution)
{
......@@ -891,6 +921,7 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Relu>},
{TI(ngraph::op::ReluBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReluBackprop>},
{TI(ngraph::op::LeakyRelu), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::LeakyRelu>},
{TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Sigmoid>},
{TI(ngraph::op::SigmoidBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::SigmoidBackprop>},
......
......@@ -60,6 +60,7 @@
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/sigmoid_mul.hpp"
#include "ngraph/util.hpp"
......@@ -1291,6 +1292,63 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_multiply()
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_leaky_relu()
{
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto iconst1 = op::Constant::create(element::f32, Shape{}, {1});
auto alpha = std::make_shared<pattern::op::Label>(iconst1);
auto broadcast_pred = [](std::shared_ptr<Node> n) {
return (std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr);
};
auto skip_broadcast = std::make_shared<pattern::op::Skip>(alpha, broadcast_pred);
auto leaky_relu =
std::make_shared<op::Maximum>(input, std::make_shared<op::Multiply>(input, skip_broadcast));
pattern::graph_rewrite_callback callback = [input, alpha](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_leaky_relu against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
if (!std::dynamic_pointer_cast<op::Constant>(pattern_map[alpha]))
{
NGRAPH_DEBUG << "alpha must be constant for leaky relu";
return false;
}
if (pattern_map[alpha]->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "Only float negative slope supported for leaky relu";
return false;
}
auto alpha_const_op = std::static_pointer_cast<op::Constant>(pattern_map[alpha]);
auto alpha_vec = alpha_const_op->get_vector<float>();
for (auto val : alpha_vec)
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal"
if (val != alpha_vec[0])
{
NGRAPH_DEBUG << "alpha is not a singular constant";
return false;
}
#pragma clang diagnostic pop
}
if (alpha_vec[0] < 0)
{
NGRAPH_DEBUG << "alpha is not positive";
return false;
}
auto cg = std::shared_ptr<Node>(new op::LeakyRelu(pattern_map[input], alpha_vec[0]));
ngraph::replace_node(m.get_match_root(), cg);
return true;
};
auto m = std::make_shared<pattern::Matcher>(leaky_relu, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_bounded_relu()
{
auto relu_input = std::make_shared<pattern::op::Label>(element::f32, Shape{});
......
......@@ -74,6 +74,7 @@ public:
construct_conv_bias_relu();
construct_conv_bias_add();
construct_conv_bias_add_relu();
construct_leaky_relu();
construct_bounded_relu();
// construct_conv_add() should always be after construct_conv_bias()
construct_conv_add();
......@@ -99,6 +100,7 @@ private:
void construct_conv_bias_add_relu();
void construct_conv_add();
void construct_conv_add_relu();
void construct_leaky_relu();
void construct_bounded_relu();
void construct_conv_bias_folded_batch_norm();
void construct_conv_bias_affine_folding();
......
......@@ -59,6 +59,7 @@
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
......@@ -1903,6 +1904,22 @@ namespace ngraph
set_native_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::LeakyRelu)
{
if (mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
vector<memory::desc> o_mds;
o_mds.push_back(input_md);
set_output_layouts(node, o_mds);
}
else
{
set_native_layouts(external_function, node);
}
}
}
}
}
......@@ -1969,6 +1986,7 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Rnn>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Softmax>},
{TI(ngraph::op::BoundedRelu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BoundedRelu>},
{TI(ngraph::op::LeakyRelu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::LeakyRelu>},
{TI(ngraph::op::ConvolutionAdd),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ConvolutionAdd>},
{TI(ngraph::op::Slice), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Slice>},
......
......@@ -56,6 +56,7 @@
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
#include "ngraph/runtime/cpu/op/group_conv_bias.hpp"
#include "ngraph/runtime/cpu/op/leaky_relu.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
......@@ -3023,6 +3024,45 @@ TEST(cpu_fusion, fuse_bounded_relu_inter_vs_cpu)
check_bounded_relu(Shape{4, 3, 2}, 2.0f);
}
TEST(cpu_fusion, fuse_leaky_relu)
{
auto make_function = [](Shape input_shape, vector<float> alpha_val) {
auto input = std::make_shared<op::Parameter>(element::f32, input_shape);
auto alpha = op::Constant::create<float>(element::f32, input_shape, alpha_val);
auto out =
std::make_shared<op::Maximum>(input, std::make_shared<op::Multiply>(input, alpha));
auto f = make_shared<Function>(NodeVector{out}, ParameterVector{input});
return f;
};
auto no_fuse1 = make_function(Shape{1, 2, 3}, std::vector<float>(6, -1.0f));
auto no_fuse2 = make_function(Shape{1, 3}, std::vector<float>{1.4f, 1.2f, 1.4f});
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(no_fuse1);
pass_manager.run_passes(no_fuse2);
EXPECT_EQ(0, count_ops_of_type<op::LeakyRelu>(no_fuse1));
EXPECT_EQ(0, count_ops_of_type<op::LeakyRelu>(no_fuse2));
// non-mkldnn kernel
auto cpu_f1 = make_function(Shape{1, 2, 3}, std::vector<float>(6, 0.1f));
// mkldnn kernel
auto cpu_f2 = make_function(Shape{2, 3}, std::vector<float>(6, 0.1f));
vector<vector<float>> args;
args.push_back(std::vector<float>{-1, -2, 0, 1, 2, 3});
std::vector<float> expected_result{-0.1f, -0.2f, 0.0f, 1.0f, 2.0f, 3.0f};
auto cpu1_results = execute(cpu_f1, args, "CPU");
EXPECT_EQ(1, count_ops_of_type<op::LeakyRelu>(cpu_f1));
EXPECT_TRUE(test::all_close(cpu1_results.at(0), expected_result));
auto cpu2_results = execute(cpu_f2, args, "CPU");
EXPECT_EQ(1, count_ops_of_type<op::LeakyRelu>(cpu_f2));
EXPECT_TRUE(test::all_close(cpu2_results.at(0), expected_result));
}
TEST(cpu_fusion, dot_batch_forward)
{
const Shape shape_a{2, 3, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment