Commit 49bd01fc authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

added DEX support for BatchNormRelu (#1375)

* added DEX support for BatchNormRelu

* - templatized build_batchnorm_emitter
parent e5b50db7
......@@ -22,6 +22,7 @@
#include "ngraph/runtime/cpu/kernel/batchnorm.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
using namespace std;
using namespace ngraph;
......@@ -32,6 +33,7 @@ namespace ngraph
{
namespace cpu
{
template <typename OP>
static void build_batch_norm(CPU_ExternalFunction* external_function,
const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args,
......@@ -46,8 +48,7 @@ namespace ngraph
auto& arg2_tensor = tensor_data[args[2].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
const ngraph::op::BatchNorm* batchnorm =
static_cast<const ngraph::op::BatchNorm*>(node);
const OP* batchnorm = static_cast<const OP*>(node);
shared_ptr<uint8_t> stacked_weights(new uint8_t[2 * args[0].get_size()]);
......@@ -239,7 +240,8 @@ namespace ngraph
}
else
{
build_batch_norm(external_function, node, args, out, false);
build_batch_norm<ngraph::op::BatchNorm>(
external_function, node, args, out, false);
}
}
......@@ -324,7 +326,18 @@ namespace ngraph
functors.emplace_back(functor);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::BatchNormRelu)
{
if (!mkldnn_utils::use_mkldnn_kernel(node))
{
throw ngraph_error("BatchNormRelu is only supported with 4-D MKLDNN kernel.");
}
build_batch_norm<ngraph::op::BatchNormRelu>(
external_function, node, args, out, true);
}
REGISTER_OP_BUILDER(BatchNorm);
REGISTER_OP_BUILDER(BatchNormRelu);
REGISTER_OP_BUILDER(BatchNormBackprop);
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment