Commit fe06f325 authored by gcwenger's avatar gcwenger Committed by Robert Kimball

Support LRN for NVGPU Backend (#1740)

* LRN WIP

* Explicit lambda captures.

* Switched to Ayan's new caching routine.

* Remove commented out lrn from manifest.

* Fixed clang 3.9 error.

* Corrected lrn hash. Only call cudnnSetLRNDescriptor once.

* Simplified lrn hash. Removed redundant parameters. No longer passing CUDNN_LRN_CROSS_CHANNEL_DIM1 as parameter because it's the only choice for cudnnLRNCrossChannelForward.
parent c8858ef2
......@@ -1665,6 +1665,58 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
return this->m_primitive_emitter->register_primitive(batchnorm, hash);
}
size_t runtime::gpu::CUDNNEmitter::build_lrn(const std::string& dtype,
const Prop& direction,
const Shape& io_shape,
const double lrn_alpha,
const double lrn_beta,
const double lrn_bias,
const size_t lrn_size)
{
// construct hash to determine if kernel needs to be emitted
// or if it already exists in the primitive list
std::stringstream ss;
ss << "lrn_dtype_" << dtype << "_dir" << static_cast<int>(direction) << "_io"
<< join(io_shape, "_") << "_alpha_" << lrn_alpha << "_beta_" << lrn_beta << "_bias_"
<< lrn_bias << "_size_" << lrn_size;
std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
cudnnDataType_t data_type = get_cudnn_datatype(dtype);
cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW;
auto& io_desc = tensor_descriptor_from_shape(io_shape, data_type, tensor_format);
auto& lrn_descriptor = m_descriptors.build<cudnnLRNDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetLRNDescriptor(
lrn_descriptor, static_cast<unsigned int>(lrn_size), lrn_alpha, lrn_beta, lrn_bias));
void* alpha = m_host_parameters.allocate_by_datatype(data_type, 1.0);
void* beta = m_host_parameters.allocate_by_datatype(data_type, 0);
// emit lrn operation
std::unique_ptr<gpu::primitive> lrn(new gpu::primitive{
[&lrn_descriptor, &io_desc, this, alpha, beta](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnLRNCrossChannelForward(*m_ctx->cudnn_handle,
lrn_descriptor,
CUDNN_LRN_CROSS_CHANNEL_DIM1,
alpha,
io_desc,
inputs[0],
beta,
io_desc,
outputs[0]));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->register_primitive(lrn, hash);
return primitive_index;
}
size_t runtime::gpu::CUDNNEmitter::build_softmax(const cudnnSoftmaxAlgorithm_t& algorithm,
const cudnnSoftmaxMode_t& mode,
const std::string& dtype,
......
......@@ -129,6 +129,14 @@ namespace ngraph
double epsilon,
bool global_stats = false);
size_t build_lrn(const std::string& dtype,
const Prop& direction,
const Shape& io_shape,
const double lrn_alpha,
const double lrn_beta,
const double lrn_bias,
const size_t lrn_size);
size_t build_softmax(const cudnnSoftmaxAlgorithm_t& algorithm,
const cudnnSoftmaxMode_t& mode,
const std::string& dtype,
......
......@@ -624,6 +624,25 @@ void runtime::gpu::GPU_Emitter::emit_Log(EMIT_ARGS)
void runtime::gpu::GPU_Emitter::emit_LRN(EMIT_ARGS)
{
auto lrn = static_cast<const ngraph::op::LRN*>(node);
auto& input_shape = args[0].get_shape();
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
size_t index = cudnn_emitter->build_lrn(out[0].get_type(),
CUDNNEmitter::Prop::Forward,
input_shape,
lrn->get_alpha(),
lrn->get_beta(),
lrn->get_bias(),
lrn->get_nsize());
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
}
void runtime::gpu::GPU_Emitter::emit_Max(EMIT_ARGS)
......
......@@ -12,7 +12,6 @@ divide_by_zero_int32
dot_matrix_vector_int64
#no mkldnn on GPU
#error throw is not the same on GPU, not supported yet
lrn
one_hot_scalar_fp_nonint_in_3
one_hot_scalar_oob_in_3
one_hot_vector_1_barely_oob
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment