Commit fbc3a940 authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

Cache and use fprop stats in cudnn batchnorm bprop (#1841)

* Temp bn update commit.

* Add CUDNNBatchNorm which adds two additional outputs to batchnorm, the batch mean and batch inv variance.
The batch mean is the same as the output mean if the cummulative average factor is 1.0. Add BatchNormCache pass which replaces all BatchNorm ops that are inputs to BatchNormBackprop
with CUDNNBatchNorm which outputs the saved batch statistics directly to the backprop step.

* Updated bn cache pass, removed extra tests, added test checking that provided stats are used in bprop instead of batch stats.
This test was disabled for interpreter as the reference kernel needs to be updated to use provided statistics.

* Formatting.

* Update to new batch norm API.

* CUDNNBatchNorm -> BatchNormTrainingWithStats

* new line

* Preprocess input variance into BN denominator for cudnn (#1885)

* Add explicit cuda kernel to calculate what cuDNN describes as the inverse
variance. In reality, the backward cudnn kernel for BN requires 1.0f / sqrt(variance + eps),
which is the batchnorm denominator for each channel (a numerically stable inverse stddev).

This introduces op annotations for batch norm backprop and updates the cudnn_emitter to support the insertion of this cuda kernel when required.

* Disable second test on INTERPRETER.
parent baf8da98
......@@ -35,6 +35,8 @@ namespace ngraph
class OpAnnotations
{
public:
virtual ~OpAnnotations() = default;
void add_in_place_oi_pair(const struct oi_pair& oi)
{
for (auto e : m_in_place_oi_pairs)
......
......@@ -42,6 +42,8 @@ set(SRC
pass/tensor_memory_reservation.cpp
gpu_kernel_args.cpp
pass/gpu_rnn_fusion.cpp
pass/gpu_batch_norm_cache.cpp
op/batch_norm.cpp
op/rnn.cpp
)
......
......@@ -1228,6 +1228,71 @@ size_t runtime::gpu::CUDAEmitter::build_elementwise_n_to_1(const std::vector<std
return this->m_primitive_emitter->register_primitive(ew, hash);
}
size_t runtime::gpu::CUDAEmitter::build_cudnn_bn_inv_var(const std::vector<std::string>& dtypes,
NVShape tensor_shape,
const double& eps)
{
uint32_t nthreads = static_cast<uint32_t>(shape_size(tensor_shape));
// kernel_name is used to check if the cuda kernel has been previously compiled
std::stringstream kernel_name;
kernel_name << "cudnn_bn_inv_var"
<< "_" << join(dtypes, "_");
// hash is used to check if the emitted primitive already exists
std::stringstream ss;
ss << kernel_name.str() << "_s" << join(tensor_shape, "_") << "_eps" << eps;
auto hash = ss.str();
// if the primitive exists, we are done
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
uint32_t block_size_x = 512;
int num_SMs;
CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute(&num_SMs, cudaDevAttrMultiProcessorCount, 0));
uint32_t aligned_grid_size_x = fmin(num_SMs * 32, align_to_block_size(nthreads, block_size_x));
auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes[0], "in")
.add_placeholder(dtypes[1], "out")
.add("epsilon", eps)
.add("nthreads", nthreads);
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_cudnn_bn_inv_var_op(writer, kernel_name.str(), args);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
// create the launch primitive
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &outputs[0])
.get_argument_list();
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1, // grid dim
block_size_x,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
debug_sync();
}});
return this->m_primitive_emitter->register_primitive(kernel_launch, hash);
}
size_t runtime::gpu::CUDAEmitter::build_primitive(const op::MaxPool* node)
{
auto& args = node->get_inputs();
......
......@@ -114,6 +114,10 @@ namespace ngraph
dtypes, tensor_shape, CudaOpMap<T>::op, CudaOpMap<T>::math_kernel);
}
size_t build_cudnn_bn_inv_var(const std::vector<std::string>& dtypes,
NVShape tensor_shape,
const double& eps);
template <typename T>
size_t build_reduce(const std::vector<std::string>& dtypes,
const size_t data_bytes,
......
......@@ -1622,7 +1622,9 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
const Shape& tensor_shape,
const Shape& param_shape,
double epsilon,
bool global_stats)
bool global_stats,
bool save_stats,
bool invert_variance)
{
// Assumes NC{d1...dN} format
std::stringstream ss;
......@@ -1630,7 +1632,7 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
ss << "bn_op" << bn_op << "_dtype_" << dtype << "_dir" << static_cast<int>(direction) << "_ts"
<< join(tensor_shape, "_") << "_ps" << join(param_shape, "_") << "_eps" << epsilon << "_g"
<< global_stats;
<< global_stats << "_s" << save_stats << "_invvar" << invert_variance;
std::string hash = ss.str();
std::replace(hash.begin(), hash.end(), '.', '_');
......@@ -1696,6 +1698,8 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
[=, &op_desc, &tensor_desc, &derived_param_desc](void** inputs, void** outputs) {
auto mean = (global_stats ? inputs[3] : outputs[1]);
auto variance = (global_stats ? inputs[4] : outputs[2]);
auto saved_mean = (save_stats ? outputs[3] : nullptr);
auto saved_inv_var = (save_stats ? outputs[4] : nullptr);
CUDNN_SAFE_CALL(cudnnBatchNormalizationForwardTraining(*m_ctx->cudnn_handle,
bn_op,
alpha,
......@@ -1711,8 +1715,8 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
mean,
variance,
epsilon,
NULL,
NULL));
saved_mean,
saved_inv_var));
debug_sync();
// convert to biased variance
......@@ -1733,30 +1737,50 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
}
case Prop::Backward:
{
batchnorm.reset(new gpu::primitive{
[=, &tensor_desc, &derived_param_desc](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnBatchNormalizationBackward(
*m_ctx->cudnn_handle,
bn_op,
alpha,
beta,
alpha,
beta,
tensor_desc,
inputs[2 /* input tensor x */],
tensor_desc,
inputs[5 /* dy */],
tensor_desc,
outputs[0 /* dx */],
derived_param_desc,
inputs[0 /* gamma */],
outputs[1 /* dgamma */],
outputs[2 /* dbeta */],
epsilon,
NULL, // inputs[3 /* mu batch mean*/],
NULL)); // inputs[4 /* 1/sig**2 batch inverse variance*/]);
debug_sync();
gpu::primitive bnbp = [=, &tensor_desc, &derived_param_desc](void** inputs,
void** outputs) {
CUDNN_SAFE_CALL(cudnnBatchNormalizationBackward(*m_ctx->cudnn_handle,
bn_op,
alpha,
beta,
alpha,
beta,
tensor_desc,
inputs[2 /* input tensor x */],
tensor_desc,
inputs[5 /* dy */],
tensor_desc,
outputs[0 /* dx */],
derived_param_desc,
inputs[0 /* gamma */],
outputs[1 /* dgamma */],
outputs[2 /* dbeta */],
epsilon,
inputs[3], // batch mean
inputs[4])); // batch inverse variance
debug_sync();
};
if (invert_variance)
{
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t inv_var_idx = allocator.reserve_workspace(tensor_shape[1] * dtype.size());
auto& cuda_emitter = m_primitive_emitter->get_cuda_emitter();
auto reciprocal_idx = cuda_emitter->build_cudnn_bn_inv_var(
{dtype, dtype}, Shape{tensor_shape[1]}, epsilon);
batchnorm.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
void* inv_var = runtime::gpu::invoke_memory_primitive(m_ctx, inv_var_idx);
gpu::invoke_primitive(m_ctx, reciprocal_idx, &inputs[4], &inv_var);
inputs[4] = inv_var;
bnbp(inputs, outputs);
}});
}
else
{
batchnorm.reset(
new gpu::primitive{[=](void** inputs, void** outputs) { bnbp(inputs, outputs); }});
}
break;
}
}
......
......@@ -143,7 +143,9 @@ namespace ngraph
const Shape& tensor_shape,
const Shape& param_shape,
double epsilon,
bool global_stats = false);
bool global_stats = false,
bool save_stats = false,
bool invert_variance = false);
size_t build_lrn(const std::string& dtype,
const Prop& direction,
......
......@@ -56,6 +56,27 @@ void runtime::gpu::CudaKernelBuilder::get_elementwise_op(codegen::CodeWriter& wr
return;
}
void runtime::gpu::CudaKernelBuilder::get_cudnn_bn_inv_var_op(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args)
{
writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature();
writer.block_begin();
{
writer << "uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; \n";
writer << "uint32_t step = gridDim.x * blockDim.x; \n";
writer << "for (; tid < nthreads; tid += step)\n";
writer.block_begin();
{
writer << "out[tid] = 1.0f / sqrtf(in[tid] + epsilon);\n";
}
writer.block_end();
}
writer.block_end();
return;
}
void runtime::gpu::CudaKernelBuilder::get_softmax_divide_op(
codegen::CodeWriter& writer,
const std::string& name,
......
......@@ -40,6 +40,10 @@ namespace ngraph
const std::string& op,
const std::vector<std::string>& data_types);
static void get_cudnn_bn_inv_var_op(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args);
static void get_broadcast_op(codegen::CodeWriter& writer,
const std::string& name,
GPUKernelArgs& args,
......
......@@ -106,8 +106,10 @@
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
#include "ngraph/runtime/gpu/gpu_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_op_annotations.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
#include "ngraph/runtime/gpu/op/batch_norm.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/runtime/gpu/type_info.hpp"
#include "ngraph/util.hpp"
......@@ -325,55 +327,54 @@ void runtime::gpu::GPU_Emitter::emit_AvgPoolBackprop(EMIT_ARGS)
writer.block_end();
}
void runtime::gpu::GPU_Emitter::emit_BatchNormInference(EMIT_ARGS)
static void emit_BatchNorm(EMIT_ARGS, runtime::gpu::CUDNNEmitter::Prop direction, bool save_stats)
{
const ngraph::op::BatchNormInference* batchnorm =
static_cast<const ngraph::op::BatchNormInference*>(node);
CUDNNEmitter::Prop direction = CUDNNEmitter::Prop::Inference;
const ngraph::op::BatchNormBase* batchnorm =
static_cast<const ngraph::op::BatchNormBase*>(node);
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_batchnorm(CUDNN_BATCHNORM_SPATIAL,
out[0].get_type(),
direction,
args[2].get_shape(),
args[0].get_shape(),
batchnorm->get_eps_value());
writer.block_begin();
bool global_stats = false;
if (direction == runtime::gpu::CUDNNEmitter::Prop::Forward)
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
global_stats = (batchnorm->get_arguments().size() == 5);
}
writer.block_end();
}
void runtime::gpu::GPU_Emitter::emit_BatchNormTraining(EMIT_ARGS)
{
const ngraph::op::BatchNormTraining* batchnorm =
static_cast<const ngraph::op::BatchNormTraining*>(node);
CUDNNEmitter::Prop direction = CUDNNEmitter::Prop::Forward;
bool global_stats = (batchnorm->get_arguments().size() == 5);
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_batchnorm(CUDNN_BATCHNORM_SPATIAL,
out[0].get_type(),
direction,
args[2].get_shape(),
args[0].get_shape(),
batchnorm->get_eps_value(),
global_stats);
global_stats,
save_stats);
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "void* input[] = {" << runtime::gpu::GPU_Emitter::node_names(args) << "};\n";
writer << "void* output[] = {" << runtime::gpu::GPU_Emitter::node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
}
void runtime::gpu::GPU_Emitter::emit_BatchNormInference(EMIT_ARGS)
{
::emit_BatchNorm(
external_function, writer, node, args, out, CUDNNEmitter::Prop::Inference, false);
}
void runtime::gpu::GPU_Emitter::emit_BatchNormTraining(EMIT_ARGS)
{
::emit_BatchNorm(
external_function, writer, node, args, out, CUDNNEmitter::Prop::Forward, false);
}
void runtime::gpu::GPU_Emitter::emit_BatchNormTrainingWithStats(EMIT_ARGS)
{
::emit_BatchNorm(external_function, writer, node, args, out, CUDNNEmitter::Prop::Forward, true);
}
void runtime::gpu::GPU_Emitter::emit_BatchNormTrainingBackprop(EMIT_ARGS)
{
const ngraph::op::BatchNormTrainingBackprop* batchnorm =
......@@ -381,13 +382,26 @@ void runtime::gpu::GPU_Emitter::emit_BatchNormTrainingBackprop(EMIT_ARGS)
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
bool needs_variance_inversion = false;
auto annotation = batchnorm->get_op_annotations();
if (annotation)
{
auto bnbp_annotation =
std::dynamic_pointer_cast<runtime::gpu::BatchNormBackpropAnnotations>(annotation);
if (bnbp_annotation && bnbp_annotation->has_inverted_variance() == false)
{
needs_variance_inversion = true;
}
}
auto index = cudnn_emitter->build_batchnorm(CUDNN_BATCHNORM_SPATIAL,
out[0].get_type(),
CUDNNEmitter::Prop::Backward,
args[2].get_shape(),
args[0].get_shape(),
batchnorm->get_eps_value());
batchnorm->get_eps_value(),
false,
false,
needs_variance_inversion);
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
......
......@@ -77,7 +77,6 @@ namespace ngraph
static void emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t);
private:
/// \brief Create a list of node names for each arg in args
/// \param args list of tensor arguments
/// \param arg_indexes a list of indexes into args for which args to include in
......
......@@ -112,7 +112,9 @@
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
#include "ngraph/runtime/gpu/op/batch_norm.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/runtime/gpu/pass/gpu_batch_norm_cache.hpp"
#include "ngraph/runtime/gpu/pass/gpu_layout.hpp"
#include "ngraph/runtime/gpu/pass/gpu_rnn_fusion.hpp"
#include "ngraph/runtime/gpu/pass/tensor_memory_reservation.hpp"
......@@ -568,6 +570,7 @@ void runtime::gpu::GPU_ExternalFunction::compile()
#else
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
#endif
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
......
......@@ -28,7 +28,17 @@ namespace ngraph
class GPUOpAnnotations : public ngraph::op::util::OpAnnotations
{
public:
GPUOpAnnotations() {}
virtual ~GPUOpAnnotations() = default;
};
class BatchNormBackpropAnnotations : public GPUOpAnnotations
{
public:
~BatchNormBackpropAnnotations() = default;
bool has_inverted_variance() { return m_inv_variance; }
void set_inverted_variance(bool b) { m_inv_variance = b; }
private:
bool m_inv_variance = false;
};
}
}
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/gpu/op/batch_norm.hpp"
ngraph::op::gpu::BatchNormTrainingWithStats::BatchNormTrainingWithStats(
double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input)
: ngraph::op::BatchNormTraining(eps, gamma, beta, input)
{
auto output_index = get_output_size();
set_output_size(output_index + 2);
Shape channel_shape{input->get_shape()[1]};
// saved batch mean
set_output_type(output_index++, input->get_element_type(), channel_shape);
// saved batch inverse variance
set_output_type(output_index++, input->get_element_type(), channel_shape);
}
ngraph::op::gpu::BatchNormTrainingWithStats::BatchNormTrainingWithStats(
double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
bool training)
: ngraph::op::BatchNormTraining(eps, gamma, beta, input, mean, variance)
{
auto output_index = get_output_size();
set_output_size(output_index + 2);
Shape channel_shape{input->get_shape()[1]};
// saved batch mean
set_output_type(output_index++, input->get_element_type(), channel_shape);
// saved batch inverse variance
set_output_type(output_index++, input->get_element_type(), channel_shape);
}
std::shared_ptr<ngraph::Node> ngraph::op::gpu::BatchNormTrainingWithStats::copy_with_new_args(
const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return std::make_shared<ngraph::op::gpu::BatchNormTrainingWithStats>(
get_eps_value(), new_args.at(0), new_args.at(1), new_args.at(2));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
namespace gpu
{
class BatchNormTrainingWithStats : public ngraph::op::BatchNormTraining
{
public:
BatchNormTrainingWithStats(double eps,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input);
BatchNormTrainingWithStats(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
bool training = false);
protected:
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
}
......@@ -18,3 +18,4 @@
#if CUDNN_VERSION >= 7200
NGRAPH_OP(Rnn, ngraph::op::gpu)
#endif
NGRAPH_OP(BatchNormTrainingWithStats, ngraph::op::gpu)
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <unordered_map>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/runtime/gpu/gpu_op_annotations.hpp"
#include "ngraph/runtime/gpu/op/batch_norm.hpp"
#include "ngraph/runtime/gpu/pass/gpu_batch_norm_cache.hpp"
using namespace ngraph;
bool ngraph::runtime::gpu::pass::BatchNormCache::run_on_function(
std::shared_ptr<ngraph::Function> f)
{
bool replaced = false;
for (auto n : f->get_ordered_ops())
{
if (auto bnbp = std::dynamic_pointer_cast<op::BatchNormTrainingBackprop>(n))
{
// batch norm bprop annotations are used to indicate if variance is in inverse stddev format
auto op_annotations =
std::make_shared<ngraph::runtime::gpu::BatchNormBackpropAnnotations>();
// pass must be run prior to GOE elimination
// collect all batch norm inputs to batch norm backward op
std::vector<std::shared_ptr<op::GetOutputElement>> goes;
for (auto& arg : bnbp->get_arguments())
{
if (auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(arg))
{
if (auto bn = std::dynamic_pointer_cast<op::BatchNormTraining>(
goe->get_arguments().at(0)))
{
goes.push_back(goe);
}
}
}
// only replace if some of the inputs to backprop are from fprop directly
if (goes.size())
{
if (auto target = std::dynamic_pointer_cast<op::BatchNormTraining>(
goes.front()->get_arguments().at(0)))
{
auto replacement = std::make_shared<op::gpu::BatchNormTrainingWithStats>(
target->get_eps_value(),
target->get_argument(0),
target->get_argument(1),
target->get_argument(2));
// replace all users of old batchnorm with cudnn batchnorm
for (size_t i = 0; i < target->get_outputs().size(); i++)
{
auto& target_output = target->get_outputs().at(i);
std::set<ngraph::descriptor::Input*> copy_inputs{
begin(target_output.get_inputs()), end(target_output.get_inputs())};
for (auto input : copy_inputs)
{
input->replace_output(replacement->get_outputs().at(i));
}
}
// for each output of forward op into backprop op
// use the mean and inverse variance from the forward
// cudnn op to avoid recalculation of batch statistics
for (auto& goe : goes)
{
auto out_idx = goe->get_n();
if (out_idx != 0)
{
auto new_goe =
std::make_shared<op::GetOutputElement>(replacement, out_idx + 2);
ngraph::replace_node(goe, new_goe);
}
}
replaced = true;
op_annotations->set_inverted_variance(true);
}
}
bnbp->set_op_annotations(op_annotations);
}
}
return replaced;
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
namespace pass
{
class BatchNormCache;
}
}
}
}
class ngraph::runtime::gpu::pass::BatchNormCache : public ngraph::pass::FunctionPass
{
public:
BatchNormCache()
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
};
......@@ -3,4 +3,6 @@ batchnorm_fprop_b1c2h2w2
batchnorm_fprop_b2c2h2w1
batchnorm_fprop_globalstats_b2c2w2h1
batchnorm_fprop_inference_b2c2h2w1
batchnorm_fprop_bprop
batchnorm_fprop_bprop_2step
computation_reuse
......@@ -23,6 +23,7 @@
#include "gtest/gtest.h"
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/serializer.hpp"
......@@ -5865,3 +5866,78 @@ NGRAPH_TEST(${BACKEND_NAME}, quantize_ROUND_DOWN)
EXPECT_EQ((vector<output_c_type>{2, 2, 2, -3, -3, -3, 3, 3, 3, -4, -4, -4}),
read_vector<output_c_type>(y));
}
NGRAPH_TEST(${BACKEND_NAME}, batchnorm_fprop_bprop)
{
Shape sca{1};
Shape vec{1, 1, 1, 2};
double eps = 1.0e-04;
auto g = std::make_shared<op::Parameter>(element::f32, sca);
auto b = std::make_shared<op::Parameter>(element::f32, sca);
auto input = std::make_shared<op::Parameter>(element::f32, vec);
auto bn_fp = std::make_shared<op::BatchNormTraining>(eps, g, b, input);
auto bnorm = std::make_shared<op::GetOutputElement>(bn_fp, 0);
auto mean = std::make_shared<op::GetOutputElement>(bn_fp, 1);
auto var = std::make_shared<op::GetOutputElement>(bn_fp, 2);
auto delta = std::make_shared<op::Parameter>(element::f32, vec);
auto bn_bp =
std::make_shared<op::BatchNormTrainingBackprop>(eps, g, b, bnorm, mean, var, delta);
auto dx = std::make_shared<op::GetOutputElement>(bn_bp, 0);
std::vector<std::vector<float>> args = {
{1.0f}, // gamma
{1.0f}, // beta
{1.1f, 1.0f}, // x
{1.0f, 1.0f}, // dy
};
auto func = std::make_shared<Function>(dx, op::ParameterVector{g, b, input, delta});
auto results = execute(func, args, "${BACKEND_NAME}");
EXPECT_TRUE(test::all_close_f(std::vector<float>{350.957, -388.67}, results.at(0)));
}
NGRAPH_TEST(${BACKEND_NAME}, batchnorm_fprop_bprop_2step)
{
Shape sca{1};
Shape vec{1, 1, 1, 2};
double eps = 1.0e-04;
auto g = std::make_shared<op::Parameter>(element::f32, sca);
auto b = std::make_shared<op::Parameter>(element::f32, sca);
auto input = std::make_shared<op::Parameter>(element::f32, vec);
auto bn_fp = std::make_shared<op::BatchNormTraining>(eps, g, b, input);
auto bnorm = std::make_shared<op::GetOutputElement>(bn_fp, 0);
auto mean = std::make_shared<op::GetOutputElement>(bn_fp, 1);
auto var = std::make_shared<op::GetOutputElement>(bn_fp, 2);
auto func_bn =
std::make_shared<Function>(NodeVector{bnorm, mean, var}, op::ParameterVector{g, b, input});
std::vector<std::vector<float>> args = {
{1.0f}, // gamma
{1.0f}, // beta
{1.1f, 1.0f}, // x
};
auto results = execute(func_bn, args, "${BACKEND_NAME}");
g = std::make_shared<op::Parameter>(element::f32, sca);
b = std::make_shared<op::Parameter>(element::f32, sca);
auto bn_output = std::make_shared<op::Parameter>(element::f32, vec);
auto m = std::make_shared<op::Parameter>(element::f32, sca);
auto v = std::make_shared<op::Parameter>(element::f32, sca);
auto delta = std::make_shared<op::Parameter>(element::f32, vec);
auto bn_bp = std::make_shared<op::BatchNormTrainingBackprop>(eps, g, b, bn_output, m, v, delta);
auto dx = std::make_shared<op::GetOutputElement>(bn_bp, 0);
args.pop_back(); // remove x
args.push_back(results.at(0)); // bn_output
args.push_back(results.at(1)); // m
args.push_back(results.at(2)); // v
args.push_back({1.0f, 1.0f}); // dy
auto func = std::make_shared<Function>(dx, op::ParameterVector{g, b, bn_output, m, v, delta});
results = execute(func, args, "${BACKEND_NAME}");
EXPECT_TRUE(test::all_close_f(std::vector<float>{350.957, -388.67}, results.at(0)));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment