Commit 49bd01fc authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

added DEX support for BatchNormRelu (#1375)

* added DEX support for BatchNormRelu

* - templatized build_batchnorm_emitter
parent e5b50db7
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/runtime/cpu/kernel/batchnorm.hpp" #include "ngraph/runtime/cpu/kernel/batchnorm.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp" #include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -32,6 +33,7 @@ namespace ngraph ...@@ -32,6 +33,7 @@ namespace ngraph
{ {
namespace cpu namespace cpu
{ {
template <typename OP>
static void build_batch_norm(CPU_ExternalFunction* external_function, static void build_batch_norm(CPU_ExternalFunction* external_function,
const ngraph::Node* node, const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args, const std::vector<TensorViewWrapper>& args,
...@@ -46,8 +48,7 @@ namespace ngraph ...@@ -46,8 +48,7 @@ namespace ngraph
auto& arg2_tensor = tensor_data[args[2].get_name()]; auto& arg2_tensor = tensor_data[args[2].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()]; auto& out0_tensor = tensor_data[out[0].get_name()];
const ngraph::op::BatchNorm* batchnorm = const OP* batchnorm = static_cast<const OP*>(node);
static_cast<const ngraph::op::BatchNorm*>(node);
shared_ptr<uint8_t> stacked_weights(new uint8_t[2 * args[0].get_size()]); shared_ptr<uint8_t> stacked_weights(new uint8_t[2 * args[0].get_size()]);
...@@ -239,7 +240,8 @@ namespace ngraph ...@@ -239,7 +240,8 @@ namespace ngraph
} }
else else
{ {
build_batch_norm(external_function, node, args, out, false); build_batch_norm<ngraph::op::BatchNorm>(
external_function, node, args, out, false);
} }
} }
...@@ -324,7 +326,18 @@ namespace ngraph ...@@ -324,7 +326,18 @@ namespace ngraph
functors.emplace_back(functor); functors.emplace_back(functor);
} }
template <>
void Builder::BUILDER_DECL(ngraph::op::BatchNormRelu)
{
if (!mkldnn_utils::use_mkldnn_kernel(node))
{
throw ngraph_error("BatchNormRelu is only supported with 4-D MKLDNN kernel.");
}
build_batch_norm<ngraph::op::BatchNormRelu>(
external_function, node, args, out, true);
}
REGISTER_OP_BUILDER(BatchNorm); REGISTER_OP_BUILDER(BatchNorm);
REGISTER_OP_BUILDER(BatchNormRelu);
REGISTER_OP_BUILDER(BatchNormBackprop); REGISTER_OP_BUILDER(BatchNormBackprop);
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment