Commit fbc38cf4 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Added DEX support for BoundedRelu (#1355)

* - Added DEX support for BoundedRelu
- Refactored bounded_relu in cpu_emitter to use mkldnn_emitter helper methods

* remove unwanted templatization for bounded_relu mkldnn_emitter
parent 87bcec21
......@@ -30,6 +30,7 @@ set(SRC
builder/avg_pool.cpp
builder/batch_norm.cpp
builder/broadcast.cpp
builder/bounded_relu.cpp
builder/concat.cpp
builder/convert.cpp
builder/convert_layout.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/kernel/relu.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::BoundedRelu)
{
if (!runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
throw ngraph_error(
"BoundedRelu is supported only through MKLDNN and doesnt have reference "
"INTERPRETER implementation");
}
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& input_tensor = tensor_data[args[0].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto bounded_relu_index = mkldnn_emitter->build_bounded_relu(node, args, out);
auto& deps = mkldnn_emitter->get_primitive_deps(bounded_relu_index);
auto functor = [&, bounded_relu_index](CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], input_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, bounded_relu_index);
};
functors.emplace_back(functor);
}
}
REGISTER_OP_BUILDER(BoundedRelu);
}
}
}
......@@ -3727,15 +3727,10 @@ namespace ngraph
{
auto bounded_relu_node = static_cast<const ngraph::op::BoundedRelu*>(node);
float alpha = bounded_relu_node->get_alpha();
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t bounded_relu_index =
mkldnn_emitter->build_bounded_relu(input_desc, result_desc, alpha);
auto bounded_relu_index = mkldnn_emitter->build_bounded_relu(node, args, out);
auto& deps = mkldnn_emitter->get_primitive_deps(bounded_relu_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
......
......@@ -28,6 +28,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/bounded_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/shape.hpp"
......@@ -483,6 +484,18 @@ namespace ngraph
const mkldnn::memory::desc& result_desc,
int softmax_axis);
size_t build_bounded_relu(const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args,
const std::vector<TensorViewWrapper>& out)
{
auto bounded_relu_node = static_cast<const ngraph::op::BoundedRelu*>(node);
float alpha = bounded_relu_node->get_alpha();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
return build_bounded_relu(input_desc, result_desc, alpha);
}
size_t build_bounded_relu(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
float alpha);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment