Commit adb38ab4 authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

Properly support global stats in BN (#1753)

* global stats fix

* Formatting.
parent 3d21f6ed
...@@ -1253,14 +1253,16 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b ...@@ -1253,14 +1253,16 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
const Prop& direction, const Prop& direction,
const Shape& tensor_shape, const Shape& tensor_shape,
const Shape& param_shape, const Shape& param_shape,
double epsilon) double epsilon,
bool global_stats)
{ {
// Assumes NC{d1...dN} format // Assumes NC{d1...dN} format
std::stringstream ss; std::stringstream ss;
ss.precision(std::numeric_limits<double>::digits10 + 2); ss.precision(std::numeric_limits<double>::digits10 + 2);
ss << "bn_op" << bn_op << "_dtype_" << dtype << "_dir" << static_cast<int>(direction) << "_ts" ss << "bn_op" << bn_op << "_dtype_" << dtype << "_dir" << static_cast<int>(direction) << "_ts"
<< join(tensor_shape, "_") << "_ps" << join(param_shape, "_") << "_eps" << epsilon; << join(tensor_shape, "_") << "_ps" << join(param_shape, "_") << "_eps" << epsilon << "_g"
<< global_stats;
std::string hash = ss.str(); std::string hash = ss.str();
std::replace(hash.begin(), hash.end(), '.', '_'); std::replace(hash.begin(), hash.end(), '.', '_');
...@@ -1324,6 +1326,8 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b ...@@ -1324,6 +1326,8 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
void* bias_factor = m_host_parameters.allocate_by_datatype(data_type, (m - 1) / m); void* bias_factor = m_host_parameters.allocate_by_datatype(data_type, (m - 1) / m);
batchnorm.reset(new gpu::primitive{ batchnorm.reset(new gpu::primitive{
[=, &op_desc, &tensor_desc, &derived_param_desc](void** inputs, void** outputs) { [=, &op_desc, &tensor_desc, &derived_param_desc](void** inputs, void** outputs) {
auto mean = (global_stats ? inputs[3] : outputs[1]);
auto variance = (global_stats ? inputs[4] : outputs[2]);
CUDNN_SAFE_CALL(cudnnBatchNormalizationForwardTraining(*m_ctx->cudnn_handle, CUDNN_SAFE_CALL(cudnnBatchNormalizationForwardTraining(*m_ctx->cudnn_handle,
bn_op, bn_op,
alpha, alpha,
...@@ -1336,8 +1340,8 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b ...@@ -1336,8 +1340,8 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
inputs[0], inputs[0],
inputs[1], inputs[1],
exp_avg_factor, exp_avg_factor,
outputs[1], mean,
outputs[2], variance,
epsilon, epsilon,
NULL, NULL,
NULL)); NULL));
...@@ -1348,13 +1352,13 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b ...@@ -1348,13 +1352,13 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
op_desc, op_desc,
beta, beta,
derived_param_desc, derived_param_desc,
outputs[2], variance,
beta, beta,
derived_param_desc, derived_param_desc,
outputs[2], variance,
bias_factor, bias_factor,
derived_param_desc, derived_param_desc,
outputs[2])); variance));
debug_sync(); debug_sync();
}}); }});
break; break;
......
...@@ -124,7 +124,8 @@ namespace ngraph ...@@ -124,7 +124,8 @@ namespace ngraph
const Prop& direction, const Prop& direction,
const Shape& tensor_shape, const Shape& tensor_shape,
const Shape& param_shape, const Shape& param_shape,
double epsilon); double epsilon,
bool global_stats = false);
size_t build_softmax(const cudnnSoftmaxAlgorithm_t& algorithm, size_t build_softmax(const cudnnSoftmaxAlgorithm_t& algorithm,
const cudnnSoftmaxMode_t& mode, const cudnnSoftmaxMode_t& mode,
......
...@@ -285,10 +285,12 @@ void runtime::gpu::GPU_Emitter::emit_BatchNorm(EMIT_ARGS) ...@@ -285,10 +285,12 @@ void runtime::gpu::GPU_Emitter::emit_BatchNorm(EMIT_ARGS)
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter(); auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
bool global_stats = false;
CUDNNEmitter::Prop direction; CUDNNEmitter::Prop direction;
if (batchnorm->get_training_flag() && args.size() == 3) if (batchnorm->get_training_flag())
{ {
direction = CUDNNEmitter::Prop::Forward; direction = CUDNNEmitter::Prop::Forward;
global_stats = (batchnorm->get_arguments().size() == 5);
} }
else else
{ {
...@@ -300,7 +302,8 @@ void runtime::gpu::GPU_Emitter::emit_BatchNorm(EMIT_ARGS) ...@@ -300,7 +302,8 @@ void runtime::gpu::GPU_Emitter::emit_BatchNorm(EMIT_ARGS)
direction, direction,
args[2].get_shape(), args[2].get_shape(),
args[0].get_shape(), args[0].get_shape(),
batchnorm->get_eps_value()); batchnorm->get_eps_value(),
global_stats);
writer.block_begin(); writer.block_begin();
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment