Commit 23ac5e5a authored by Chris Sullivan's avatar Chris Sullivan Committed by Scott Cyphers

CUDNN BatchNorm (inference/forward/backward) (#893)

* Added cudnn batch norm operation to GPU transformer.
Brought batchnorm tests out of cpu_tests and into
backend_tests. Need to add JIRA ticket for interpreter
SKIPS.

* CUDNN batchnorm is implemented. In the ForwardTraining branch
CUDNN seems to calculate the batch mean correctly but the batch variance incorrectly.
Currently the batchnorm output and mean are calculated correctly for tests:
* GPU.batchnorm_fprop_b2c2h3w3_mean_var
* GPU.batchnorm_fprop_b1c2h2w2
* GPU.batchnorm_fprop_b2c2h2w1
but the variance calculated for the batches in these tests is incorrectly calculated by CUDNN.

Also added an additional test and cleaned up some of the old tests.

* MKLDNN internally utilizes the biased estimate of the population variance
and the tests have been crafted to suit MKLDNN. According to the original
batchnorm publication (https://arxiv.org/pdf/1502.03167v3.pdf), population
(unbiased) statistics should be used for inference, and mini-batch (biased)
statistics should be used training (forward/backward). For the variance this
means utlitizing the following equations, respectively:

  (biased)   Var[X] = 1/m * Sum_i(x_i-mu)^2      :: used in training
  (unbiased) Var[X] = 1/(m-1) * Sum_i(x_i-mu)^2  :: used in inference

  s.t. x_i are elements of X and m = N*D*H*W.

For large batch sizes in inference this may not impact convergence as m >> 1,
but for small batch sizes it will. CUDNN internally utilizes the unbiased
variance.

Changes:
* Added Multiply op to Forward pass of batchnorm to convert
  the unbiased variance to a biased one. The op utilizes the
  blending scaling factors to apply the bias factor.
* Adds emission for the BatchNormBackprop kernel and cleans up
  the emitter implementation.

* Added hashing to cudnn::batchnorm op.

* Formatting.

* Changed hashing of epsilon in cudnn batchnorm.

* Remove implicit conversion and default case in switch for bn.

* Added skips for IE transformer on batchnorm.

* add cudnn include path to compiler.cpp

* seperate two path

* PR #892 and #825 which were recently merged both forgot skips for the GPU backend.
Adding them in as they are unimplemented ops.

* The allocation and deletion of primitives was occuring in seperate
translation units with raw c pointers. Because of this, it was not
clear that these were being freed appropriate, nor did it indicate
ownership of the pointers.

In this commit these raw pointers have been converted over to
std::unique_ptrs such that the construction/destruction is managed
automatically. Furthermore, GPUPrimitiveEmitter::insert now only
takes an r-value reference, requiring move-semantics to indicate
that when inserting a primitive, the GPUPrimitiveEmitter takes
ownership of the pointer.

All instances of primitive creation have been modified.

* CUDNN_SAFE_CALL

* Removed redundant comment and made variable names more verbose.

* Change from conditionals to case-switch in pooling to conform to
batchnorm per @fengleitian's suggestion.
parent b0421577
......@@ -150,12 +150,12 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex
compiled_kernel = ctx->compiled_kernel_pool->set(hash, writer.get_code());
}
gpu::primitive* pad = nullptr;
std::unique_ptr<gpu::primitive> pad;
// if the pad value is statically provided, the kernel call signature is different
if (pad_value == "") // pad value provided at runtime (dynamic)
{
pad = new gpu::primitive{[=](void** inputs, void** outputs) {
pad.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
void* args_list[] = {&inputs[1], &inputs[0], &outputs[0]};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(nthreads),
......@@ -169,11 +169,11 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}};
}});
}
else // pad value provided at compile time (static)
{
pad = new gpu::primitive{[=](void** inputs, void** outputs) {
pad.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
void* args_list[] = {&inputs[0], &outputs[0]};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(nthreads),
......@@ -187,10 +187,10 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}};
}});
}
primitive_index = this->m_primitive_emitter->insert(pad);
primitive_index = this->m_primitive_emitter->insert(std::move(pad));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
......@@ -259,7 +259,7 @@ size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const GPURuntimeContext* ctx
compiled_kernel = ctx->compiled_kernel_pool->set(hash, writer.get_code());
}
auto pool = new gpu::primitive{[=](void** inputs, void** outputs) {
std::unique_ptr<gpu::primitive> pool(new gpu::primitive{[=](void** inputs, void** outputs) {
void* args_list[] = {&inputs[0], &outputs[0]};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(nthreads),
......@@ -273,9 +273,9 @@ size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const GPURuntimeContext* ctx
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}};
}});
primitive_index = this->m_primitive_emitter->insert(pool);
primitive_index = this->m_primitive_emitter->insert(std::move(pool));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
......
This diff is collapsed.
......@@ -50,6 +50,7 @@ namespace ngraph
public:
enum class Prop
{
Inference,
Forward,
Backward
};
......@@ -69,6 +70,13 @@ namespace ngraph
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_batchnorm(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnBatchNormMode_t& bn_op,
const Prop& direction,
const Shape& tensor_shape,
const Shape& param_shape,
double epsilon);
private:
CUDNNEmitter(GPUPrimitiveEmitter* emitter);
GPUPrimitiveEmitter* m_primitive_emitter;
......
......@@ -1415,6 +1415,100 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::BatchNorm)
{
const ngraph::op::BatchNorm* batchnorm =
static_cast<const ngraph::op::BatchNorm*>(node);
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
CUDNNEmitter::Prop direction;
if (batchnorm->get_training_flag() && args.size() == 3)
{
direction = CUDNNEmitter::Prop::Forward;
}
else
{
direction = CUDNNEmitter::Prop::Inference;
}
auto bn_index = cudnn_emitter->build_batchnorm(external_function->ctx().get(),
CUDNN_BATCHNORM_SPATIAL,
direction,
args[2].get_shape(),
args[0].get_shape(),
batchnorm->get_eps_value());
writer.block_begin(" // " + node->get_name());
{
writer << "gpu::invoke_primitive(ctx, " << bn_index << ", ";
writer << "std::vector<void*>{" << args.front().get_name();
for (size_t i = 1; i < args.size(); i++)
{
writer << ", " << args[i].get_name();
}
writer << "}.data(), ";
writer << "std::vector<void*>{" << out.front().get_name();
for (size_t i = 1; i < out.size(); i++)
{
writer << ", " << out[i].get_name();
}
writer << "}.data()";
writer << ");\n";
}
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormBackprop)
{
const ngraph::op::BatchNormBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormBackprop*>(node);
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
auto bn_index = cudnn_emitter->build_batchnorm(external_function->ctx().get(),
CUDNN_BATCHNORM_SPATIAL,
CUDNNEmitter::Prop::Backward,
args[2].get_shape(),
args[0].get_shape(),
batchnorm->get_eps_value());
writer.block_begin(" // " + node->get_name());
{
writer << "gpu::invoke_primitive(ctx, " << bn_index << ", ";
writer << "std::vector<void*>{" << args.front().get_name();
for (size_t i = 1; i < args.size(); i++)
{
writer << ", " << args[i].get_name();
}
writer << "}.data(), ";
writer << "std::vector<void*>{" << out.front().get_name();
for (size_t i = 1; i < out.size(); i++)
{
writer << ", " << out[i].get_name();
}
writer << "}.data()";
writer << ");\n";
}
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::GetOutputElement)
{
auto get_tuple_element = static_cast<const ngraph::op::GetOutputElement*>(node);
writer.block_begin(" // " + node->get_name());
writer << "runtime::gpu::cuda_memcpyDtH(" << out[0].get_name() << ", "
<< args[get_tuple_element->get_n()].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.block_end();
}
// assumes NC{d1,d2,d3,...} format
Shape get_padded_shape(const Shape& input_shape,
const Shape& padding_below,
......
......@@ -27,13 +27,7 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter()
, m_cudnn_emitter(new CUDNNEmitter(this))
{
}
GPUPrimitiveEmitter::~GPUPrimitiveEmitter()
{
for (auto& primitive : m_gpu_primitives)
{
delete primitive;
}
}
std::unique_ptr<CUDAEmitter>& GPUPrimitiveEmitter::get_cuda_emitter()
{
return m_cuda_emitter;
......@@ -42,9 +36,10 @@ std::unique_ptr<CUDNNEmitter>& GPUPrimitiveEmitter::get_cudnn_emitter()
{
return m_cudnn_emitter;
}
size_t GPUPrimitiveEmitter::insert(gpu::primitive* f)
size_t GPUPrimitiveEmitter::insert(std::unique_ptr<gpu::primitive>&& f)
{
m_gpu_primitives.push_back(f);
m_managed_primitives.emplace_back(std::move(f));
m_gpu_primitives.push_back(m_managed_primitives.back().get());
return m_gpu_primitives.size() - 1;
}
size_t GPUPrimitiveEmitter::lookup(std::string hash)
......
......@@ -35,11 +35,10 @@ namespace ngraph
{
public:
GPUPrimitiveEmitter();
~GPUPrimitiveEmitter();
std::unique_ptr<CUDAEmitter>& get_cuda_emitter();
std::unique_ptr<CUDNNEmitter>& get_cudnn_emitter();
std::vector<gpu::primitive*>& get_primitives() { return m_gpu_primitives; }
size_t insert(gpu::primitive* f);
size_t insert(std::unique_ptr<gpu::primitive>&& f);
size_t lookup(std::string hash);
void cache(const std::string& hash, const size_t& index);
......@@ -48,6 +47,7 @@ namespace ngraph
std::unique_ptr<CUDNNEmitter> m_cudnn_emitter;
std::vector<gpu::primitive*> m_gpu_primitives;
std::unordered_map<std::string, size_t> m_primitive_map;
std::vector<std::unique_ptr<gpu::primitive>> m_managed_primitives;
};
}
}
......
......@@ -68,6 +68,11 @@ void runtime::gpu::cuda_memcpyHtD(void* dst, void* src, size_t buffer_size)
cudaMemcpy(dst, src, buffer_size, cudaMemcpyHostToDevice);
}
void runtime::gpu::cuda_memcpyDtH(void* dst, void* src, size_t buffer_size)
{
cudaMemcpy(dst, src, buffer_size, cudaMemcpyDeviceToHost);
}
void runtime::gpu::cuda_memset(void* dst, int value, size_t buffer_size)
{
cudaMemset(dst, value, buffer_size);
......
......@@ -95,6 +95,7 @@ namespace ngraph
void free_gpu_buffer(void* buffer);
void cuda_memcpyDtD(void* dst, void* src, size_t buffer_size);
void cuda_memcpyHtD(void* dst, void* src, size_t buffer_size);
void cuda_memcpyDtH(void* dst, void* src, size_t buffer_size);
void cuda_memset(void* dst, int value, size_t buffer_size);
}
}
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment