Commit 87b5758d authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

Added DEX execution support for ReluBprop (#1305)

parent 1fdf2d98
...@@ -40,6 +40,7 @@ set(SRC ...@@ -40,6 +40,7 @@ set(SRC
builder/max.cpp builder/max.cpp
builder/max_pool.cpp builder/max_pool.cpp
builder/min.cpp builder/min.cpp
builder/relu.cpp
builder/product.cpp builder/product.cpp
builder/reshape.cpp builder/reshape.cpp
builder/reverse.cpp builder/reverse.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/cpu/kernel/relu.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::ReluBackprop)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg_fwd_tensor = tensor_data[args[0].get_name()];
auto& delta_tensor = tensor_data[args[1].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
size_t count = out[0].get_size();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_memory_descriptor(
args[0], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0));
auto delta_desc = mkldnn_emitter->build_memory_descriptor(
args[1], runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 1));
auto result_desc = mkldnn_emitter->build_memory_descriptor(
out[0], runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0));
size_t relu_index =
mkldnn_emitter->build_relu_backward(input_desc, delta_desc, result_desc);
auto& deps = mkldnn_emitter->get_primitive_deps(relu_index);
auto functor = [&, relu_index](CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg_fwd_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], delta_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[2], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, relu_index);
};
functors.emplace_back(functor);
}
else
{
std::function<decltype(runtime::cpu::kernel::relu_backprop<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::relu_backprop);
auto functor = [&, kernel, count](CPURuntimeContext* ctx) {
kernel(arg_fwd_tensor, delta_tensor, out_tensor, count);
};
functors.emplace_back(functor);
}
}
REGISTER_OP_BUILDER(ReluBackprop);
}
}
}
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <unsupported/Eigen/CXX11/Tensor> #include <unsupported/Eigen/CXX11/Tensor>
#include "ngraph/runtime/cpu/kernel/eigen_thread_pool.hpp" #include "ngraph/runtime/cpu/kernel/eigen_thread_pool.hpp"
#include "ngraph/runtime/reference/relu.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -43,6 +44,15 @@ namespace ngraph ...@@ -43,6 +44,15 @@ namespace ngraph
out.device(eigen::global_thread_pool_device) = in0.cwiseMax(ElementType(0)); out.device(eigen::global_thread_pool_device) = in0.cwiseMax(ElementType(0));
} }
template <typename ElementType>
void relu_backprop(void* arg, void* delta_arg, void* out, size_t count)
{
reference::relu_backprop<ElementType>(static_cast<ElementType*>(arg),
static_cast<ElementType*>(delta_arg),
static_cast<ElementType*>(out),
count);
}
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment