Commit 27fb77b6 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Fix allocation size in batchnorm kernel (#1386)

* Fix allocation size in batchnorm kernel

* added missing brackets
parent 134b0ae2
......@@ -50,8 +50,6 @@ namespace ngraph
const OP* batchnorm = static_cast<const OP*>(node);
shared_ptr<uint8_t> stacked_weights(new uint8_t[2 * args[0].get_size()]);
// Kill clang diagnostics bug
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-braces"
......@@ -62,6 +60,8 @@ namespace ngraph
#pragma clang diagnostic pop
shared_ptr<uint8_t> stacked_weights(new uint8_t[weight_sizes[0] + weight_sizes[1]]);
const float ops_scale = 1.f;
const float ops_alpha = -0.f; // relu negative slope
const float ops_beta = 0.f;
......@@ -265,9 +265,6 @@ namespace ngraph
auto& out1_tensor = tensor_data[out[1].get_name()];
auto& out2_tensor = tensor_data[out[2].get_name()];
shared_ptr<uint8_t> stacked_weights(new uint8_t[2 * args[0].get_size()]);
shared_ptr<uint8_t> stacked_dweights(new uint8_t[2 * args[0].get_size()]);
// Kill clang diagnostics bug
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-braces"
......@@ -277,6 +274,9 @@ namespace ngraph
args[1].get_size() * args[1].get_element_type().size()};
#pragma clang diagnostic pop
shared_ptr<uint8_t> stacked_weights(new uint8_t[weight_sizes[0] + weight_sizes[1]]);
shared_ptr<uint8_t> stacked_dweights(
new uint8_t[weight_sizes[0] + weight_sizes[1]]);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment