Commit ae45c984 authored by Chris Sullivan's avatar Chris Sullivan Committed by Scott Cyphers

Nd convolution via blocked GEMM for C{d1,...,dn}N layout (#1131)

* Added blank convolution kernel and refactored coordinate transform kernel helper.

* Added op::Reshape to the CUDAEmitter.

* Added 2-Nd tiled convolution.

* Bug fixes with data_dilation and filter loop. Still need to add test for coverage of register tiling.

* Styling.

* Removed some comments and code added for testing.

* Some tests became enabled in merge, removing them.
parent 3a43bdac
......@@ -1222,6 +1222,344 @@ size_t runtime::gpu::CUDAEmitter::build_broadcast(const GPURuntimeContext* ctx,
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_reshape(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
GPUShape input_order)
{
std::string kernel_name = "reshape_" + join(dtypes, "_");
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
std::stringstream ss;
ss << kernel_name << "_s" << join(input_shape, "_") << "_ax" << join(input_order, "_");
auto hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// if the kernel has not been compiled, build it
auto compiled_kernel = ctx->compiled_kernel_pool->get(kernel_name);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
writer << include_helpers();
runtime::gpu::CudaKernelBuilder::get_reshape_op(writer, kernel_name, dtypes);
compiled_kernel = ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
}
auto input_rank = input_shape.size();
std::vector<size_t> input_strides(input_rank);
std::vector<size_t> output_strides(input_rank);
std::vector<size_t> trans_strides(input_rank);
size_t stride = 1;
for (int64_t i = input_rank - 1; i >= 0; i--)
{
input_strides[i] = stride;
stride *= input_shape[i];
}
stride = 1;
for (int64_t i = input_rank - 1; i >= 0; i--)
{
output_strides[i] = stride;
stride *= input_shape[input_order[i]];
}
for (int64_t i = 0; i < input_rank; i++)
{
trans_strides[input_order[i]] = output_strides[i];
}
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t idx_input_strides =
allocator.reserve_argspace(input_strides.data(), input_strides.size() * sizeof(size_t));
size_t idx_trans_strides =
allocator.reserve_argspace(trans_strides.data(), trans_strides.size() * sizeof(size_t));
size_t nthreads = shape_size(input_shape);
std::unique_ptr<gpu::primitive> reshape(new gpu::primitive{[=](void** inputs,
void** outputs) mutable {
void* input_strides_d = runtime::gpu::invoke_memory_primitive(ctx, idx_input_strides);
void* trans_strides_d = runtime::gpu::invoke_memory_primitive(ctx, idx_trans_strides);
void* args_list[] = {
&inputs[0], &outputs[0], &input_strides_d, &trans_strides_d, &input_rank, &nthreads};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<int>(nthreads),
1,
1,
1,
1,
1,
0,
NULL,
args_list,
0));
CUDA_SAFE_CALL(cuCtxSynchronize());
}});
primitive_index = this->m_primitive_emitter->insert(std::move(reshape));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_convolution(const GPURuntimeContext* ctx,
const std::array<std::string, 3>& dtypes,
GPUShape input_shape,
GPUShape input_pad_below,
GPUShape input_dilation,
GPUShape filter_shape,
GPUShape filter_stride,
GPUShape filter_dilation,
GPUShape output_shape)
{
// convolution is performed on tensors in the following format
// input_shape: C{di_1,...,du_n}N
// filter_shape: C{df_1,...,df_n}K
// output_shape: K{do_1,...,do_n}N
// The basic strategy performed by this kernel is to convert Nd convolution
// into a single 2D GEMM that can be block multiplied via a hierarchical strategy.
// The spatial dimensions are squashed into a single column axis and the
// batch number N and filter number K are the rows of A and B in the 2D GEMM
// A * B = C, respectively. By keeping N and K in contiguous memory space,
// coalescing and vectorization is maintained regardless of coordinate access
// (e.g. data and filter dilation).
std::string kernel_name = "convolution_fprop_" + join(dtypes, "_");
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
// prerequisits for kernel cacheing and building
int N = input_shape.back();
int K = filter_shape.back();
int filter_size = 1;
int rank = 0;
for (int i = 1; i < filter_shape.size() - 1; i++)
{ // skip first and last (non-spatial) dimensions
filter_size *= filter_shape[i];
rank++;
}
// additional kernel cache parameters
kernel_name = kernel_name + "_n" + std::to_string(N) + "_k" + std::to_string(K) + "_fsz" +
std::to_string(filter_size) + "_r" + std::to_string(rank);
// primitive cache parameters
std::stringstream ss;
ss << kernel_name << "_s" << join(input_shape, "_") << "_pb" << join(input_pad_below, "_")
<< "_pi" << join(input_dilation, "_") << "_fs" << join(filter_shape, "_") << "_fst"
<< join(filter_stride, "_") << "_fdi" << join(filter_dilation, "_");
auto hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// tiling options are determined by
// batch size (N) and number of filters (K)
int reg_tile_size = 1;
int sm_tile_size = 8;
// if N is a multiple of 32 use register tiling
if (N % (sm_tile_size * 4) == 0)
{
reg_tile_size = 4;
}
// if the kernel has not been compiled, build it
auto compiled_kernel = ctx->compiled_kernel_pool->get(kernel_name);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
writer << include_helpers();
CudaKernelBuilder::get_convolution_forward(
writer, kernel_name, dtypes, N, K, filter_size, rank, sm_tile_size, reg_tile_size);
compiled_kernel = ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
}
// ----- build primitive arguments -----
// TODO: as each cuda_emitter has a regular structure
// it would be beneficial to factor these into classes
// with seperate methods for compiling the kernel, building
// the primitive, and transfering arguments to device memory
int C = input_shape.front();
int input_channel_size = 1;
int filter_channel_size = 1;
int output_filter_size = 1;
for (int i = 1; i < input_shape.size(); i++)
{
input_channel_size *= input_shape[i];
filter_channel_size *= filter_shape[i];
output_filter_size *= output_shape[i];
}
// vector accesses of width `reg_tile_size` are
// used reducting the effective tensor array size
input_channel_size /= reg_tile_size;
filter_channel_size /= reg_tile_size;
output_filter_size /= reg_tile_size;
// arguments derived from output tensor
int output_pixels = 1;
int output_pixels_magic;
int output_pixels_shift;
std::vector<int> output_dim_strides(rank, 0);
std::vector<int> output_str_magic(rank, 0);
std::vector<int> output_str_shift(rank, 0);
for (int64_t i = output_shape.size() - 2; i > 0; i--)
{
output_dim_strides[i - 1] = output_pixels;
int magic;
int shift;
std::tie(magic, shift) = idiv_magic_u64(output_pixels);
output_str_magic[i - 1] = magic;
output_str_shift[i - 1] = shift;
output_pixels *= output_shape[i];
}
std::tie(output_pixels_magic, output_pixels_shift) = idiv_magic_u64(output_pixels);
// arguments derived from filter tensor
int filter_sz = 1;
std::vector<int> filter_dim_strides(rank, 0);
std::vector<int> filter_str_magic(rank, 0);
std::vector<int> filter_str_shift(rank, 0);
for (int64_t i = filter_shape.size() - 2; i > 0; i--)
{
filter_dim_strides[i - 1] = filter_sz;
int magic;
int shift;
std::tie(magic, shift) = idiv_magic_u64(filter_sz);
filter_str_magic[i - 1] = magic;
filter_str_shift[i - 1] = shift;
filter_sz *= filter_shape[i];
}
// remaining kernel arguments
std::vector<int> data_dilation_magic(input_dilation.size(), 0);
std::vector<int> data_dilation_shift(input_dilation.size(), 0);
for (int i = 0; i < input_dilation.size(); i++)
{
int magic;
int shift;
std::tie(magic, shift) = idiv_magic_u64(input_dilation[i]);
data_dilation_magic[i] = magic;
data_dilation_shift[i] = shift;
}
GPUShape input_shape_str = row_major_strides(input_shape);
float alpha = 1.0f;
float beta = 0.0f;
// ----- register primitive arguments with device -----
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t idx_pad = allocator.reserve_argspace(input_pad_below);
size_t idx_data_dilation = allocator.reserve_argspace(input_dilation);
size_t idx_data_dilation_magic = allocator.reserve_argspace(data_dilation_magic);
size_t idx_data_dilation_shift = allocator.reserve_argspace(data_dilation_shift);
size_t idx_filter_strides = allocator.reserve_argspace(filter_stride);
size_t idx_filter_dilation = allocator.reserve_argspace(filter_dilation);
size_t idx_input_shape = allocator.reserve_argspace(input_shape);
size_t idx_input_shape_str = allocator.reserve_argspace(input_shape_str);
size_t idx_output_dim_strides = allocator.reserve_argspace(output_dim_strides);
size_t idx_output_str_magic = allocator.reserve_argspace(output_str_magic);
size_t idx_output_str_shift = allocator.reserve_argspace(output_str_shift);
size_t idx_filter_dim_strides = allocator.reserve_argspace(filter_dim_strides);
size_t idx_filter_str_magic = allocator.reserve_argspace(filter_str_magic);
size_t idx_filter_str_shift = allocator.reserve_argspace(filter_str_shift);
// launch arguments:
// each output pixel is its own block. if the batch size is greater than reg_tile_size * sm_tile_size, a single
// output pixel is spread over multiple blocks along the batch axis so that memory coordination is not required
// each block consists of 2 warps in an 8 x 8 array used for accessing the SM block of the GEMM
// do_i = output pixel coordinates
// grid = (do_1*do_2*...*do_N*ceil_div(N, REG_TILE_SIZE*SM_TILE_SIZE), ceil_div(K, REG_TILE_SIZE*SM_TILE_SIZE), 1)
// block = (8, 8, 1)
dim3 blocks(output_pixels * idiv_ceil(N, reg_tile_size * sm_tile_size),
idiv_ceil(K, reg_tile_size * sm_tile_size),
1);
dim3 threads(sm_tile_size, sm_tile_size, 1);
// e.g. for 2d without register tiling
// blocks = (PQ*N/8, K/8, 1)
// threads = (8, 8, 1)
std::unique_ptr<gpu::primitive> conv(new gpu::primitive{[=](void** inputs,
void** outputs) mutable {
void* pad_d = runtime::gpu::invoke_memory_primitive(ctx, idx_pad);
void* data_dilation_d = runtime::gpu::invoke_memory_primitive(ctx, idx_data_dilation);
void* data_dilation_magic_d =
runtime::gpu::invoke_memory_primitive(ctx, idx_data_dilation_magic);
void* data_dilation_shift_d =
runtime::gpu::invoke_memory_primitive(ctx, idx_data_dilation_shift);
void* filter_strides_d = runtime::gpu::invoke_memory_primitive(ctx, idx_filter_strides);
void* filter_dilation_d = runtime::gpu::invoke_memory_primitive(ctx, idx_filter_dilation);
void* input_shape_d = runtime::gpu::invoke_memory_primitive(ctx, idx_input_shape);
void* input_shape_str_d = runtime::gpu::invoke_memory_primitive(ctx, idx_input_shape_str);
void* output_dim_strides_d =
runtime::gpu::invoke_memory_primitive(ctx, idx_output_dim_strides);
void* output_str_magic_d = runtime::gpu::invoke_memory_primitive(ctx, idx_output_str_magic);
void* output_str_shift_d = runtime::gpu::invoke_memory_primitive(ctx, idx_output_str_shift);
void* filter_dim_strides_d =
runtime::gpu::invoke_memory_primitive(ctx, idx_filter_dim_strides);
void* filter_str_magic_d = runtime::gpu::invoke_memory_primitive(ctx, idx_filter_str_magic);
void* filter_str_shift_d = runtime::gpu::invoke_memory_primitive(ctx, idx_filter_str_shift);
void* args_list[] = {&inputs[0],
&inputs[1],
&outputs[0],
&alpha,
&beta,
&N,
&C,
&K,
&input_channel_size,
&filter_channel_size,
&output_filter_size,
&output_pixels,
&output_pixels_magic,
&output_pixels_shift,
&pad_d,
&data_dilation_d,
&data_dilation_magic_d,
&data_dilation_shift_d,
&filter_strides_d,
&filter_dilation_d,
&input_shape_d,
&input_shape_str_d,
&output_dim_strides_d,
&output_str_magic_d,
&output_str_shift_d,
&filter_dim_strides_d,
&filter_str_magic_d,
&filter_str_shift_d};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
blocks.x,
blocks.y,
blocks.z,
threads.x,
threads.y,
threads.z,
0,
NULL,
args_list,
0));
CUDA_SAFE_CALL(cuCtxSynchronize());
}});
primitive_index = this->m_primitive_emitter->insert(std::move(conv));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
void runtime::gpu::CUDAEmitter::print_tensor_from_gpu(codegen::CodeWriter& writer,
const std::string& tensor_name,
GPUShape shape)
......@@ -1301,6 +1639,16 @@ __device__ __forceinline__ int division_by_invariant_multiplication(int value, i
"}" : "=r"(result) : "r"(value), "r"(magic), "r"(shift));
return result;
}
__device__ __forceinline__ void idiv_fast(int numerator, int denominator, float rcp,
int& result, int& remainder)
{
result = (int)((float)numerator * rcp);
remainder = numerator - (result * denominator);
result = (remainder >= denominator) ? (result + 1) : result;
remainder = (remainder >= denominator) ? (remainder - denominator) : remainder;
}
__device__ __forceinline__ int mod16(int numerator, int div, int maxdiv)
{
int res;
......
......@@ -127,6 +127,21 @@ namespace ngraph
GPUShape result_shape,
const std::set<size_t>& bcast_axes);
size_t build_reshape(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
GPUShape input_order);
size_t build_convolution(const GPURuntimeContext* ctx,
const std::array<std::string, 3>& dtypes,
GPUShape input_shape,
GPUShape input_pad_below,
GPUShape input_dilation,
GPUShape filter_shape,
GPUShape filter_stride,
GPUShape filter_dilation,
GPUShape output_shape);
private:
CUDAEmitter(GPUPrimitiveEmitter* emitter);
void print_tensor_from_gpu(codegen::CodeWriter& writer,
......
......@@ -710,15 +710,473 @@ void runtime::gpu::CudaKernelBuilder::get_avg_pool(codegen::CodeWriter& writer,
writer.block_end();
}
std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_helper(
void runtime::gpu::CudaKernelBuilder::get_convolution_forward(
codegen::CodeWriter& writer,
std::string i_thread_index,
std::string i_strides,
std::string i_stride_magic,
std::string i_stride_shift,
std::string i_reduced_strides,
std::string o_coordinates,
size_t rank)
const std::string& name,
const std::array<std::string, 3>& data_types,
int N,
int K,
int filter_size,
int rank,
int sm_tile_size,
int reg_tile_size)
{
writer << "#define NUM_ROWS 8\n";
writer << "#define FILTER_SIZE " << filter_size << "\n";
writer << "#define SM_TILE_SIZE " << sm_tile_size << "\n";
writer << "#define REG_TILE_SIZE " << reg_tile_size << "\n";
// convenient type def for register tiling
writer << "typedef union Matrix\n";
writer.block_begin();
{
writer << data_types[0] << reg_tile_size << " f" << reg_tile_size << ";\n";
writer << data_types[0] << " f[" << reg_tile_size << "];\n";
}
writer.block_end();
writer << "Matrix;\n\n";
writer << "extern \"C\" __global__ void cuda_" << name << "(";
writer << data_types[0] << "* in, ";
writer << data_types[1] << "* filter, ";
writer << data_types[2] << "* out, ";
// TODO: add alpha/beta support
writer << "float alpha, float beta, "
<< "int N, "
<< "int C, "
<< "int K, "
<< "int input_channel_size, "
<< "int filter_channel_size, "
<< "int output_filter_size, "
<< "int output_pixels, "
<< "int output_pixels_magic, "
<< "int output_pixels_shift, "
<< "int* pad, "
<< "int* data_dilation, "
<< "int* data_dilation_magic, "
<< "int* data_dilation_shift, "
<< "int* filter_strides, "
<< "int* filter_dilation, "
<< "int* in_shape, "
<< "int* in_shape_str, "
<< "int* out_dim_str, "
<< "int* out_str_magic, "
<< "int* out_str_shift, "
<< "int* filter_dim_str, "
<< "int* filter_str_magic, "
<< "int* filter_str_shift"
<< ")\n";
writer.block_begin();
{
writer << "Matrix* I = reinterpret_cast<Matrix*>(in);\n";
writer << "Matrix* F = reinterpret_cast<Matrix*>(filter);\n";
writer << "Matrix* O = reinterpret_cast<Matrix*>(out);\n";
writer << "__shared__ int2 lookup_table[FILTER_SIZE];\n";
writer << "__shared__ int lookup_size;\n";
writer << "__shared__ Matrix a_tile[NUM_ROWS][SM_TILE_SIZE];\n";
writer << "__shared__ Matrix b_tile[NUM_ROWS][SM_TILE_SIZE];\n";
writer << "int lookup_size_local = 0;\n";
writer << "int n_batch = division_by_invariant_multiplication(blockIdx.x, "
"output_pixels_magic, output_pixels_shift);\n";
writer << "int output_pixel_idx = blockIdx.x - n_batch*output_pixels;\n";
// multiply by the number of threads per sm tile to get the offset into the
// image and filter dimensions (stride 1)
writer << "int n_offset = n_batch * blockDim.x;\n";
writer << "int k_offset = blockIdx.y * blockDim.x;\n";
// compute coordinate transform to output tensor axes
// up to the last dimension but not including it
// : out_dim_str { d2*d3*...*dn, d3*...*dn, ..., dn, 1}
// : for 2d {Q, 1}
coordinate_transform_to_multi_d(writer,
"out_dim_str",
"out_str_magic",
"out_str_shift",
"output_pixel_idx",
"out_d",
rank);
// offset tensors by image and filter indices
// each thread is responsible for it's own image and filter
// n and k offsets are required because only REG_TILE_SIZE*SM_TILE_SIZE
// images/filters are processed per block
writer << "I = &(I[n_offset + threadIdx.x]);\n";
writer << "F = &(F[k_offset + threadIdx.x]);\n";
// if N is a multiple of reg_tile_size * sm_tile_size then no check is needed
bool need_image_bounds_check = N % (reg_tile_size * sm_tile_size) != 0;
if (need_image_bounds_check)
{
writer << "int image_load_in_bounds = (n_offset + threadIdx.x);\n";
if (reg_tile_size == 4)
{
writer << "image_load_in_bounds <<= 2;\n";
}
writer << "image_load_in_bounds = (image_load_in_bounds < N);\n";
}
// if K is a multiple of reg_tile_size * sm_tile_size then no check is needed
bool need_filter_bounds_check = K % (reg_tile_size * sm_tile_size) != 0;
if (need_filter_bounds_check)
{
writer << "int filter_load_in_bounds = (k_offset + threadIdx.x);\n";
if (reg_tile_size == 4)
{
writer << "filter_load_in_bounds <<= 2;\n";
}
writer << "filter_load_in_bounds = (filter_load_in_bounds < K);\n";
}
writer << "int tid = threadIdx.x + threadIdx.y * blockDim.x;\n";
// build lookup table for loading elements from data and filter tensors
writer << "if (tid < 32)\n";
writer.block_begin();
{
writer << "int filter_pixel = tid;\n";
for (int i = 0; i < rank; i++)
{
writer << "int input_base_d" << i << " = out_d" << i << " * filter_strides[" << i
<< "] - pad[" << i << "];\n";
}
// a mask marking all threads that have tid less than the current thread
writer << "uint32_t mask = (1 << tid) - 1;\n";
// loop over filter coordinates
writer << "while (filter_pixel < FILTER_SIZE)\n";
writer.block_begin();
{
// transform to filter coordinates
// : filter_dim_str is {S, 1} for 2D
coordinate_transform_to_multi_d(writer,
"filter_dim_str",
"filter_str_magic",
"filter_str_shift",
"filter_pixel",
"filter_d",
rank);
// transform from filter coordinate to input coordinates
// and check that each coordinate maps to an input element in the undilated space
writer << "int off_dilation_stride = 0;\n";
writer << "int undilated_coordinate = 0;\n";
for (int i = 0; i < rank; i++)
{
writer << "int input_d" << i << " = input_base_d" << i << " + filter_d" << i
<< " * filter_dilation[" << i << "];\n";
// determine coordinate in undilated input space
writer << "undilated_coordinate = division_by_invariant_multiplication(input_d"
<< i << ", data_dilation_magic[" << i << "], data_dilation_shift[" << i
<< "]);\n";
// if division remainder is 0, then dilated coordinate is on an input element
writer << "off_dilation_stride += (input_d" << i
<< " - undilated_coordinate * data_dilation[" << i << "]);\n";
// reassign dilated coordinate to undilated input coordinate
writer << "input_d" << i << " = undilated_coordinate;\n";
}
// check if the index is in bounds of the input tensor
writer << "bool in_bounds = (off_dilation_stride == 0) && (";
for (int i = 0; i < rank; i++)
{
if (i != 0)
{
writer << "&& ";
}
// in_shape contains the full shape of the input_tensor
// for 2D this is: (C, H, W, N) but rank = 2 and so only [H, W] are used
// condition (input_d0 >=0 && input_d0 < H && input_d1 >= 0 && input_d1 < W)
writer << "input_d" << i << ">= 0 && input_d" << i << " < in_shape[" << i + 1
<< "] ";
}
writer << ");\n";
// check which threads are within bounds of the input tensor
writer << "uint32_t threads = __ballot(in_bounds);\n";
writer << "if (in_bounds)\n";
writer.block_begin();
{
writer << "int2 entry;\n";
// inner product of coordinates and strides up to the last dimension
// for 2D (CHWN) this is: (HWN, WN, N, 1)
// entry.x = input_d0 * WN + input_d1*N
writer << "entry.x = (";
for (int i = 0; i < rank; i++)
{
if (i != 0)
{
writer << "+ ";
}
// skips the first and last stride which correspond
// to the channel and batch coordinate, respectively
writer << "input_d" << i << " * in_shape_str[" << i + 1 << "] ";
}
writer << ")";
// if using register tiling, down shift
// as each thread will compute outer
// product with register tiles
if (reg_tile_size == 4)
{
writer << " >> 2";
}
writer << ";\n";
// multiply by K filters per filter_pixel
writer << "entry.y = (filter_pixel * K)";
if (reg_tile_size == 4)
{
writer << " >> 2";
}
writer << ";\n";
// count the number of active threads with index less than
// current tid use this as an offset into the lookup table
writer << "int index = lookup_size_local + __popc(threads & mask);\n";
// save coordinates to shared lookup table for later loading
writer << "lookup_table[index] = entry;\n";
}
writer.block_end();
writer << "lookup_size_local += __popc(threads);\n";
writer << "filter_pixel += 32;\n";
}
writer.block_end();
}
writer.block_end();
// push lookup table size to shared memory so that it is accessible by other threads
writer << "if (tid == 0)\n";
writer.block_begin();
{
writer << "lookup_size = lookup_size_local;\n";
}
writer.block_end();
writer << "__syncthreads();\n";
// pull lookup table size from shared memory
writer << "lookup_size_local = lookup_size;\n";
// declare and zero initialize gemm accumulator
writer << "Matrix result[" << reg_tile_size << "] = {0};\n";
// if the lookup table is empty no multiplication is needed,
// skip and write out zero result else, do the gemm
writer << "if (lookup_size_local > 0)\n";
writer.block_begin();
{
// calculate total size of filter including each channel
writer << "int total_filter_size = lookup_size_local * C;\n";
// precompute reciprocal for faster division
writer << "float reciprocal = 1.0f / static_cast<float>(lookup_size_local);\n";
// loop from the back of the filter (highest index) to the front
// in order to handle filter pixel edge conditionals first (outside of gemm loop)
writer << "int total_filter_idx = total_filter_size % NUM_ROWS;\n";
// want total_filter_idx always >=0 in order to mask threads with t.y > total_filter_idx
writer << "total_filter_idx = (total_filter_idx == 0) ? 8 : total_filter_idx;\n";
// first iteration from back of filter
writer << "int c;\n";
writer << "int filter_idx;\n";
writer << "idiv_fast(total_filter_size - threadIdx.y - 1, lookup_size_local, "
"reciprocal, c, filter_idx);\n";
// retrieve the offsets for the data and filter for these filter pixels
// only threads that are less than the total_filter_idx are valid, the rest are oob
writer << "int2 entry = ((threadIdx.y & 7) >= total_filter_idx) "
<< "? make_int2(0, 0)\n"
<< ": lookup_table[filter_idx];\n";
// helper to emit call to cuda make_float function
auto make_float_i = [](int n) {
std::stringstream ss;
ss << "make_float" << n << "(";
for (int i = 0; i < n; i++)
{
if (i != 0)
{
ss << ", ";
}
ss << "0";
}
ss << ")";
return ss.str();
};
// use the y index of threads to load data into the tile rows
// threadIdx.x is used for iterating over the fastest moving dimensions
// of the data and filter tensors (N and K respectively)
// --- image load ---
writer << "a_tile[threadIdx.y][threadIdx.x].f" << reg_tile_size << " =\n";
if (need_image_bounds_check)
{
// check if image index is in bounds
writer << "(!image_load_in_bounds)\n";
writer << "? " << make_float_i(reg_tile_size) << "\n";
writer << ": ";
}
// if filter pixel is out of range,
// set all elements in the relevant sm tile row to 0
writer << "((threadIdx.y & 7) >= total_filter_idx)\n";
writer << "? " << make_float_i(reg_tile_size) << "\n";
// else load the image data corresponding to this filter pixel
// according to the entry.x offset previously determined
writer << ": I[(c * input_channel_size) + entry.x].f" << reg_tile_size << ";\n";
// --- filter load ---
writer << "b_tile[threadIdx.y][threadIdx.x].f" << reg_tile_size << " =\n";
if (need_filter_bounds_check)
{
// check if filter index is in bounds
writer << "(!filter_load_in_bounds)\n";
writer << "? " << make_float_i(reg_tile_size) << "\n";
writer << ": ";
}
// if filter pixel is out of range,
// set all elements in the relevant sm tile row to 0
writer << "((threadIdx.y & 7) >= total_filter_idx)\n";
writer << "? " << make_float_i(reg_tile_size) << "\n";
// else load the filter weights corresponding to this filter pixel
// according to the entry.y offset previously determined
writer << ": F[(c * filter_channel_size) + entry.y].f" << reg_tile_size << ";\n";
// iterate over filter from back to front
writer << "for (total_filter_idx = total_filter_size - total_filter_idx; "
"total_filter_idx > 0; total_filter_idx -= NUM_ROWS)\n";
writer.block_begin();
{
// finish loads
writer << "__syncthreads();\n";
writer << "#pragma unroll\n";
writer << "for (int i = 0; i < NUM_ROWS; i++)\n";
writer.block_begin();
{
writer << "Matrix row;\n";
writer << "Matrix col;\n";
writer << "row.f" << reg_tile_size << " = a_tile[i][threadIdx.x].f"
<< reg_tile_size << ";\n";
writer << "col.f" << reg_tile_size << " = b_tile[i][threadIdx.y].f"
<< reg_tile_size << ";\n";
// accumulate the product
writer << "#pragma unroll\n";
writer << "for (int y = 0; y < " << reg_tile_size << "; y++)\n";
writer.block_begin();
{
writer << "#pragma unroll\n";
writer << "for (int x = 0; x < " << reg_tile_size << "; x++)\n";
writer.block_begin();
{
writer << "result[y].f[x] += (row.f[x] * col.f[y]);\n";
}
writer.block_end();
}
writer.block_end();
}
writer.block_end();
writer << "__syncthreads();\n";
// load new data and weights
writer << "idiv_fast(total_filter_idx - threadIdx.y - 1, lookup_size_local, "
"reciprocal, c, filter_idx);\n";
writer << "entry = lookup_table[filter_idx];\n";
// --- image load ---
writer << "a_tile[threadIdx.y][threadIdx.x].f" << reg_tile_size << " =\n";
if (need_image_bounds_check)
{
// check if image index is in bounds
writer << "(!image_load_in_bounds)\n";
writer << "? " << make_float_i(reg_tile_size) << "\n";
writer << ": ";
}
writer << "I[(c * input_channel_size) + entry.x].f" << reg_tile_size << ";\n";
// --- filter load ---
writer << "b_tile[threadIdx.y][threadIdx.x].f" << reg_tile_size << " =\n";
if (need_filter_bounds_check)
{
// check if filter index is in bounds
writer << "(!filter_load_in_bounds)\n";
writer << "? " << make_float_i(reg_tile_size) << "\n";
writer << ": ";
}
writer << "F[(c * filter_channel_size) + entry.y].f" << reg_tile_size << ";\n";
}
writer.block_end();
writer << "__syncthreads();\n";
// last iteration
writer << "#pragma unroll\n";
writer << "for (int i = 0; i < NUM_ROWS; i++)\n";
writer.block_begin();
{
writer << "Matrix row;\n";
writer << "Matrix col;\n";
writer << "row.f" << reg_tile_size << " = a_tile[i][threadIdx.x].f" << reg_tile_size
<< ";\n";
writer << "col.f" << reg_tile_size << " = b_tile[i][threadIdx.y].f" << reg_tile_size
<< ";\n";
// accumulate the product
writer << "#pragma unroll\n";
writer << "for (int y = 0; y < " << reg_tile_size << "; y++)\n";
writer.block_begin();
{
writer << "#pragma unroll\n";
writer << "for (int x = 0; x < " << reg_tile_size << "; x++)\n";
writer.block_begin();
{
writer << "result[y].f[x] += (row.f[x] * col.f[y]);\n";
}
writer.block_end();
}
writer.block_end();
}
writer.block_end();
} // end if (lookup_size_local > 0)
writer.block_end();
// store result block to global memory
writer << "int n = n_offset + threadIdx.x;\n";
std::string k_definition = "int k = (k_offset + threadIdx.y)";
std::string output_pixel = "output_pixel_idx = (output_pixel_idx * N)";
if (reg_tile_size == 4)
{
output_pixel += " >> 2";
k_definition += " << 2";
}
writer << output_pixel << ";\n";
writer << k_definition << ";\n";
writer << "if (k < K && n < N)\n";
writer.block_begin();
{
writer << "#pragma unroll\n";
writer << "for (int x = 0; x < " << reg_tile_size << "; x++)\n";
writer.block_begin();
{
writer << "if (k < K)\n";
writer.block_begin();
{
writer << "int idx = (k * output_filter_size) + output_pixel_idx + n;\n";
writer << "O[idx].f" << reg_tile_size << " = result[x].f" << reg_tile_size
<< ";\n";
}
writer.block_end();
writer << "k++;\n";
}
writer.block_end();
}
writer.block_end();
}
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::coordinate_transform_to_multi_d(codegen::CodeWriter& writer,
std::string i_strides,
std::string i_stride_magic,
std::string i_stride_shift,
std::string i_coord_product,
std::string o_coordinates,
size_t rank)
{
// Translation from flat index to dense tensor coordinates:
// Given tensor shape [d0 d1 ... dN] with strides [d1*...*dN, d2*...*dN, ... 1],
......@@ -729,15 +1187,31 @@ std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_hel
// product = product % stride[0]
// d1 = product/stride[1]
// ...
writer << "int coordinate_product = " << i_thread_index << ";\n";
writer << "int coordinate_product = " << i_coord_product << ";\n";
for (size_t i = 0; i < rank; i++)
{
if (i != 0)
{
writer << "coordinate_product -= (" << o_coordinates << i - 1 << " * " << i_strides
<< "[" << i - 1 << "]);\n";
}
writer << "int " << o_coordinates << i << " = division_by_invariant_multiplication("
<< "coordinate_product, " << i_stride_magic << "[" << i << "], " << i_stride_shift
<< "[" << i << "]);\n";
writer << "coordinate_product -= (" << o_coordinates << i << " * " << i_strides << "[" << i
<< "]);\n";
}
}
std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_helper(
codegen::CodeWriter& writer,
std::string i_thread_index,
std::string i_strides,
std::string i_stride_magic,
std::string i_stride_shift,
std::string i_reduced_strides,
std::string o_coordinates,
size_t rank)
{
coordinate_transform_to_multi_d(
writer, i_strides, i_stride_magic, i_stride_shift, i_thread_index, o_coordinates, rank);
// index into reduced tensor from coordinates of non-reduced tensor
std::string reduced_idx = "reduced_idx";
......@@ -750,7 +1224,6 @@ std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_hel
return reduced_idx;
}
void runtime::gpu::CudaKernelBuilder::get_device_helper(codegen::CodeWriter& writer,
const std::string& name,
const std::string& math_kernel,
......
......@@ -114,6 +114,16 @@ namespace ngraph
const std::array<std::string, 2>& data_types,
bool include_pad);
static void get_convolution_forward(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 3>& data_types,
int N,
int K,
int filter_size,
int rank,
int sm_tile_size = 8,
int reg_tile_size = 1);
static void add_pod_typedefs(codegen::CodeWriter& writer);
/// \brief Given kernel input variables i_* produce register variables o_coordinates{i}
......@@ -127,6 +137,13 @@ namespace ngraph
std::string i_reduced_strides,
std::string o_coordinates,
size_t rank);
static void coordinate_transform_to_multi_d(codegen::CodeWriter& writer,
std::string i_strides,
std::string i_stride_magic,
std::string i_stride_shift,
std::string i_coord_product,
std::string o_coordinates,
size_t rank);
};
}
}
......
......@@ -149,6 +149,11 @@ namespace ngraph
}
auto convolution = static_cast<const ngraph::op::Convolution*>(node);
auto input_shape = args[0].get_shape();
auto filter_shape = args[1].get_shape();
auto output_shape = out[0].get_shape();
auto rank = input_shape.size();
Strides window_dilation_strides = convolution->get_window_dilation_strides();
Strides window_movement_strides = convolution->get_window_movement_strides();
Strides data_dilation_strides = convolution->get_data_dilation_strides();
......@@ -157,94 +162,190 @@ namespace ngraph
if (padding_below_diff.size() > 3)
{
throw std::runtime_error(node->get_name() +
"with more than 3D is not implemented.");
}
bool is_deconvolution = false;
for (auto a : data_dilation_strides)
{
if (a != 1)
{
is_deconvolution = true;
break;
}
}
// Reshape from NC{d1,..,dn} -> C{d1,...,dn}N
// and from KC{df1,...,dfn} -> C{df1,...,dfn}N.
bool pad_required = (padding_below_diff != padding_above_diff);
// TODO: This should be done via a pass similar to
// what is done for convolution in the IA transformer
// c.f runtime/cpu/pass/cpu_layout.cpp
Shape padding_below(padding_below_diff.size(), 0);
Shape padding_above(padding_above_diff.size(), 0);
for (int i = 0; i < padding_below.size(); i++)
{
padding_below[i] = static_cast<size_t>(padding_below_diff[i]);
padding_above[i] = static_cast<size_t>(padding_above_diff[i]);
}
auto input_shape = args[0].get_shape();
Shape input_shape_padded = input_shape;
Shape padding_interior(data_dilation_strides);
writer.block_begin();
if (pad_required || is_deconvolution)
{
input_shape_padded = get_padded_shape(
input_shape, padding_below, padding_above, padding_interior);
Shape input_padded_strides = row_major_strides(input_shape_padded);
auto temp_size =
shape_size(input_shape_padded) * args[0].get_element_type().size();
GPUAllocator allocator =
external_function->get_primitive_emitter()->get_memory_allocator();
size_t idx_workspace = allocator.reserve_workspace(temp_size);
writer << "void* pad_buffer = runtime::gpu::invoke_memory_primitive(ctx, "
<< idx_workspace << ");\n";
writer << "std::vector<" << args[0].get_type() << "> pad_buffer_host("
<< shape_size(input_shape_padded) << ", 0);\n";
writer << "runtime::gpu::cuda_memcpyHtD(pad_buffer, pad_buffer_host.data(), "
<< temp_size << ");\n";
size_t transposed_data_idx = allocator.reserve_workspace(
args[0].get_size() * args[0].get_element_type().size());
size_t transposed_filter_idx = allocator.reserve_workspace(
args[1].get_size() * args[1].get_element_type().size());
size_t transposed_output_idx = allocator.reserve_workspace(
out[0].get_size() * out[0].get_element_type().size());
GPUShape input_order;
for (int i = 1; i <= rank; i++)
{
input_order.push_back(i % rank);
}
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
auto pad_dynamic_index =
cuda_emitter->build_pad_dynamic(external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
input_shape,
input_shape_padded,
padding_below,
padding_interior);
writer << "gpu::invoke_primitive(ctx, " << pad_dynamic_index << ", ";
size_t reshape_data_index =
cuda_emitter->build_reshape(external_function->ctx().get(),
{{args[0].get_type(), args[0].get_type()}},
input_shape,
input_order);
writer << "void* data = gpu::invoke_memory_primitive(ctx, "
<< transposed_data_idx << ");\n";
writer << "gpu::invoke_primitive(ctx, " << reshape_data_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{pad_buffer}.data()";
writer << ");\n";
// asymetric padding has been applied, zero out padding vectors to
// ensure cuDNN does not assume padding
std::fill(padding_below.begin(), padding_below.end(), 0);
}
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
size_t index = cudnn_emitter->build_convolution(external_function->ctx().get(),
out[0].get_type(),
input_shape_padded,
args[1].get_shape(),
out[0].get_shape(),
window_movement_strides,
window_dilation_strides,
padding_below);
writer << "std::vector<void*>{data}.data());\n";
size_t reshape_filter_index =
cuda_emitter->build_reshape(external_function->ctx().get(),
{{args[1].get_type(), args[1].get_type()}},
filter_shape,
input_order);
writer << "void* filter = gpu::invoke_memory_primitive(ctx, "
<< transposed_filter_idx << ");\n";
writer << "gpu::invoke_primitive(ctx, " << reshape_filter_index << ", ";
writer << "std::vector<void*>{" << args[1].get_name() << "}.data(), ";
writer << "std::vector<void*>{filter}.data());\n";
// local helper to reshape tensor shape objects
auto reshape = [](const Shape& shape, const GPUShape& order) {
Shape output(shape.size(), 0);
for (size_t i = 0; i < shape.size(); i++)
{
output[i] = shape[order[i]];
}
return output;
};
// reorder axes of the input shape (NC{d_1,...,d_n} -> C{d_1,...,d_n}N)
input_shape = reshape(input_shape, input_order);
// reorder axes of the filter shape (KC{df_1,...,df_n} -> C{df_1,...,df_n}K)
filter_shape = reshape(filter_shape, input_order);
// reorder axes of the output shape (NK{do_1,...,do_n} -> K{do_1,...,do_n}N)
output_shape = reshape(output_shape, input_order);
size_t conv_index = cuda_emitter->build_convolution(
external_function->ctx().get(),
{{args[0].get_type(), args[1].get_type(), out[0].get_type()}},
input_shape,
padding_below_diff,
data_dilation_strides,
filter_shape,
window_movement_strides,
window_dilation_strides,
output_shape);
writer << "void* output = gpu::invoke_memory_primitive(ctx, "
<< transposed_output_idx << ");\n";
writer << "gpu::invoke_primitive(ctx, " << conv_index << ", ";
writer << "std::vector<void*>{data, filter}.data(), ";
writer << "std::vector<void*>{output}.data());\n";
// reshape output tensor (K{do_1,...,do_n}N -> NK{do_1,...,do_n})
input_order.clear();
input_order.push_back(static_cast<int>(rank - 1));
for (int i = 0; i < rank - 1; i++)
{
input_order.push_back(i);
}
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
if (pad_required || is_deconvolution)
{
writer << "std::vector<void*>{pad_buffer, " << args[1].get_name()
<< "}.data(), ";
size_t reshape_output_index =
cuda_emitter->build_reshape(external_function->ctx().get(),
{{args[1].get_type(), args[1].get_type()}},
output_shape,
input_order);
writer << "gpu::invoke_primitive(ctx, " << reshape_output_index << ", ";
writer << "std::vector<void*>{output}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data());\n";
}
else
{
writer << "std::vector<void*>{" << args[0].get_name() << ","
<< args[1].get_name() << "}.data(), ";
bool is_deconvolution = false;
for (auto a : data_dilation_strides)
{
if (a != 1)
{
is_deconvolution = true;
break;
}
}
bool pad_required = (padding_below_diff != padding_above_diff);
Shape padding_below(padding_below_diff.size(), 0);
Shape padding_above(padding_above_diff.size(), 0);
for (int i = 0; i < padding_below.size(); i++)
{
padding_below[i] = static_cast<size_t>(padding_below_diff[i]);
padding_above[i] = static_cast<size_t>(padding_above_diff[i]);
}
Shape input_shape_padded = input_shape;
Shape padding_interior(data_dilation_strides);
writer.block_begin();
if (pad_required || is_deconvolution)
{
input_shape_padded = get_padded_shape(
input_shape, padding_below, padding_above, padding_interior);
Shape input_padded_strides = row_major_strides(input_shape_padded);
auto temp_size =
shape_size(input_shape_padded) * args[0].get_element_type().size();
GPUAllocator allocator =
external_function->get_primitive_emitter()->get_memory_allocator();
size_t idx_workspace = allocator.reserve_workspace(temp_size);
writer << "void* pad_buffer = runtime::gpu::invoke_memory_primitive(ctx, "
<< idx_workspace << ");\n";
writer << "std::vector<" << args[0].get_type() << "> pad_buffer_host("
<< shape_size(input_shape_padded) << ", 0);\n";
writer
<< "runtime::gpu::cuda_memcpyHtD(pad_buffer, pad_buffer_host.data(), "
<< temp_size << ");\n";
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
auto pad_dynamic_index = cuda_emitter->build_pad_dynamic(
external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
input_shape,
input_shape_padded,
padding_below,
padding_interior);
writer << "gpu::invoke_primitive(ctx, " << pad_dynamic_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{pad_buffer}.data()";
writer << ");\n";
// asymetric padding has been applied, zero out padding vectors to
// ensure cudnn does not assume padding
std::fill(padding_below.begin(), padding_below.end(), 0);
}
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
size_t index = cudnn_emitter->build_convolution(external_function->ctx().get(),
out[0].get_type(),
input_shape_padded,
args[1].get_shape(),
out[0].get_shape(),
window_movement_strides,
window_dilation_strides,
padding_below);
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
if (pad_required || is_deconvolution)
{
writer << "std::vector<void*>{pad_buffer, " << args[1].get_name()
<< "}.data(), ";
}
else
{
writer << "std::vector<void*>{" << args[0].get_name() << ","
<< args[1].get_name() << "}.data(), ";
}
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
writer.block_end();
}
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
writer.block_end();
}
template <>
......
......@@ -42,6 +42,12 @@ namespace ngraph
GPUAllocator(const GPUAllocator& g);
~GPUAllocator();
template <typename T>
size_t reserve_argspace(const T& container)
{
return reserve_argspace(container.data(),
container.size() * sizeof(typename T::value_type));
}
size_t reserve_argspace(const void* data, size_t size);
size_t reserve_workspace(size_t size, bool zero_initialize = true);
......
......@@ -21,6 +21,7 @@
#include <vector>
#include "ngraph/axis_set.hpp"
#include "ngraph/axis_vector.hpp"
#include "ngraph/coordinate.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/shape.hpp"
......@@ -30,45 +31,45 @@ namespace ngraph
{
class Shape;
/// \brief Shape for a tensor resident on GPU.
class GPUShape : public std::vector<uint32_t>
class GPUShape : public std::vector<int32_t>
{
public:
GPUShape(const std::initializer_list<uint32_t>& axis_lengths)
: std::vector<uint32_t>(axis_lengths)
GPUShape(const std::initializer_list<int32_t>& axis_lengths)
: std::vector<int32_t>(axis_lengths)
{
}
GPUShape(const std::vector<uint32_t>& axis_lengths)
: std::vector<uint32_t>(axis_lengths)
GPUShape(const std::vector<int32_t>& axis_lengths)
: std::vector<int32_t>(axis_lengths)
{
}
GPUShape(const GPUShape& axis_lengths)
: std::vector<uint32_t>(axis_lengths)
: std::vector<int32_t>(axis_lengths)
{
}
explicit GPUShape(size_t n, uint32_t initial_value = 0)
: std::vector<uint32_t>(n, initial_value)
explicit GPUShape(size_t n, int32_t initial_value = 0)
: std::vector<int32_t>(n, initial_value)
{
}
template <class InputIterator>
GPUShape(InputIterator first, InputIterator last)
: std::vector<uint32_t>(first, last)
: std::vector<int32_t>(first, last)
{
}
GPUShape() {}
GPUShape& operator=(const GPUShape& v)
{
static_cast<std::vector<uint32_t>*>(this)->operator=(v);
static_cast<std::vector<int32_t>*>(this)->operator=(v);
return *this;
}
GPUShape& operator=(GPUShape&& v)
{
static_cast<std::vector<uint32_t>*>(this)->operator=(v);
static_cast<std::vector<int32_t>*>(this)->operator=(v);
return *this;
}
......@@ -81,7 +82,7 @@ namespace ngraph
throw std::runtime_error(
"Request exceeds the bitwidth available for GPUShapes (32)");
}
this->push_back(static_cast<uint32_t>(size));
this->push_back(static_cast<int32_t>(size));
}
}
......@@ -95,7 +96,7 @@ namespace ngraph
"Request for Shape which exceeds the bitwidth available for GPUShapes "
"(32)");
}
this->push_back(static_cast<uint32_t>(size));
this->push_back(static_cast<int32_t>(size));
}
}
......@@ -109,7 +110,7 @@ namespace ngraph
"Request for Strides which exceed the bitwidth available for GPUShapes "
"(32)");
}
this->push_back(static_cast<uint32_t>(size));
this->push_back(static_cast<int32_t>(size));
}
}
......@@ -123,21 +124,36 @@ namespace ngraph
"Request for Coordinate which exceed the bitwidth available for GPUShapes "
"(32)");
}
this->push_back(static_cast<uint32_t>(size));
this->push_back(static_cast<int32_t>(size));
}
}
GPUShape(const CoordinateDiff& coord)
{
for (auto const& size : coord)
for (auto const& dim : coord)
{
if (dim > 0 && dim >> 32 != 0)
{
throw std::runtime_error(
"Request for CoordinateDiff which exceed the bitwidth available for "
"GPUShapes "
"(32)");
}
this->push_back(static_cast<int32_t>(dim));
}
}
GPUShape(const AxisVector& vec)
{
for (auto const& size : vec)
{
if (size >> 32 != 0)
{
throw std::runtime_error(
"Request for Coordinate which exceed the bitwidth available for GPUShapes "
"Request for axis vector which exceed the bitwidth available for GPUShapes "
"(32)");
}
this->push_back(static_cast<uint32_t>(size));
this->push_back(static_cast<int32_t>(size));
}
}
};
......
......@@ -186,3 +186,9 @@ std::pair<uint64_t, uint64_t> runtime::gpu::idiv_magic_u64(uint64_t divisor)
{
return magicU64(divisor);
}
uint32_t runtime::gpu::idiv_ceil(int n, int d)
{
// compiler fused modulo and division
return n / d + (n % d > 0);
}
......@@ -103,6 +103,7 @@ namespace ngraph
void cuda_memset(void* dst, int value, size_t buffer_size);
std::pair<uint64_t, uint64_t> idiv_magic_u32(uint64_t max_numerator, uint64_t divisor);
std::pair<uint64_t, uint64_t> idiv_magic_u64(uint64_t divisor);
uint32_t idiv_ceil(int n, int d);
template <typename T>
void print_gpu_tensor(const void* p, size_t element_count)
......
......@@ -6,17 +6,6 @@ batch_norm_three_outputs
computation_reuse
#int64 is not supprted
concat_matrix_int64
#convolution 4d is work in progress
convolution_4d_2items
convolution_4d_4items
convolution_4d_4items_dilated
convolution_4d_4items_padded_neg
convolution_4d_4items_strided
convolution_4d_4items_strided_dilated
convolution_4d_4items_strided_dilated_padded
convolution_4d_4items_strided_dilated_padded_neg
convolution_4d_4items_strided_dilated_padded_same
#cuDNN does not have arithmetic exceptions
divide_by_zero_int32
#int64 is not supprted by cuDNN
dot_matrix_vector_int64
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment