Commit eef2b19d authored by Fenglei's avatar Fenglei Committed by Nick Korovaiko

enable cudnn datatype support (#1122)

* enable multi datatpye support for Cudnn. refactor binary ops using cudnn

* fix bugs

* add tests to skip list that CUDNN does not support

* not int support on cudnn for backward pooling

* no GPU.dot_4d_5d_multi_axis_big_fp64_VERY_SLOW test anymore

* clang format

* throw if datatype is int8 or int32 for backward pooling

* comments

* fix list in unit_test.manifest

* add type support for alpha, beta

* fix bugs

* datatype support for alpha, beta

* missing ()

* clang format

* batchnorm backward bug fix

* remove debug info

* change member function name to snake case. remove comments

* use nullptr instead of NULL

* code style, use cuDNN everywhere in comments

* add cudnn host parameters memory manager.

* change name to allocate_by_datatype

* compiled

* debug

* fix bug: using list instead of vector, vector address will change each time it resize

* add CUDNN_DATA_UINT8 and CUDNN_DATA_UINT8x4
parent 35b04e6a
......@@ -28,8 +28,8 @@
using namespace ngraph;
cudnnTensorDescriptor_t&
runtime::gpu::CUDNNEmitter::tensor_descriptor_from_shape(const Shape& shape)
cudnnTensorDescriptor_t& runtime::gpu::CUDNNEmitter::tensor_descriptor_from_shape(
const Shape& shape, const cudnnDataType_t data_type, const cudnnTensorFormat_t tensor_format)
{
cudnnTensorDescriptor_t& desc = m_descriptors.build<cudnnTensorDescriptor_t>();
if (shape.size() < 4)
......@@ -45,8 +45,8 @@ cudnnTensorDescriptor_t&
dimensions[pos++] = static_cast<int>(shape[i]);
}
CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptor(desc,
CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT,
tensor_format,
data_type,
dimensions[0],
dimensions[1],
dimensions[2],
......@@ -55,8 +55,8 @@ cudnnTensorDescriptor_t&
else if (shape.size() == 4)
{
CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptor(desc,
CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT,
tensor_format,
data_type,
static_cast<int>(shape[0]),
static_cast<int>(shape[1]),
static_cast<int>(shape[2]),
......@@ -71,7 +71,7 @@ cudnnTensorDescriptor_t&
}
CUDNN_SAFE_CALL(cudnnSetTensorNdDescriptor(
desc,
CUDNN_DATA_FLOAT,
data_type,
static_cast<int>(dimensions.size()),
dimensions.data(),
runtime::gpu::cudnn_util::compute_strides(dimensions).data()));
......@@ -112,13 +112,30 @@ runtime::gpu::CUDNNEmitter::CUDNNEmitter(GPUPrimitiveEmitter* emitter)
{
}
cudnnDataType_t runtime::gpu::CUDNNEmitter::get_cudnn_datatype(std::string dtype)
{
static const std::unordered_map<std::string, cudnnDataType_t> datatype_map{
{"float", CUDNN_DATA_FLOAT},
{"double", CUDNN_DATA_DOUBLE},
{"int8_t", CUDNN_DATA_INT8},
{"int32_t", CUDNN_DATA_INT32}};
auto p = datatype_map.find(dtype);
if (p == datatype_map.end())
{
std::string err = dtype + "is not supported by cuDNN";
throw std::runtime_error(err);
}
return p->second;
}
size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnReduceTensorOp_t& reduce_op,
const std::string& dtype,
const Shape& input_shape,
const AxisSet& reduction_axes)
{
std::stringstream ss;
ss << "reduce_op" << reduce_op << "_i" << join(input_shape, "_") << "_ra"
ss << "reduce_op_" << reduce_op << "_dtype_" << dtype << "_i" << join(input_shape, "_") << "_ra"
<< join(reduction_axes, "_");
std::string hash = ss.str();
......@@ -130,14 +147,16 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const runtime::gpu::GPUR
}
auto& desc = m_descriptors.build<cudnnReduceTensorDescriptor_t>();
auto& input_desc = tensor_descriptor_from_shape(input_shape);
cudnnDataType_t data_type = get_cudnn_datatype(dtype);
cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW;
auto& input_desc = tensor_descriptor_from_shape(input_shape, data_type, tensor_format);
Shape output_shape = input_shape;
// mark reduced axes of input tensor for output tensor descriptor
for (auto const& idx_dim : reduction_axes)
{
output_shape[idx_dim] = 1;
}
auto& output_desc = tensor_descriptor_from_shape(output_shape);
auto& output_desc = tensor_descriptor_from_shape(output_shape, data_type, tensor_format);
// get an allocator for transient per kernel gpu memory
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
......@@ -145,30 +164,30 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const runtime::gpu::GPUR
CUDNN_SAFE_CALL(cudnnGetReductionWorkspaceSize(
*ctx->cudnn_handle, desc, input_desc, output_desc, &workspace_size));
size_t workspace_idx = allocator.reserve_workspace(workspace_size);
void* alpha = m_host_parameters.allocate_by_datatype(data_type, 1.0);
void* beta = m_host_parameters.allocate_by_datatype(data_type, 0);
// emit reduce operation
std::unique_ptr<gpu::primitive> reduce(
new gpu::primitive{[=, &desc, &input_desc, &output_desc](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnSetReduceTensorDescriptor(desc,
reduce_op,
CUDNN_DATA_FLOAT,
data_type,
CUDNN_NOT_PROPAGATE_NAN,
CUDNN_REDUCE_TENSOR_NO_INDICES,
CUDNN_32BIT_INDICES));
void* workspace_ptr = runtime::gpu::invoke_memory_primitive(ctx, workspace_idx);
float alpha = 1.0, beta = 0.0;
CUDNN_SAFE_CALL(cudnnReduceTensor(*ctx->cudnn_handle,
desc,
nullptr,
0,
workspace_ptr,
workspace_size,
&alpha,
alpha,
input_desc,
inputs[0],
&beta,
beta,
output_desc,
outputs[0]));
}});
......@@ -178,6 +197,58 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const runtime::gpu::GPUR
return primitive_index;
}
size_t runtime::gpu::CUDNNEmitter::build_tensor_op(const GPURuntimeContext* ctx,
const cudnnOpTensorOp_t& tensor_op,
const std::string& dtype,
const Shape& input_shape,
const double alpha0,
const double alpha1,
const double beta)
{
std::stringstream ss;
ss << "tensor_op" << tensor_op << "_dtype_" << dtype << "_i" << join(input_shape, "_");
std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
auto& opTensorDesc = m_descriptors.build<cudnnOpTensorDescriptor_t>();
cudnnDataType_t data_type = get_cudnn_datatype(dtype);
cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW;
auto& descriptor = tensor_descriptor_from_shape(input_shape, data_type, tensor_format);
void* alpha_dt0 = m_host_parameters.allocate_by_datatype(data_type, alpha0);
void* alpha_dt1 = m_host_parameters.allocate_by_datatype(data_type, alpha1);
void* beta_dt = m_host_parameters.allocate_by_datatype(data_type, beta);
// emit tensor binary operation
std::unique_ptr<gpu::primitive> tensor(
new gpu::primitive{[=, &opTensorDesc, &descriptor](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(
opTensorDesc, tensor_op, data_type, CUDNN_NOT_PROPAGATE_NAN));
CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,
opTensorDesc,
alpha_dt0,
descriptor,
inputs[0],
alpha_dt1,
descriptor,
inputs[1],
beta_dt,
descriptor,
outputs[0]));
}});
primitive_index = this->m_primitive_emitter->insert(std::move(tensor));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
cudnnFilterDescriptor_t& runtime::gpu::CUDNNEmitter::get_cudnn_filter_descriptor(
const Shape& shape, const cudnnDataType_t data_type, const cudnnTensorFormat_t tensor_format)
{
......@@ -256,7 +327,7 @@ cudnnConvolutionDescriptor_t& runtime::gpu::CUDNNEmitter::get_cudnn_convolution_
}
size_t runtime::gpu::CUDNNEmitter::build_convolution(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnDataType_t data_type,
const std::string& dtype,
const Shape& input_tensor_shape,
const Shape& input_filter_shape,
const Shape& output_tensor_shape,
......@@ -267,7 +338,7 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution(const runtime::gpu::GPURunt
// construct hash to determine if kernel needs to be emitted
// or if it already exists in the primitive list
std::stringstream ss;
ss << "convolution_op" << data_type << "_i" << join(input_tensor_shape, "_") << "_w"
ss << "convolution_op_" << dtype << "_i" << join(input_tensor_shape, "_") << "_w"
<< join(input_filter_shape, "_") << "_o" << join(output_tensor_shape, "_") << "_ws"
<< join(window_movement_strides, "_") << "_wd" << join(window_dilation_strides, "_") << "_p"
<< join(padding_below, "_");
......@@ -278,15 +349,21 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution(const runtime::gpu::GPURunt
{
return primitive_index;
}
cudnnDataType_t data_type = get_cudnn_datatype(dtype);
const cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW;
const cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
auto& tensor_desc_0 = tensor_descriptor_from_shape(input_tensor_shape);
auto& tensor_desc_1 = tensor_descriptor_from_shape(output_tensor_shape);
auto& tensor_desc_0 =
tensor_descriptor_from_shape(input_tensor_shape, data_type, tensor_format);
auto& tensor_desc_1 =
tensor_descriptor_from_shape(output_tensor_shape, data_type, tensor_format);
auto& filter_desc = get_cudnn_filter_descriptor(input_filter_shape, data_type, tensor_format);
auto& conv_desc = get_cudnn_convolution_descriptor(
padding_below, window_movement_strides, window_dilation_strides, mode, data_type);
const cudnnConvolutionFwdAlgo_t conv_fwd_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
void* alpha = m_host_parameters.allocate_by_datatype(data_type, 1.0);
void* beta = m_host_parameters.allocate_by_datatype(data_type, 0);
size_t workspace_size_in_bytes = 0;
CUDNN_SAFE_CALL(cudnnGetConvolutionForwardWorkspaceSize(*ctx->cudnn_handle,
......@@ -305,11 +382,9 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution(const runtime::gpu::GPURunt
std::unique_ptr<gpu::primitive> conv;
conv.reset(new gpu::primitive{[=, &conv_desc, &tensor_desc_0, &filter_desc, &tensor_desc_1](
void** inputs, void** outputs) {
float alpha = 1.0;
float beta = 0.0;
void* workspace_ptr = runtime::gpu::invoke_memory_primitive(ctx, workspace_idx);
CUDNN_SAFE_CALL(cudnnConvolutionForward(*ctx->cudnn_handle,
&alpha,
alpha,
tensor_desc_0,
inputs[0],
filter_desc,
......@@ -318,7 +393,7 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution(const runtime::gpu::GPURunt
conv_fwd_algo,
workspace_ptr,
workspace_size_in_bytes,
&beta,
beta,
tensor_desc_1,
outputs[0]));
}});
......@@ -330,7 +405,7 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution(const runtime::gpu::GPURunt
size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_data(
const runtime::gpu::GPURuntimeContext* ctx,
const cudnnDataType_t data_type,
const std::string& dtype,
const Shape& input_filter_shape,
const Shape& input_tensor_shape,
const Shape& output_tensor_shape,
......@@ -341,7 +416,7 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_data(
// construct hash to determine if kernel needs to be emitted
// or if it already exists in the primitive list
std::stringstream ss;
ss << "convolution_bp_data_op" << data_type << "_i" << join(input_tensor_shape, "_") << "_w"
ss << "convolution_bp_data_op_" << dtype << "_i" << join(input_tensor_shape, "_") << "_w"
<< join(input_filter_shape, "_") << "_o" << join(output_tensor_shape, "_") << "_ws"
<< join(window_movement_strides, "_") << "_wd" << join(window_dilation_strides, "_") << "_p"
<< join(padding_below, "_");
......@@ -352,15 +427,20 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_data(
{
return primitive_index;
}
const cudnnDataType_t data_type = get_cudnn_datatype(dtype);
const cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW;
const cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
auto& tensor_desc_0 = tensor_descriptor_from_shape(input_tensor_shape);
auto& tensor_desc_1 = tensor_descriptor_from_shape(output_tensor_shape);
auto& tensor_desc_0 =
tensor_descriptor_from_shape(input_tensor_shape, data_type, tensor_format);
auto& tensor_desc_1 =
tensor_descriptor_from_shape(output_tensor_shape, data_type, tensor_format);
auto& filter_desc = get_cudnn_filter_descriptor(input_filter_shape, data_type, tensor_format);
auto& conv_desc = get_cudnn_convolution_descriptor(
padding_below, window_movement_strides, window_dilation_strides, mode, data_type);
const cudnnConvolutionBwdDataAlgo_t conv_bwd_data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
void* alpha = m_host_parameters.allocate_by_datatype(data_type, 1.0);
void* beta = m_host_parameters.allocate_by_datatype(data_type, 0);
size_t workspace_size_in_bytes = 0;
CUDNN_SAFE_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize(*ctx->cudnn_handle,
......@@ -379,11 +459,9 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_data(
std::unique_ptr<gpu::primitive> conv;
conv.reset(new gpu::primitive{[=, &conv_desc, &tensor_desc_0, &filter_desc, &tensor_desc_1](
void** inputs, void** outputs) {
float alpha = 1.0;
float beta = 0.0;
void* workspace_ptr = runtime::gpu::invoke_memory_primitive(ctx, workspace_idx);
CUDNN_SAFE_CALL(cudnnConvolutionBackwardData(*ctx->cudnn_handle,
&alpha,
alpha,
filter_desc,
inputs[0],
tensor_desc_0,
......@@ -392,7 +470,7 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_data(
conv_bwd_data_algo,
workspace_ptr,
workspace_size_in_bytes,
&beta,
beta,
tensor_desc_1,
outputs[0]));
}});
......@@ -404,7 +482,7 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_data(
size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_filter(
const runtime::gpu::GPURuntimeContext* ctx,
const cudnnDataType_t data_type,
const std::string& dtype,
const Shape& input_tensor_shape_0,
const Shape& input_tensor_shape_1,
const Shape& output_filter_shape,
......@@ -414,8 +492,9 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_filter(
{
// construct hash to determine if kernel needs to be emitted
// or if it already exists in the primitive list
std::stringstream ss;
ss << "convolution_bp_filter_op" << data_type << "_i" << join(input_tensor_shape_0, "_") << "_w"
ss << "convolution_bp_filter_op_" << dtype << "_i" << join(input_tensor_shape_0, "_") << "_w"
<< join(output_filter_shape, "_") << "_o" << join(input_tensor_shape_1, "_") << "_ws"
<< join(window_movement_strides, "_") << "_wd" << join(window_dilation_strides, "_") << "_p"
<< join(padding_below, "_");
......@@ -426,11 +505,14 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_filter(
{
return primitive_index;
}
const cudnnDataType_t data_type = get_cudnn_datatype(dtype);
const cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW;
const cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
auto& tensor_desc_0 = tensor_descriptor_from_shape(input_tensor_shape_0);
auto& tensor_desc_1 = tensor_descriptor_from_shape(input_tensor_shape_1);
auto& tensor_desc_0 =
tensor_descriptor_from_shape(input_tensor_shape_0, data_type, tensor_format);
auto& tensor_desc_1 =
tensor_descriptor_from_shape(input_tensor_shape_1, data_type, tensor_format);
auto& filter_desc = get_cudnn_filter_descriptor(output_filter_shape, data_type, tensor_format);
auto& conv_desc = get_cudnn_convolution_descriptor(
padding_below, window_movement_strides, window_dilation_strides, mode, data_type);
......@@ -450,15 +532,15 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_filter(
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
// (lazy) allocation for kernel arguments
size_t workspace_idx = allocator.reserve_workspace(workspace_size_in_bytes);
void* alpha = m_host_parameters.allocate_by_datatype(data_type, 1.0);
void* beta = m_host_parameters.allocate_by_datatype(data_type, 0);
std::unique_ptr<gpu::primitive> conv;
conv.reset(new gpu::primitive{[=, &conv_desc, &tensor_desc_0, &filter_desc, &tensor_desc_1](
void** inputs, void** outputs) {
float alpha = 1.0;
float beta = 0.0;
void* workspace_ptr = runtime::gpu::invoke_memory_primitive(ctx, workspace_idx);
CUDNN_SAFE_CALL(cudnnConvolutionBackwardFilter(*ctx->cudnn_handle,
&alpha,
alpha,
tensor_desc_0,
inputs[0],
tensor_desc_1,
......@@ -467,11 +549,10 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_filter(
conv_bwd_filter_algo,
workspace_ptr,
workspace_size_in_bytes,
&beta,
beta,
filter_desc,
outputs[0]));
}});
primitive_index = this->m_primitive_emitter->insert(std::move(conv));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
......@@ -479,6 +560,7 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_filter(
size_t runtime::gpu::CUDNNEmitter::build_pooling(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnPoolingMode_t& pool_op,
const std::string& dtype,
const Prop& direction,
const Shape& input_shape,
const Shape& output_shape,
......@@ -490,7 +572,7 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const runtime::gpu::GPURuntimeC
// construct hash to determine if kernel needs to be emitted
// or if it already exists in the primitive list
std::stringstream ss;
ss << "pool_op" << pool_op << "_dir" << static_cast<int>(direction) << "_i"
ss << "pool_op" << pool_op << "dtype_" << dtype << "_dir" << static_cast<int>(direction) << "_i"
<< join(input_shape, "_") << "_o" << join(output_shape, "_") << "_ws"
<< join(window_shape, "_") << "_wst" << join(window_strides, "_") << "_pb"
<< join(padding_below, "_") << "_pb" << join(padding_above, "_");
......@@ -503,9 +585,11 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const runtime::gpu::GPURuntimeC
return primitive_index;
}
const cudnnDataType_t data_type = get_cudnn_datatype(dtype);
const cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW;
auto& desc = m_descriptors.build<cudnnPoolingDescriptor_t>();
auto& input_desc = tensor_descriptor_from_shape(input_shape);
auto& output_desc = tensor_descriptor_from_shape(output_shape);
auto& input_desc = tensor_descriptor_from_shape(input_shape, data_type, tensor_format);
auto& output_desc = tensor_descriptor_from_shape(output_shape, data_type, tensor_format);
if (input_shape.size() == 4)
{
......@@ -544,6 +628,8 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const runtime::gpu::GPURuntimeC
}
std::unique_ptr<gpu::primitive> pool;
void* alpha = m_host_parameters.allocate_by_datatype(data_type, 1.0);
void* beta = m_host_parameters.allocate_by_datatype(data_type, 0);
switch (direction)
{
......@@ -552,13 +638,12 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const runtime::gpu::GPURuntimeC
{
pool.reset(new gpu::primitive{
[=, &desc, &input_desc, &output_desc](void** inputs, void** outputs) {
float alpha = 1.0, beta = 0.0;
CUDNN_SAFE_CALL(cudnnPoolingForward(*ctx->cudnn_handle,
desc,
&alpha,
alpha,
input_desc,
inputs[0],
&beta,
beta,
output_desc,
outputs[0]));
}});
......@@ -566,15 +651,18 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const runtime::gpu::GPURuntimeC
}
case (Prop::Backward):
{
if (data_type == CUDNN_DATA_INT8 || data_type == CUDNN_DATA_INT32)
{
throw std::runtime_error("Pooling does not support int type by cuDNN.");
}
pool.reset(new gpu::primitive{
[=, &desc, &input_desc, &output_desc](void** inputs, void** outputs) {
float alpha = 1.0, beta = 0.0;
// cuDNN requires the output tensor of the maxpool fprop to be passed even though
// it is not mathematically necessary. It appears, however, that it is not actually
// used as the adjoints are passed in place and the correct result is achieved.
CUDNN_SAFE_CALL(cudnnPoolingBackward(*ctx->cudnn_handle,
desc,
&alpha,
alpha,
// output (wrt maxpool) tensor
output_desc,
inputs[1],
......@@ -584,7 +672,7 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const runtime::gpu::GPURuntimeC
// input (wrt maxpool) tensor
input_desc,
inputs[0],
&beta,
beta,
// adjoint of input
input_desc,
outputs[0]));
......@@ -600,6 +688,7 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const runtime::gpu::GPURuntimeC
size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnBatchNormMode_t& bn_op,
const std::string& dtype,
const Prop& direction,
const Shape& tensor_shape,
const Shape& param_shape,
......@@ -609,7 +698,7 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const runtime::gpu::GPURuntim
std::stringstream ss;
ss.precision(std::numeric_limits<double>::digits10 + 2);
ss << "bn_op" << bn_op << "_dir" << static_cast<int>(direction) << "_ts"
ss << "bn_op" << bn_op << "_dtype_" << dtype << "_dir" << static_cast<int>(direction) << "_ts"
<< join(tensor_shape, "_") << "_ps" << join(param_shape, "_") << "_eps" << epsilon;
std::string hash = ss.str();
std::replace(hash.begin(), hash.end(), '.', '_');
......@@ -625,11 +714,13 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const runtime::gpu::GPURuntim
throw std::runtime_error("Batch Norm epsilon is less than CUDNN_BN_MIN_EPSILON");
}
const cudnnDataType_t data_type = get_cudnn_datatype(dtype);
const cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW;
auto& derived_param_desc = m_descriptors.build<cudnnTensorDescriptor_t>();
auto& tensor_desc = tensor_descriptor_from_shape(tensor_shape);
auto& tensor_desc = tensor_descriptor_from_shape(tensor_shape, data_type, tensor_format);
CUDNN_SAFE_CALL(cudnnDeriveBNTensorDescriptor(derived_param_desc, tensor_desc, bn_op));
float alpha = 1.0, beta = 0.0;
void* alpha = m_host_parameters.allocate_by_datatype(data_type, 1.0);
void* beta = m_host_parameters.allocate_by_datatype(data_type, 0);
std::unique_ptr<gpu::primitive> batchnorm;
switch (direction)
{
......@@ -639,8 +730,8 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const runtime::gpu::GPURuntim
[=, &tensor_desc, &derived_param_desc](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnBatchNormalizationForwardInference(*ctx->cudnn_handle,
bn_op,
&alpha,
&beta,
alpha,
beta,
tensor_desc,
inputs[2], // tensor
tensor_desc,
......@@ -658,24 +749,23 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const runtime::gpu::GPURuntim
{
auto& op_desc = m_descriptors.build<cudnnOpTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(
op_desc, CUDNN_OP_TENSOR_MUL, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN));
op_desc, CUDNN_OP_TENSOR_MUL, data_type, CUDNN_NOT_PROPAGATE_NAN));
// currently not using the cudnn moving average
// currently not using the cuDNN moving average
// calculation so this factor needs to be set to 1.0
double exp_avg_factor = 1.0;
// factor to convert unbiased variance to biased variance estimate
// mini-batch statistics (variance of the sample) should be used
// in training and population statistics (sample variance) used
// during inference. see commit note for 3b081ce for more details.
float m = shape_size(tensor_shape) / tensor_shape[1];
float bias_factor = (m - 1) / m;
double m = shape_size(tensor_shape) / tensor_shape[1];
void* bias_factor = m_host_parameters.allocate_by_datatype(data_type, (m - 1) / m);
batchnorm.reset(new gpu::primitive{
[=, &op_desc, &tensor_desc, &derived_param_desc](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnBatchNormalizationForwardTraining(*ctx->cudnn_handle,
bn_op,
&alpha,
&beta,
alpha,
beta,
tensor_desc,
inputs[2],
tensor_desc,
......@@ -693,13 +783,13 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const runtime::gpu::GPURuntim
// convert to biased variance
CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,
op_desc,
&beta,
beta,
derived_param_desc,
outputs[2],
&beta,
beta,
derived_param_desc,
outputs[2],
&bias_factor,
bias_factor,
derived_param_desc,
outputs[2]));
}});
......@@ -712,10 +802,10 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const runtime::gpu::GPURuntim
CUDNN_SAFE_CALL(cudnnBatchNormalizationBackward(
*ctx->cudnn_handle,
bn_op,
&alpha,
&beta,
&alpha,
&beta,
alpha,
beta,
alpha,
beta,
tensor_desc,
inputs[2 /* input tensor x */],
tensor_desc,
......@@ -742,14 +832,15 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const runtime::gpu::GPURuntim
size_t runtime::gpu::CUDNNEmitter::build_softmax(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnSoftmaxAlgorithm_t& algorithm,
const cudnnSoftmaxMode_t& mode,
const std::string& dtype,
const Prop& direction,
const Shape& tensor_shape)
{
// construct hash to determine if kernel needs to be emitted
// or if it already exists in the primitive list
std::stringstream ss;
ss << "softmax_op" << mode << "_alg" << algorithm << "_dir" << static_cast<int>(direction)
<< "_s" << join(tensor_shape, "_");
ss << "softmax_op_" << mode << "_dtype_" << dtype << "_alg" << algorithm << "_dir"
<< static_cast<int>(direction) << "_s" << join(tensor_shape, "_");
std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive
......@@ -759,9 +850,11 @@ size_t runtime::gpu::CUDNNEmitter::build_softmax(const runtime::gpu::GPURuntimeC
return primitive_index;
}
auto& tensor_desc = tensor_descriptor_from_shape(tensor_shape);
float alpha = 1.0, beta = 0.0;
cudnnDataType_t data_type = get_cudnn_datatype(dtype);
cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW;
auto& tensor_desc = tensor_descriptor_from_shape(tensor_shape, data_type, tensor_format);
void* alpha = m_host_parameters.allocate_by_datatype(data_type, 1.0);
void* beta = m_host_parameters.allocate_by_datatype(data_type, 0);
std::unique_ptr<runtime::gpu::primitive> softmax;
switch (direction)
{
......@@ -772,10 +865,10 @@ size_t runtime::gpu::CUDNNEmitter::build_softmax(const runtime::gpu::GPURuntimeC
CUDNN_SAFE_CALL(cudnnSoftmaxForward(*ctx->cudnn_handle,
algorithm,
mode,
&alpha,
alpha,
tensor_desc,
inputs[0],
&beta,
beta,
tensor_desc,
outputs[0]));
}});
......@@ -787,12 +880,12 @@ size_t runtime::gpu::CUDNNEmitter::build_softmax(const runtime::gpu::GPURuntimeC
CUDNN_SAFE_CALL(cudnnSoftmaxBackward(*ctx->cudnn_handle,
algorithm,
mode,
&alpha,
alpha,
tensor_desc,
inputs[0],
tensor_desc,
inputs[1],
&beta,
beta,
tensor_desc,
outputs[0]));
}});
......
......@@ -26,6 +26,7 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/runtime/gpu/cudnn_descriptors.hpp"
#include "ngraph/runtime/gpu/cudnn_host_parameters.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/shape.hpp"
......@@ -56,7 +57,7 @@ namespace ngraph
};
size_t build_convolution(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnDataType_t data_type,
const std::string& dtype,
const Shape& input_tensor_shape,
const Shape& input_filter_shape,
const Shape& output_tensor_shape,
......@@ -65,7 +66,7 @@ namespace ngraph
const Shape& padding_below);
size_t build_convolution_backward_data(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnDataType_t data_type,
const std::string& dtype,
const Shape& input_filter_shape,
const Shape& input_tensor_shape,
const Shape& output_tensor_shape,
......@@ -74,7 +75,7 @@ namespace ngraph
const Shape& padding_below);
size_t build_convolution_backward_filter(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnDataType_t data_type,
const std::string& dtype,
const Shape& input_tensor_shape_0,
const Shape& input_tensor_shape_1,
const Shape& output_filter_shape,
......@@ -84,11 +85,21 @@ namespace ngraph
size_t build_reduce_forward(const GPURuntimeContext* ctx,
const cudnnReduceTensorOp_t& reduce_op,
const std::string& dtype,
const Shape& input_shape,
const AxisSet& reduction_axes);
size_t build_tensor_op(const GPURuntimeContext* ctx,
const cudnnOpTensorOp_t& tensor_op,
const std::string& dtype,
const Shape& input_shape,
const double alpha0,
const double alpha1,
const double beta);
size_t build_pooling(const GPURuntimeContext* ctx,
const cudnnPoolingMode_t& pool_op,
const std::string& dtype,
const Prop& direction,
const ngraph::Shape& input_shape,
const ngraph::Shape& output_shape,
......@@ -99,6 +110,7 @@ namespace ngraph
size_t build_batchnorm(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnBatchNormMode_t& bn_op,
const std::string& dtype,
const Prop& direction,
const Shape& tensor_shape,
const Shape& param_shape,
......@@ -107,10 +119,21 @@ namespace ngraph
size_t build_softmax(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnSoftmaxAlgorithm_t& algorithm,
const cudnnSoftmaxMode_t& mode,
const std::string& dtype,
const Prop& direction,
const Shape& tensor_shape);
cudnnTensorDescriptor_t& tensor_descriptor_from_shape(const Shape& shape);
private:
CUDNNEmitter(GPUPrimitiveEmitter* emitter);
void* get_data_by_type(cudnnDataType_t data_type, double value);
cudnnDataType_t get_cudnn_datatype(std::string dtype);
cudnnTensorDescriptor_t&
tensor_descriptor_from_shape(const Shape& shape,
const cudnnDataType_t data_type,
const cudnnTensorFormat_t tensor_format);
cudnnFilterDescriptor_t&
get_cudnn_filter_descriptor(const Shape& shape,
const cudnnDataType_t data_type,
......@@ -122,10 +145,9 @@ namespace ngraph
cudnnConvolutionMode_t mode,
cudnnDataType_t data_type);
private:
CUDNNEmitter(GPUPrimitiveEmitter* emitter);
CUDNNDescriptors m_descriptors;
CUDNNHostParameters m_host_parameters;
GPUPrimitiveEmitter* m_primitive_emitter;
};
}
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <list>
#include <memory>
#include <cudnn.h>
#include "ngraph/log.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
/// \brief A factory which builds cuDNN host parameters
/// and manages their creation and destruction.
class CUDNNHostParameters
{
public:
CUDNNHostParameters() = default;
~CUDNNHostParameters() = default;
void* allocate_by_datatype(const cudnnDataType_t data_type, const double value)
{
void* r = nullptr;
switch (data_type)
{
case CUDNN_DATA_FLOAT:
m_host_parameters_float.push_back(static_cast<float>(value));
r = static_cast<void*>(&m_host_parameters_float.back());
break;
case CUDNN_DATA_DOUBLE:
m_host_parameters_double.push_back(value);
r = static_cast<void*>(&m_host_parameters_double.back());
break;
case CUDNN_DATA_INT8:
m_host_parameters_int8_t.push_back(static_cast<int8_t>(value));
r = static_cast<void*>(&m_host_parameters_int8_t.back());
break;
case CUDNN_DATA_INT32:
m_host_parameters_int32_t.push_back(static_cast<int32_t>(value));
r = static_cast<void*>(&m_host_parameters_int32_t.back());
break;
case CUDNN_DATA_HALF:
case CUDNN_DATA_INT8x4:
case CUDNN_DATA_UINT8:
case CUDNN_DATA_UINT8x4:
std::string err = "datatype is not supported by cuDNN";
throw std::runtime_error(err);
}
return r;
}
private:
std::list<int8_t> m_host_parameters_int8_t;
std::list<int32_t> m_host_parameters_int32_t;
std::list<float> m_host_parameters_float;
std::list<double> m_host_parameters_double;
};
}
}
}
......@@ -288,6 +288,13 @@ namespace ngraph
static constexpr const char* math_kernel = "!x0";
};
template <>
struct CudaOpMap<ngraph::op::Negative>
{
static constexpr const char* op = "negative";
static constexpr const char* math_kernel = "-x0";
};
template <>
struct CudaOpMap<ngraph::op::Select>
{
......
......@@ -120,33 +120,23 @@ namespace ngraph
return;
}
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
auto& descriptor = descriptors.build<cudnnTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptor(descriptor,
/*format=*/CUDNN_TENSOR_NCHW,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/1,
/*image_height=*/1,
/*image_width=*/count));
auto& opTensorDesc = descriptors.build<cudnnOpTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_ADD,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
writer << "CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,"
<< "opTensorDesc,"
<< "&alpha1,"
<< "descriptor," << args[0].get_name() << ","
<< "&alpha2,"
<< "descriptor," << args[1].get_name() << ","
<< "&beta,"
<< "descriptor," << out[0].get_name() << "));\n";
{
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_tensor_op(external_function->ctx().get(),
CUDNN_OP_TENSOR_ADD,
out[0].get_type(),
args[0].get_shape(),
1.0,
1.0,
0);
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << ","
<< args[1].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_end();
}
......@@ -226,15 +216,14 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer << "std::vector<void*>{pad_buffer}.data()";
writer << ");\n";
// asymetric padding has been applied, zero out padding vectors to
// ensure cudnn does not assume padding
// ensure cuDNN does not assume padding
std::fill(padding_below.begin(), padding_below.end(), 0);
}
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
size_t index = cudnn_emitter->build_convolution(external_function->ctx().get(),
data_type,
out[0].get_type(),
input_shape_padded,
args[1].get_shape(),
out[0].get_shape(),
......@@ -346,16 +335,14 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer << "std::vector<void*>{pad_buffer}.data()";
writer << ");\n";
// asymetric padding has been applied, zero out padding vectors to
// ensure cudnn does not assume padding
// ensure cuDNN does not assume padding
std::fill(padding_below.begin(), padding_below.end(), 0);
}
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
size_t index =
cudnn_emitter->build_convolution_backward_data(external_function->ctx().get(),
data_type,
out[0].get_type(),
args[0].get_shape(),
args[1].get_shape(),
output_shape_padded,
......@@ -499,17 +486,15 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer << "std::vector<void*>{pad_buffer}.data()";
writer << ");\n";
// asymetric padding has been applied, zero out padding vectors to
// ensure cudnn does not assume padding
// ensure cuDNN does not assume padding
std::fill(padding_below.begin(), padding_below.end(), 0);
}
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
size_t index =
cudnn_emitter->build_convolution_backward_filter(external_function->ctx().get(),
data_type,
out[0].get_type(),
input_shape_padded,
args[1].get_shape(),
out[0].get_shape(),
......@@ -569,7 +554,8 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
{
writer.block_begin();
writer << "runtime::gpu::cuda_memset(" << out[0].get_name() << ", 0, "
<< out[0].get_size() << " * sizeof(float));\n";
<< out[0].get_size() << " * " << out[0].get_element_type().size()
<< ");\n";
writer.block_end();
return;
}
......@@ -705,33 +691,23 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
return;
}
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
auto& descriptor = descriptors.build<cudnnTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptor(descriptor,
/*format=*/CUDNN_TENSOR_NCHW,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/1,
/*image_height=*/1,
/*image_width=*/count));
auto& opTensorDesc = descriptors.build<cudnnOpTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_MAX,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
writer << "CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,"
<< "opTensorDesc,"
<< "&alpha1,"
<< "descriptor," << args[0].get_name() << ","
<< "&alpha2,"
<< "descriptor," << args[1].get_name() << ","
<< "&beta,"
<< "descriptor," << out[0].get_name() << "));\n";
{
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_tensor_op(external_function->ctx().get(),
CUDNN_OP_TENSOR_MAX,
out[0].get_type(),
args[0].get_shape(),
1.0,
1.0,
0);
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << ","
<< args[1].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_end();
}
......@@ -743,71 +719,23 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
return;
}
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
auto& descriptor = descriptors.build<cudnnTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptor(descriptor,
/*format=*/CUDNN_TENSOR_NCHW,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/1,
/*image_height=*/1,
/*image_width=*/count));
auto& opTensorDesc = descriptors.build<cudnnOpTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_MIN,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
writer << "CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,"
<< "opTensorDesc,"
<< "&alpha1,"
<< "descriptor," << args[0].get_name() << ","
<< "&alpha2,"
<< "descriptor," << args[1].get_name() << ","
<< "&beta,"
<< "descriptor," << out[0].get_name() << "));\n";
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Negative)
{
if (out[0].get_size() == 0)
{
return;
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_tensor_op(external_function->ctx().get(),
CUDNN_OP_TENSOR_MIN,
out[0].get_type(),
args[0].get_shape(),
1.0,
1.0,
0);
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << ","
<< args[1].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = -1.0, alpha2 = 0, beta = 0;
auto& descriptor = descriptors.build<cudnnTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptor(descriptor,
/*format=*/CUDNN_TENSOR_NCHW,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/1,
/*image_height=*/1,
/*image_width=*/count));
auto& opTensorDesc = descriptors.build<cudnnOpTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_ADD,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
writer << "CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,"
<< "opTensorDesc,"
<< "&alpha1,"
<< "descriptor," << args[0].get_name() << ","
<< "&alpha2,"
<< "descriptor," << args[0].get_name() << ","
<< "&beta,"
<< "descriptor," << out[0].get_name() << "));\n";
writer.block_end();
}
......@@ -1192,33 +1120,23 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
return;
}
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
auto& descriptor = descriptors.build<cudnnTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptor(descriptor,
/*format=*/CUDNN_TENSOR_NCHW,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/1,
/*image_height=*/1,
/*image_width=*/count));
auto& opTensorDesc = descriptors.build<cudnnOpTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_MUL,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
writer << "CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,"
<< "opTensorDesc,"
<< "&alpha1,"
<< "descriptor," << args[0].get_name() << ","
<< "&alpha2,"
<< "descriptor," << args[1].get_name() << ","
<< "&beta,"
<< "descriptor," << out[0].get_name() << "));\n";
{
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_tensor_op(external_function->ctx().get(),
CUDNN_OP_TENSOR_MUL,
out[0].get_type(),
args[0].get_shape(),
1.0,
1.0,
0);
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << ","
<< args[1].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_end();
}
......@@ -1261,33 +1179,23 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
return;
}
writer.block_begin();
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 0, beta = 0;
auto& descriptor = descriptors.build<cudnnTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetTensor4dDescriptor(descriptor,
/*format=*/CUDNN_TENSOR_NCHW,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/1,
/*image_height=*/1,
/*image_width=*/count));
auto& opTensorDesc = descriptors.build<cudnnOpTensorDescriptor_t>();
CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_SQRT,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
)";
writer << "CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,"
<< "opTensorDesc,"
<< "&alpha1,"
<< "descriptor," << args[0].get_name() << ","
<< "&alpha2,"
<< "descriptor," << args[0].get_name() << ","
<< "&beta,"
<< "descriptor," << out[0].get_name() << "));\n";
{
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_tensor_op(external_function->ctx().get(),
CUDNN_OP_TENSOR_SQRT,
out[0].get_type(),
args[0].get_shape(),
1.0,
0,
0);
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << ","
<< args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_end();
}
......@@ -1334,6 +1242,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto max_index =
cudnn_emitter->build_reduce_forward(external_function->ctx().get(),
CUDNN_REDUCE_TENSOR_MAX,
out[0].get_type(),
args[0].get_shape(),
max_op->get_reduction_axes());
......@@ -1382,6 +1291,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto min_index =
cudnn_emitter->build_reduce_forward(external_function->ctx().get(),
CUDNN_REDUCE_TENSOR_MIN,
out[0].get_type(),
args[0].get_shape(),
min_op->get_reduction_axes());
......@@ -1421,6 +1331,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto sum_index =
cudnn_emitter->build_reduce_forward(external_function->ctx().get(),
CUDNN_REDUCE_TENSOR_ADD,
out[0].get_type(),
args[0].get_shape(),
sum->get_reduction_axes());
......@@ -1446,9 +1357,9 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
// one of args[] axes has zero size, fill output with 1
if (args[0].get_size() == 0)
{
writer << "float init_value = 1;\n";
writer << "std::vector<float> temp(" << out[0].get_size()
<< ", init_value);\n";
writer << out[0].get_type() << " init_value = 1;\n";
writer << "std::vector<" << out[0].get_type() << "> temp("
<< out[0].get_size() << ", init_value);\n";
writer << "runtime::gpu::cuda_memcpyHtD(" << out[0].get_name()
<< ", (void*)temp.data(), " << out[0].get_size() << " * "
<< out[0].get_element_type().size() << ");\n";
......@@ -1465,6 +1376,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto index =
cudnn_emitter->build_reduce_forward(external_function->ctx().get(),
CUDNN_REDUCE_TENSOR_MUL,
out[0].get_type(),
args[0].get_shape(),
product->get_reduction_axes());
......@@ -1506,12 +1418,12 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
// one of args0 axes has zero size, zero output, use args1 value
if (args[0].get_size() == 0)
{
writer << "float init_value;\n";
writer << out[0].get_type() << " init_value;\n";
writer << "runtime::gpu::cuda_memcpyDtH(&init_value, "
<< args[1].get_name() << " ,"
<< args[1].get_element_type().size() << ");\n";
writer << "std::vector<float> temp(" << out[0].get_size()
<< ", init_value);\n";
writer << "std::vector<" << out[0].get_type() << "> temp("
<< out[0].get_size() << ", init_value);\n";
writer << "runtime::gpu::cuda_memcpyHtD(" << out[0].get_name()
<< ", (void*)temp.data(), " << out[0].get_size() << " * "
<< out[0].get_element_type().size() << ");\n";
......@@ -1562,6 +1474,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto reduce_index = cudnn_emitter->build_reduce_forward(
external_function->ctx().get(),
reduce_tensor_op,
out[0].get_type(),
args[0].get_shape(),
reduce_op->get_reduction_axes());
......@@ -1595,12 +1508,12 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
// one of args0 axes has zero size, zero output, use args1 value
if (args[0].get_size() == 0)
{
writer << "float init_value;\n";
writer << out[0].get_type() << " init_value;\n";
writer << "runtime::gpu::cuda_memcpyDtH(&init_value, "
<< args[1].get_name() << " ,"
<< args[1].get_element_type().size() << ");\n";
writer << "std::vector<float> temp(" << out[0].get_size()
<< ", init_value);\n";
writer << "std::vector<" << out[0].get_type() << "> temp("
<< out[0].get_size() << ", init_value);\n";
writer << "runtime::gpu::cuda_memcpyHtD(" << out[0].get_name()
<< ", (void*)temp.data(), " << out[0].get_size() << " * "
<< out[0].get_element_type().size() << ");\n";
......@@ -1777,7 +1690,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer << ");\n";
// asymetric padding has been applied, zero out padding vectors to
// ensure cudnn does not assume padding during pooling
// ensure cuDNN does not assume padding during pooling
std::fill(padding_below.begin(), padding_below.end(), 0);
std::fill(padding_above.begin(), padding_above.end(), 0);
}
......@@ -1817,6 +1730,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
max_pool_index = cudnn_emitter->build_pooling(
external_function->ctx().get(),
CUDNN_POOLING_MAX,
out[0].get_type(),
CUDNNEmitter::Prop::Forward,
shape_to_pool,
result_shape,
......@@ -1864,6 +1778,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto max_pool_bp_index =
cudnn_emitter->build_pooling(external_function->ctx().get(),
CUDNN_POOLING_MAX,
out[0].get_type(),
CUDNNEmitter::Prop::Backward,
fp_input_shape,
fp_output_shape,
......@@ -1903,6 +1818,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto bn_index = cudnn_emitter->build_batchnorm(external_function->ctx().get(),
CUDNN_BATCHNORM_SPATIAL,
out[0].get_type(),
direction,
args[2].get_shape(),
args[0].get_shape(),
......@@ -1939,6 +1855,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto bn_index = cudnn_emitter->build_batchnorm(external_function->ctx().get(),
CUDNN_BATCHNORM_SPATIAL,
out[0].get_type(),
CUDNNEmitter::Prop::Backward,
args[2].get_shape(),
args[0].get_shape(),
......@@ -2057,6 +1974,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
avg_pool_index = cudnn_emitter->build_pooling(
external_function->ctx().get(),
cudnn_avg_type,
out[0].get_type(),
CUDNNEmitter::Prop::Forward,
input_shape,
result_shape,
......@@ -2101,6 +2019,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto avg_pool_bp_index =
cudnn_emitter->build_pooling(external_function->ctx().get(),
cudnn_avg_type,
out[0].get_type(),
CUDNNEmitter::Prop::Backward,
output_shape,
delta_shape,
......@@ -2110,7 +2029,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
apb->get_padding_above());
writer << "gpu::invoke_primitive(ctx, " << avg_pool_bp_index << ", ";
// CUDNN backwards pooling requests input and output tensors from
// cuDNN backwards pooling requests input and output tensors from
// the forward pass but does not use them. It also behaves differently
// for max pool vs avg pool. The repetition of args below is to address
// this interface in a way that supports both max and avg pooling
......@@ -2246,6 +2165,7 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
cudnn_emitter->build_softmax(external_function->ctx().get(),
CUDNN_SOFTMAX_FAST,
CUDNN_SOFTMAX_MODE_INSTANCE,
out[0].get_type(),
CUDNNEmitter::Prop::Forward,
tensor_shape);
writer << "gpu::invoke_primitive(ctx, " << softmax_index << ", ";
......
......@@ -187,7 +187,7 @@ static const runtime::gpu::OpMap dispatcher{
{TI(ngraph::op::Log), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Log>},
{TI(ngraph::op::Maximum), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Maximum>},
{TI(ngraph::op::Minimum), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Minimum>},
{TI(ngraph::op::Negative), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Negative>},
{TI(ngraph::op::Negative), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Negative>},
{TI(ngraph::op::NotEqual), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::NotEqual>},
{TI(ngraph::op::Power), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Power>},
{TI(ngraph::op::Select), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Select>},
......@@ -431,7 +431,7 @@ using namespace std;
}
}
}
// Add cudnn descriptor factory for descriptor management.
// Add cuDNN descriptor factory for descriptor management.
// After the cuDNN code emitted in gpu_emitter.cc is refactored
// into the CUDNNEmitter class, this can be removed.
writer << "static runtime::gpu::CUDNNDescriptors descriptors;\n\n";
......
#int64 is not supprted by cuDNN
abc_int64
batch_norm_one_output
batch_norm_three_outputs
#need to check
computation_reuse
#int64 is not supprted
concat_matrix_int64
#convolution 4d is work in progress
convolution_4d_2items
convolution_4d_4items
convolution_4d_4items_dilated
......@@ -12,16 +16,27 @@ convolution_4d_4items_strided_dilated
convolution_4d_4items_strided_dilated_padded
convolution_4d_4items_strided_dilated_padded_neg
convolution_4d_4items_strided_dilated_padded_same
#cuDNN does not have arithmetic exceptions
divide_by_zero_int32
dot_4d_5d_multi_axis_big_fp64_VERY_SLOW
#int64 is not supprted by cuDNN
dot_matrix_vector_int64
#no mkldnn on GPU
mkldnn_layouts
#error throw is not the same on GPU, not supported yet
one_hot_scalar_fp_nonint_in_3
one_hot_scalar_oob_in_3
one_hot_vector_1_barely_oob
one_hot_vector_1_far_oob
one_hot_vector_1_fp_nonint
#select_and_scatter is deprecated
select_and_scatter_3d_without_overlap
select_and_scatter_with_overlap
select_and_scatter_without_overlap
#custom_mem is not implemented on GPU
tensorview_custom_mem
#integer is not supported by cuDNN on backward pooling
backwards_maxpool_n4_c1_hw4_2x2_max
backwards_maxpool_n2_c1_hw5_3x3_str2_max
backwards_avgpool_n1_c1_hw2x2
backwards_avgpool_n1_c1_hw4x4
backwards_avgpool_n2_c2_hw4x4
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment