Commit 652840ab authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Scott Cyphers

DEX: Add MKLDNN path for Relu (#1409)

parent 3ff9e490
......@@ -29,6 +29,38 @@ namespace ngraph
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Relu)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg_tensor = tensor_data[args[0].get_name()];
auto& out_tensor = tensor_data[out[0].get_name()];
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
size_t relu_index = mkldnn_emitter->build_relu_forward(input_desc, result_desc);
auto& deps = mkldnn_emitter->get_primitive_deps(relu_index);
auto functor = [&, relu_index](CPURuntimeContext* ctx) {
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[0], arg_tensor);
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[1], out_tensor);
cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, relu_index);
};
functors.emplace_back(functor);
}
else
{
BUILD_UNARY_ELEMWISE_FUNCTOR(runtime::cpu::kernel::relu);
}
}
template <>
void Builder::BUILDER_DECL(ngraph::op::ReluBackprop)
{
......@@ -71,6 +103,8 @@ namespace ngraph
functors.emplace_back(functor);
}
}
REGISTER_OP_BUILDER(Relu);
REGISTER_OP_BUILDER(ReluBackprop);
}
}
......
......@@ -53,7 +53,6 @@
#include "ngraph/op/or.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
......@@ -90,7 +89,6 @@
#include "ngraph/runtime/cpu/kernel/not.hpp"
#include "ngraph/runtime/cpu/kernel/not_equal.hpp"
#include "ngraph/runtime/cpu/kernel/or.hpp"
#include "ngraph/runtime/cpu/kernel/relu.hpp"
#include "ngraph/runtime/cpu/kernel/result.hpp"
#include "ngraph/runtime/cpu/kernel/sign.hpp"
#include "ngraph/runtime/cpu/kernel/sin.hpp"
......@@ -276,12 +274,6 @@ namespace ngraph
BUILD_UNARY_ELEMWISE_FUNCTOR(runtime::cpu::kernel::negative);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::Relu)
{
BUILD_UNARY_ELEMWISE_FUNCTOR(runtime::cpu::kernel::relu);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::Sqrt)
{
......@@ -390,7 +382,6 @@ namespace ngraph
REGISTER_OP_BUILDER(Cosh)
REGISTER_OP_BUILDER(Floor);
REGISTER_OP_BUILDER(Negative);
REGISTER_OP_BUILDER(Relu);
REGISTER_OP_BUILDER(Exp);
REGISTER_OP_BUILDER(Log);
REGISTER_OP_BUILDER(Sqrt);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment