Commit adb38ab4 authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

Properly support global stats in BN (#1753)

* global stats fix

* Formatting.
parent 3d21f6ed
......@@ -1253,14 +1253,16 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
const Prop& direction,
const Shape& tensor_shape,
const Shape& param_shape,
double epsilon)
double epsilon,
bool global_stats)
{
// Assumes NC{d1...dN} format
std::stringstream ss;
ss.precision(std::numeric_limits<double>::digits10 + 2);
ss << "bn_op" << bn_op << "_dtype_" << dtype << "_dir" << static_cast<int>(direction) << "_ts"
<< join(tensor_shape, "_") << "_ps" << join(param_shape, "_") << "_eps" << epsilon;
<< join(tensor_shape, "_") << "_ps" << join(param_shape, "_") << "_eps" << epsilon << "_g"
<< global_stats;
std::string hash = ss.str();
std::replace(hash.begin(), hash.end(), '.', '_');
......@@ -1324,6 +1326,8 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
void* bias_factor = m_host_parameters.allocate_by_datatype(data_type, (m - 1) / m);
batchnorm.reset(new gpu::primitive{
[=, &op_desc, &tensor_desc, &derived_param_desc](void** inputs, void** outputs) {
auto mean = (global_stats ? inputs[3] : outputs[1]);
auto variance = (global_stats ? inputs[4] : outputs[2]);
CUDNN_SAFE_CALL(cudnnBatchNormalizationForwardTraining(*m_ctx->cudnn_handle,
bn_op,
alpha,
......@@ -1336,8 +1340,8 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
inputs[0],
inputs[1],
exp_avg_factor,
outputs[1],
outputs[2],
mean,
variance,
epsilon,
NULL,
NULL));
......@@ -1348,13 +1352,13 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
op_desc,
beta,
derived_param_desc,
outputs[2],
variance,
beta,
derived_param_desc,
outputs[2],
variance,
bias_factor,
derived_param_desc,
outputs[2]));
variance));
debug_sync();
}});
break;
......
......@@ -124,7 +124,8 @@ namespace ngraph
const Prop& direction,
const Shape& tensor_shape,
const Shape& param_shape,
double epsilon);
double epsilon,
bool global_stats = false);
size_t build_softmax(const cudnnSoftmaxAlgorithm_t& algorithm,
const cudnnSoftmaxMode_t& mode,
......
......@@ -285,10 +285,12 @@ void runtime::gpu::GPU_Emitter::emit_BatchNorm(EMIT_ARGS)
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
bool global_stats = false;
CUDNNEmitter::Prop direction;
if (batchnorm->get_training_flag() && args.size() == 3)
if (batchnorm->get_training_flag())
{
direction = CUDNNEmitter::Prop::Forward;
global_stats = (batchnorm->get_arguments().size() == 5);
}
else
{
......@@ -300,7 +302,8 @@ void runtime::gpu::GPU_Emitter::emit_BatchNorm(EMIT_ARGS)
direction,
args[2].get_shape(),
args[0].get_shape(),
batchnorm->get_eps_value());
batchnorm->get_eps_value(),
global_stats);
writer.block_begin();
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment