Commit eaa6091c authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

MKLDNN BoundedRelu implementation for Relu6 (#1179)

* 1. Added MKLDNNN BoundedRelu op support for Relu6
2. CpuLayout && CPU assignment pass for BoundedRelu Op
3. Unit test inter v/s CPU for BoundedReluOp
4. MKLDNN and default emitter code for BoundedReluOp

* Removed Debug prints

* 1. Added support for boundedrelu to work on any constant literal
2. unit test case for rank2, rank3, rank4 for bounded relu without serialized graph

* Removed is_six() method
parent e42e5815
......@@ -37,6 +37,7 @@ set(SRC
mkldnn_utils.cpp
op/batch_dot.cpp
op/batch_norm_relu.cpp
op/bounded_relu.cpp
op/group_conv.cpp
op/conv_bias.cpp
op/conv_relu.cpp
......
......@@ -94,6 +94,7 @@
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_dot.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
......@@ -4195,6 +4196,44 @@ namespace ngraph
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::BoundedRelu)
{
auto bounded_relu_node = static_cast<const ngraph::op::BoundedRelu*>(node);
float alpha = bounded_relu_node->get_alpha();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t bounded_relu_index =
mkldnn_emitter->build_bounded_relu(input_desc, result_desc, alpha);
auto& deps = mkldnn_emitter->get_primitive_deps(bounded_relu_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(bounded_relu_index) << ");\n";
}
else
{
writer << "#pragma omp parallel for\n";
writer << "for (size_t i = 0; i < " << out[0].get_size() << "; i++)\n";
writer.block_begin();
writer << args[0].get_name() << "[i] = " << args[0].get_name() << "[i] > 0 ? "
<< args[0].get_name() << "[i] : 0;\n";
writer << out[0].get_name() << "[i] = " << args[0].get_name() << "[i] < "
<< alpha << " ? " << args[0].get_name() << "[i] : " << alpha << ";\n";
writer.block_end();
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Sigmoid)
{
......
......@@ -120,6 +120,7 @@
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_dot.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
......@@ -288,6 +289,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::BatchNorm), &runtime::cpu::CPU_Emitter::emit<op::BatchNorm>},
{TI(ngraph::op::BatchNormRelu), &runtime::cpu::CPU_Emitter::emit<op::BatchNormRelu>},
{TI(ngraph::op::BatchNormBackprop), &runtime::cpu::CPU_Emitter::emit<op::BatchNormBackprop>},
{TI(ngraph::op::BoundedRelu), &runtime::cpu::CPU_Emitter::emit<op::BoundedRelu>},
{TI(ngraph::op::Lstm), &runtime::cpu::CPU_Emitter::emit<op::Lstm>},
{TI(ngraph::op::MaxPoolBackprop), &runtime::cpu::CPU_Emitter::emit<op::MaxPoolBackprop>},
{TI(ngraph::op::MaxPoolWithIndicesBackprop),
......
......@@ -957,3 +957,24 @@ size_t MKLDNNEmitter::build_softmax_forward(const mkldnn::memory::desc& input_de
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_bounded_relu(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
float alpha)
{
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
size_t primitive_index =
insert_primitive(new mkldnn::eltwise_forward({{mkldnn::prop_kind::forward_training,
mkldnn::algorithm::eltwise_bounded_relu,
input_desc,
alpha,
0.0f},
mkldnn_utils::global_cpu_engine},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
......@@ -243,6 +243,10 @@ namespace ngraph
const mkldnn::memory::desc& result_desc,
int softmax_axis);
size_t build_bounded_relu(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
float alpha);
private:
std::vector<mkldnn::primitive*> m_mkldnn_primitives;
std::vector<mkldnn::stream> m_mkldnn_streams;
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
op::BoundedRelu::BoundedRelu(shared_ptr<Node> arg, float alpha)
: RequiresTensorViewArgs("BoundedRelu", {arg})
, m_alpha(alpha)
{
set_value_type_checked(arg->get_element_type(), arg->get_shape());
}
shared_ptr<Node> op::BoundedRelu::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<BoundedRelu>(new_args.at(0), m_alpha);
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
/// \brief Elementwise Minimum(Relu(arg, 0), alpha) operation.
///
class BoundedRelu : public util::RequiresTensorViewArgs
{
public:
/// \brief Constructs a BoundedRelu operation.
///
/// \param arg Node input to the Relu.
BoundedRelu(std::shared_ptr<ngraph::Node> arg, float alpha);
float get_alpha() const { return m_alpha; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
float m_alpha;
};
}
}
......@@ -37,6 +37,7 @@
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/group_conv.hpp"
......@@ -662,6 +663,25 @@ namespace ngraph
softmax->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BoundedRelu)
{
auto bounded_relu = static_cast<op::BoundedRelu*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if ((arg0_rank == 4 || arg0_rank == 2) &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
bounded_relu->set_op_annotations(op_annotations);
}
}
}
}
}
......@@ -676,6 +696,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::AvgPoolBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNorm>},
{TI(ngraph::op::BoundedRelu),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BoundedRelu>},
{TI(ngraph::op::BatchNormBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::Convolution),
......
......@@ -31,6 +31,8 @@
#include "ngraph/op/dot.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/pad.hpp"
......@@ -45,6 +47,7 @@
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
......@@ -1322,3 +1325,51 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_multiply()
auto m = std::make_shared<ngraph::pattern::Matcher>(elem_mul, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_bounded_relu()
{
auto relu_input = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto relu = std::make_shared<op::Relu>(relu_input);
auto iconst1 = op::Constant::create(element::f32, Shape{}, {1});
auto alpha = std::make_shared<pattern::op::Label>(iconst1);
auto min = std::make_shared<op::Minimum>(relu, alpha);
pattern::graph_rewrite_callback callback = [relu_input, alpha](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_bounded_relu against "
<< m.get_match_root()->get_name();
if (m.get_match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< " type is not float!";
return false;
}
auto pattern_map = m.get_pattern_map();
if (!std::dynamic_pointer_cast<op::Constant>(pattern_map[alpha]))
{
throw ngraph_error("alpha must be constant for bounded relu");
}
//we wont fuse if the alpha and the Relu output element type are not same
if (pattern_map[alpha]->get_element_type() != pattern_map[relu_input]->get_element_type())
{
return false;
}
if (pattern_map[alpha]->get_shape() != pattern_map[relu_input]->get_shape())
{
return false;
}
auto alpha_const_op = std::dynamic_pointer_cast<op::Constant>(pattern_map[alpha]);
float alpha_val = *(static_cast<float const*>(alpha_const_op->get_data_ptr()));
NGRAPH_DEBUG << "relu_input: " << pattern_map[relu_input] << " min_val: "
<< *(static_cast<float const*>(alpha_const_op->get_data_ptr()));
auto cg = std::shared_ptr<Node>(new op::BoundedRelu(pattern_map[relu_input], alpha_val));
ngraph::replace_node(m.get_match_root(), cg);
return true;
};
auto m = std::make_shared<pattern::Matcher>(min, callback);
this->add_matcher(m);
}
......@@ -66,6 +66,7 @@ public:
construct_conv_bias_relu();
construct_conv_bias_add();
construct_conv_bias_add_relu();
construct_bounded_relu();
}
if (fusions & DIFFERENTIABLE_FUSIONS)
......@@ -93,4 +94,5 @@ private:
void construct_conv_bias_relu();
void construct_conv_bias_add();
void construct_conv_bias_add_relu();
void construct_bounded_relu();
};
......@@ -41,6 +41,7 @@
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
......@@ -1484,6 +1485,23 @@ namespace ngraph
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BoundedRelu)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
vector<memory::format> prim_output_formats;
prim_output_formats.push_back(input_layout);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
}
}
}
......@@ -1538,6 +1556,7 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Rnn>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Softmax>},
{TI(ngraph::op::BoundedRelu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::BoundedRelu>},
};
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
......@@ -46,6 +46,7 @@
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/op/batch_dot.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
......@@ -2685,3 +2686,40 @@ TEST(cpu_fusion, fuse_rnn_across_2layer_1timestep)
EXPECT_TRUE(test::all_close(cpu_results.at(1), int_results.at(1), 1.0e-4f, 1.0e-4f));
}
}
static void check_bounded_relu(Shape param_shape, float constant_val)
{
auto make_function = [](Shape input_shape, float alpha_val) {
auto relu_input = std::make_shared<op::Parameter>(element::f32, input_shape);
auto relu = std::make_shared<op::Relu>(relu_input);
auto alpha = op::Constant::create<float>(
element::f32, input_shape, std::vector<float>(1.0f, alpha_val));
auto min = std::make_shared<op::Minimum>(relu, alpha);
auto f = make_shared<Function>(NodeVector{min}, op::ParameterVector{relu_input});
return f;
};
auto cpu_f = make_function(param_shape, constant_val);
auto int_f = make_function(param_shape, constant_val);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
EXPECT_EQ(1, count_ops_of_type<op::BoundedRelu>(cpu_f));
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0), 1.0e-4f, 1.0e-4f));
}
TEST(cpu_fusion, fuse_bounded_relu_inter_vs_cpu)
{
check_bounded_relu(Shape{4, 3, 2, 2}, 6.0f);
check_bounded_relu(Shape{4, 3}, 4.0f);
check_bounded_relu(Shape{4, 3, 2}, 2.0f);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment