Commit f243d035 authored by Fenglei's avatar Fenglei Committed by Scott Cyphers

gpu slice optimization (#1172)

* add gpu_timer to external function

* compiled version

* working version

* using block_begin and block_end

* add the missing '
;'

* move slice to cuda emiter

* change size_t to uint32_t in kernel

* working version

* change block size from 1 to 64

* fix bugs

* nthreads need to be size_t in broadcast op

* add rank to kernel name hash

* update slice in convolution

* resolve index conflict

* change align to align_to_blocksize, add overflow check

* add gird size check and fix pool merge bug

* code style, change names
parent c4c24cb0
...@@ -97,7 +97,10 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex ...@@ -97,7 +97,10 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex
return primitive_index; return primitive_index;
} }
size_t nthreads = shape_size(output_shape); uint32_t nthreads = static_cast<uint32_t>(shape_size(output_shape));
//TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);
// if the kernel has not been compiled, build it // if the kernel has not been compiled, build it
auto compiled_kernel = ctx->compiled_kernel_pool->get(hash); auto compiled_kernel = ctx->compiled_kernel_pool->get(hash);
...@@ -193,10 +196,10 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex ...@@ -193,10 +196,10 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex
pad.reset(new gpu::primitive{[=](void** inputs, void** outputs) { pad.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
void* args_list[] = {&inputs[1], &inputs[0], &outputs[0]}; void* args_list[] = {&inputs[1], &inputs[0], &outputs[0]};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(), CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(nthreads), aligned_grid_size_x,
1, 1,
1, // grid dim 1, // grid dim
1, block_size_x,
1, 1,
1, // block dim 1, // block dim
0, 0,
...@@ -211,10 +214,10 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex ...@@ -211,10 +214,10 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex
pad.reset(new gpu::primitive{[=](void** inputs, void** outputs) { pad.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
void* args_list[] = {&inputs[0], &outputs[0]}; void* args_list[] = {&inputs[0], &outputs[0]};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(), CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(nthreads), aligned_grid_size_x,
1, 1,
1, // grid dim 1, // grid dim
1, block_size_x,
1, 1,
1, // block dim 1, // block dim
0, 0,
...@@ -330,6 +333,95 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const runtime::gpu::GPURunti ...@@ -330,6 +333,95 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const runtime::gpu::GPURunti
return primitive_index; return primitive_index;
} }
size_t runtime::gpu::CUDAEmitter::build_slice(const runtime::gpu::GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
GPUShape lower_bounds,
GPUShape slice_strides,
GPUShape output_shape)
{
std::stringstream kernel_name;
kernel_name << "slice_" << join(dtypes, "_") << "_r_" << output_shape.size();
std::string hash = kernel_name.str() + "_i_" + join(input_shape, "_") + "_o_" +
join(output_shape, "_") + "_lb_" + join(lower_bounds, "_") + "_ss_" +
join(slice_strides, "_");
// For backwards compatability we currently use two unordered maps
// 1. one looks up the compiled cuda kernel (CudaFunctionPool)
// 2. the other looks to see if this kernel is already in the primitive list
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto compiled_kernel = ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_slice_op(writer, kernel_name.str(), dtypes, output_shape.size());
compiled_kernel = ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
uint32_t nthreads = static_cast<uint32_t>(shape_size(output_shape));
//TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);
GPUShape output_strides = row_major_strides(output_shape);
GPUShape input_strides = row_major_strides(input_shape);
// get an allocator for transient per kernel gpu memory
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t idx_input_strides =
allocator.reserve_argspace(input_strides.data(), input_strides.size() * sizeof(uint32_t));
size_t idx_output_strides =
allocator.reserve_argspace(output_strides.data(), output_strides.size() * sizeof(uint32_t));
size_t idx_lower_bounds =
allocator.reserve_argspace(lower_bounds.data(), lower_bounds.size() * sizeof(uint32_t));
size_t idx_slice_strides =
allocator.reserve_argspace(slice_strides.data(), slice_strides.size() * sizeof(uint32_t));
// create the launch primitive
std::unique_ptr<gpu::primitive> kernel_runner(new gpu::primitive{[=](void** inputs,
void** outputs) mutable {
void* param_input_strides = runtime::gpu::invoke_memory_primitive(ctx, idx_input_strides);
void* param_output_strides = runtime::gpu::invoke_memory_primitive(ctx, idx_output_strides);
void* param_lower_bounds = runtime::gpu::invoke_memory_primitive(ctx, idx_lower_bounds);
void* param_slice_strides = runtime::gpu::invoke_memory_primitive(ctx, idx_slice_strides);
std::vector<void*> args_list{&inputs[0],
&outputs[0],
&param_input_strides,
&param_lower_bounds,
&param_slice_strides,
&param_output_strides,
&nthreads};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1, // grid dim
block_size_x,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_runner));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_reverse_sequence(const runtime::gpu::GPURuntimeContext* ctx, size_t runtime::gpu::CUDAEmitter::build_reverse_sequence(const runtime::gpu::GPURuntimeContext* ctx,
const std::array<std::string, 3>& dtypes, const std::array<std::string, 3>& dtypes,
GPUShape input_shape0, GPUShape input_shape0,
...@@ -370,6 +462,9 @@ size_t runtime::gpu::CUDAEmitter::build_reverse_sequence(const runtime::gpu::GPU ...@@ -370,6 +462,9 @@ size_t runtime::gpu::CUDAEmitter::build_reverse_sequence(const runtime::gpu::GPU
} }
uint32_t nthreads = static_cast<uint32_t>(shape_size(output_shape)); uint32_t nthreads = static_cast<uint32_t>(shape_size(output_shape));
//TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);
GPUShape output_strides = row_major_strides(output_shape); GPUShape output_strides = row_major_strides(output_shape);
// get an allocator for transient per kernel gpu memory // get an allocator for transient per kernel gpu memory
...@@ -380,33 +475,32 @@ size_t runtime::gpu::CUDAEmitter::build_reverse_sequence(const runtime::gpu::GPU ...@@ -380,33 +475,32 @@ size_t runtime::gpu::CUDAEmitter::build_reverse_sequence(const runtime::gpu::GPU
allocator.reserve_argspace(output_strides.data(), output_strides.size() * sizeof(uint32_t)); allocator.reserve_argspace(output_strides.data(), output_strides.size() * sizeof(uint32_t));
// create the launch primitive // create the launch primitive
std::unique_ptr<gpu::primitive> reserve_sequence( std::unique_ptr<gpu::primitive> kernel_runner(new gpu::primitive{[=](void** inputs,
new gpu::primitive{[=](void** inputs, void** outputs) mutable { void** outputs) mutable {
void* param_output_shape = runtime::gpu::invoke_memory_primitive(ctx, idx_output_shape); void* param_output_shape = runtime::gpu::invoke_memory_primitive(ctx, idx_output_shape);
void* param_output_strides = void* param_output_strides = runtime::gpu::invoke_memory_primitive(ctx, idx_output_strides);
runtime::gpu::invoke_memory_primitive(ctx, idx_output_strides); std::vector<void*> args_list{&inputs[0],
std::vector<void*> args_list{&inputs[0], &inputs[1],
&inputs[1], &outputs[0],
&outputs[0], &param_output_shape,
&param_output_shape, &param_output_strides,
&param_output_strides, &nthreads};
&nthreads};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(), CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(nthreads), aligned_grid_size_x,
1, 1,
1, // grid dim 1, // grid dim
1, block_size_x,
1, 1,
1, // block dim 1, // block dim
0, 0,
NULL, // shared mem and stream NULL, // shared mem and stream
args_list.data(), args_list.data(),
0)); // arguments 0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output. CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}}); }});
primitive_index = this->m_primitive_emitter->insert(std::move(reserve_sequence)); primitive_index = this->m_primitive_emitter->insert(std::move(kernel_runner));
m_primitive_emitter->cache(hash, primitive_index); m_primitive_emitter->cache(hash, primitive_index);
return primitive_index; return primitive_index;
} }
...@@ -436,6 +530,12 @@ size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const GPURuntimeContext* ctx ...@@ -436,6 +530,12 @@ size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const GPURuntimeContext* ctx
return primitive_index; return primitive_index;
} }
size_t nthreads = shape_size(output_shape);
//TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x =
align_to_block_size(static_cast<uint32_t>(nthreads), block_size_x);
// if the kernel has not been compiled, build it // if the kernel has not been compiled, build it
auto compiled_kernel = ctx->compiled_kernel_pool->get(hash); auto compiled_kernel = ctx->compiled_kernel_pool->get(hash);
if (compiled_kernel == nullptr) if (compiled_kernel == nullptr)
...@@ -446,16 +546,14 @@ size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const GPURuntimeContext* ctx ...@@ -446,16 +546,14 @@ size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const GPURuntimeContext* ctx
compiled_kernel = ctx->compiled_kernel_pool->set(hash, writer.get_code()); compiled_kernel = ctx->compiled_kernel_pool->set(hash, writer.get_code());
} }
size_t nthreads = shape_size(output_shape);
std::unique_ptr<gpu::primitive> pool( std::unique_ptr<gpu::primitive> pool(
new gpu::primitive{[=](void** inputs, void** outputs) mutable { new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void* args_list[] = {&inputs[0], &outputs[0], &nthreads}; void* args_list[] = {&inputs[0], &outputs[0], &nthreads};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(), CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(nthreads), aligned_grid_size_x,
1, 1,
1, // grid dim 1, // grid dim
1, block_size_x,
1, 1,
1, // block dim 1, // block dim
0, 0,
...@@ -702,6 +800,10 @@ size_t runtime::gpu::CUDAEmitter::build_elementwise_n_to_1(const GPURuntimeConte ...@@ -702,6 +800,10 @@ size_t runtime::gpu::CUDAEmitter::build_elementwise_n_to_1(const GPURuntimeConte
compiled_kernel = ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code()); compiled_kernel = ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
} }
size_t nthreads = shape_size(tensor_shape); size_t nthreads = shape_size(tensor_shape);
//TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x =
align_to_block_size(static_cast<uint32_t>(nthreads), block_size_x);
// create the launch primitive // create the launch primitive
std::unique_ptr<gpu::primitive> ew( std::unique_ptr<gpu::primitive> ew(
...@@ -714,10 +816,10 @@ size_t runtime::gpu::CUDAEmitter::build_elementwise_n_to_1(const GPURuntimeConte ...@@ -714,10 +816,10 @@ size_t runtime::gpu::CUDAEmitter::build_elementwise_n_to_1(const GPURuntimeConte
args_list.push_back(&outputs[0]); args_list.push_back(&outputs[0]);
args_list.push_back(&nthreads); args_list.push_back(&nthreads);
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(), CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(nthreads), aligned_grid_size_x,
1, 1,
1, // grid dim 1, // grid dim
1, block_size_x,
1, 1,
1, // block dim 1, // block dim
0, 0,
...@@ -1194,6 +1296,10 @@ size_t runtime::gpu::CUDAEmitter::build_broadcast(const GPURuntimeContext* ctx, ...@@ -1194,6 +1296,10 @@ size_t runtime::gpu::CUDAEmitter::build_broadcast(const GPURuntimeContext* ctx,
float beta = 0.0f; float beta = 0.0f;
int nthreads = static_cast<int>(shape_size(result_shape)); int nthreads = static_cast<int>(shape_size(result_shape));
//TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x =
align_to_block_size(static_cast<uint32_t>(shape_size(result_shape)), block_size_x);
std::unique_ptr<gpu::primitive> broadcast(new gpu::primitive{[=](void** inputs, std::unique_ptr<gpu::primitive> broadcast(new gpu::primitive{[=](void** inputs,
void** outputs) mutable { void** outputs) mutable {
...@@ -1212,8 +1318,17 @@ size_t runtime::gpu::CUDAEmitter::build_broadcast(const GPURuntimeContext* ctx, ...@@ -1212,8 +1318,17 @@ size_t runtime::gpu::CUDAEmitter::build_broadcast(const GPURuntimeContext* ctx,
&beta, &beta,
&nthreads}; &nthreads};
CUDA_SAFE_CALL( CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
cuLaunchKernel(*compiled_kernel.get(), nthreads, 1, 1, 1, 1, 1, 0, NULL, args_list, 0)); aligned_grid_size_x,
1,
1,
block_size_x,
1,
1,
0,
NULL,
args_list,
0));
CUDA_SAFE_CALL(cuCtxSynchronize()); CUDA_SAFE_CALL(cuCtxSynchronize());
}}); }});
...@@ -1697,3 +1812,17 @@ __device__ __forceinline__ int64_t load(const int64_t* __restrict__ in, int i= ...@@ -1697,3 +1812,17 @@ __device__ __forceinline__ int64_t load(const int64_t* __restrict__ in, int i=
)"; )";
return ss.str(); return ss.str();
} }
uint32_t runtime::gpu::CUDAEmitter::align_to_block_size(uint32_t grid_size, uint32_t block_size)
{
if (grid_size > (1u << 31) - 1)
{
throw std::runtime_error("Cuda can't handle grid_size_x > 2^31 - 1.");
}
uint32_t r = (grid_size + block_size - 1) / block_size * block_size;
if (grid_size > (1u << 31) - 1)
{
throw std::runtime_error("Cuda can't handle grid_size_x > 2^31 - 1.");
}
return r;
}
...@@ -47,7 +47,7 @@ namespace ngraph ...@@ -47,7 +47,7 @@ namespace ngraph
GPUShape pad_interior, GPUShape pad_interior,
const std::string& pad_value = ""); const std::string& pad_value = "");
size_t build_pad_dynamic(const runtime::gpu::GPURuntimeContext* ctx, size_t build_pad_dynamic(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes, const std::array<std::string, 2>& dtypes,
GPUShape input_shape, GPUShape input_shape,
GPUShape output_shape, GPUShape output_shape,
...@@ -70,6 +70,13 @@ namespace ngraph ...@@ -70,6 +70,13 @@ namespace ngraph
GPUShape padding_below, GPUShape padding_below,
bool include_pad = false); bool include_pad = false);
size_t build_slice(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
GPUShape lower_bounds,
GPUShape slice_strides,
GPUShape output_shape);
size_t build_reduce_window(const GPURuntimeContext* ctx, size_t build_reduce_window(const GPURuntimeContext* ctx,
const OpName op_name, const OpName op_name,
const std::vector<std::string>& dtypes, const std::vector<std::string>& dtypes,
...@@ -144,6 +151,7 @@ namespace ngraph ...@@ -144,6 +151,7 @@ namespace ngraph
private: private:
CUDAEmitter(GPUPrimitiveEmitter* emitter); CUDAEmitter(GPUPrimitiveEmitter* emitter);
uint32_t align_to_block_size(uint32_t grid_size, uint32_t block_size);
void print_tensor_from_gpu(codegen::CodeWriter& writer, void print_tensor_from_gpu(codegen::CodeWriter& writer,
const std::string& tensor_name, const std::string& tensor_name,
GPUShape shape); GPUShape shape);
......
...@@ -363,30 +363,36 @@ void runtime::gpu::CudaKernelBuilder::get_reverse_sequence_op( ...@@ -363,30 +363,36 @@ void runtime::gpu::CudaKernelBuilder::get_reverse_sequence_op(
} }
void runtime::gpu::CudaKernelBuilder::get_slice_op(codegen::CodeWriter& writer, void runtime::gpu::CudaKernelBuilder::get_slice_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
const std::array<std::string, 2>& data_types) const std::array<std::string, 2>& data_types,
size_t rank)
{ {
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, " writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1] << "* out, size_t* input_strides, size_t* lower_bounds, size_t* " << data_types[1] << "* out, uint32_t* input_strides, uint32_t* lower_bounds, uint32_t* "
"slice_strides, size_t* output_strides, size_t rank, size_t n)\n"; "slice_strides, uint32_t* output_strides, uint32_t n)\n";
writer.block_begin(); writer.block_begin();
{ {
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n"; writer << "uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < n)\n"; writer << "if (tid < n)\n";
writer.block_begin(); writer.block_begin();
{ {
writer << "size_t input_idx = 0;\n"; writer << "uint32_t input_idx = 0;\n";
writer << "size_t output_idx = tid;\n"; writer << "uint32_t output_idx = tid;\n";
size_t i = 0;
writer << "for(size_t i = 0; i < rank; i++)\n"; for (; i < rank - 1; i++)
writer.block_begin();
{ {
writer << "input_idx += (((output_idx / output_strides[i]) * slice_strides[i]) + " writer << "input_idx += (((output_idx / output_strides[" << i
"lower_bounds[i]) * input_strides[i];\n"; << "]) * slice_strides[" << i << "]) + "
writer << "output_idx %= output_strides[i];\n"; "lower_bounds["
<< i << "]) * input_strides[" << i << "];\n";
writer << "output_idx %= output_strides[" << i << "];\n";
} }
writer.block_end(); writer << "input_idx += (((output_idx / output_strides[" << i << "]) * slice_strides["
<< i << "]) + "
"lower_bounds["
<< i << "]) * input_strides[" << i << "];\n";
writer << "out[tid] = in[input_idx];\n"; writer << "out[tid] = in[input_idx];\n";
} }
writer.block_end(); writer.block_end();
} }
writer.block_end(); writer.block_end();
......
...@@ -59,7 +59,8 @@ namespace ngraph ...@@ -59,7 +59,8 @@ namespace ngraph
static void get_slice_op(codegen::CodeWriter& writer, static void get_slice_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
const std::array<std::string, 2>& data_types); const std::array<std::string, 2>& data_types,
size_t rank);
static void get_reverse_op(codegen::CodeWriter& writer, static void get_reverse_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
......
...@@ -97,46 +97,6 @@ void runtime::gpu::emit_reshape(const std::string& name, ...@@ -97,46 +97,6 @@ void runtime::gpu::emit_reshape(const std::string& name,
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output. CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
} }
void runtime::gpu::emit_slice(const std::string& name,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
CUdeviceptr input_strides,
CUdeviceptr lower_bounds,
CUdeviceptr slice_strides,
CUdeviceptr output_strides,
size_t rank,
size_t count)
{
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
auto compiled_kernel = ctx->compiled_kernel_pool->get(name_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_slice_op(writer, name_signature, data_types);
std::string kernel = writer.get_code();
compiled_kernel = ctx->compiled_kernel_pool->set(name_signature, kernel);
}
void* args_list[] = {
&in, &out, &input_strides, &lower_bounds, &slice_strides, &output_strides, &rank, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(count),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
void runtime::gpu::emit_reverse(const std::string& name, void runtime::gpu::emit_reverse(const std::string& name,
CUdeviceptr in, CUdeviceptr in,
CUdeviceptr out, CUdeviceptr out,
......
...@@ -55,18 +55,6 @@ namespace ngraph ...@@ -55,18 +55,6 @@ namespace ngraph
size_t rank, size_t rank,
size_t count); size_t count);
void emit_slice(const std::string& name,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
CUdeviceptr input_strides,
CUdeviceptr lower_bounds,
CUdeviceptr slice_strides,
CUdeviceptr output_strides,
size_t rank,
size_t count);
void emit_reverse(const std::string& name, void emit_reverse(const std::string& name,
CUdeviceptr in, CUdeviceptr in,
CUdeviceptr out, CUdeviceptr out,
......
...@@ -466,45 +466,20 @@ namespace ngraph ...@@ -466,45 +466,20 @@ namespace ngraph
// since we padded output with temp buffer, we need to copy back to real ouput // since we padded output with temp buffer, we need to copy back to real ouput
if (pad_required || is_deconvolution) if (pad_required || is_deconvolution)
{ {
const auto arg_rank = output_shape.size(); auto& cuda_emitter =
const auto input_strides = row_major_strides(output_shape_padded); external_function->get_primitive_emitter()->get_cuda_emitter();
const auto output_strides = row_major_strides(output_shape); auto slice_index =
GPUAllocator allocator = cuda_emitter->build_slice(external_function->ctx().get(),
external_function->get_primitive_emitter()->get_memory_allocator(); {{args[0].get_type(), out[0].get_type()}},
size_t idx_input_strides = allocator.reserve_argspace( output_shape_padded,
input_strides.data(), input_strides.size() * sizeof(size_t)); padding_below_back,
size_t idx_output_strides = allocator.reserve_argspace( padding_interior_back,
output_strides.data(), output_strides.size() * sizeof(size_t)); output_shape);
size_t idx_lower_bounds = allocator.reserve_argspace(
padding_below_back.data(), padding_below_back.size() * sizeof(size_t));
size_t idx_slice_strides =
allocator.reserve_argspace(padding_interior_back.data(),
padding_interior_back.size() * sizeof(size_t));
writer << "size_t rank = " << arg_rank << ";\n";
writer << "void* input_strides_d = "
<< " runtime::gpu::invoke_memory_primitive(ctx, " << idx_input_strides
<< ");\n";
writer << "void* output_strides_d = "
<< " runtime::gpu::invoke_memory_primitive(ctx, " << idx_output_strides
<< ");\n";
writer << "void* slice_strides_d = "
<< " runtime::gpu::invoke_memory_primitive(ctx, " << idx_slice_strides
<< ");\n";
writer << "void* lower_bounds_d = "
<< " runtime::gpu::invoke_memory_primitive(ctx, " << idx_lower_bounds
<< ");\n";
writer << "runtime::gpu::emit_slice(\"" << node->description() writer << "gpu::invoke_primitive(ctx, " << slice_index << ", ";
<< "\", CUdeviceptr(pad_buffer), CUdeviceptr(" << out[0].get_name() writer << "std::vector<void*>{pad_buffer}.data(), ";
<< ")" writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
<< ", {\"" << args[0].get_type() << "\", \"" << out[0].get_type() writer << ");\n";
<< "\"}"
<< ", "
<< "ctx, "
<< "CUdeviceptr(input_strides_d), CUdeviceptr(lower_bounds_d), "
"CUdeviceptr(slice_strides_d), CUdeviceptr(output_strides_d)"
<< ", " << arg_rank << ", " << out[0].get_size() << ");\n";
} }
writer.block_end(); writer.block_end();
} }
...@@ -1075,12 +1050,9 @@ namespace ngraph ...@@ -1075,12 +1050,9 @@ namespace ngraph
auto slice = static_cast<const op::Slice*>(node); auto slice = static_cast<const op::Slice*>(node);
const auto arg_shape = args[0].get_shape(); const auto arg_shape = args[0].get_shape();
const auto arg_rank = arg_shape.size();
const auto result_shape = out[0].get_shape(); const auto result_shape = out[0].get_shape();
const Coordinate& lower_bounds = slice->get_lower_bounds(); const Coordinate& lower_bounds = slice->get_lower_bounds();
const Strides slice_strides = slice->get_strides(); const Strides slice_strides = slice->get_strides();
const auto input_strides = row_major_strides(arg_shape);
const auto output_strides = row_major_strides(result_shape);
writer.block_begin(); writer.block_begin();
if (args[0].get_size() == out[0].get_size()) if (args[0].get_size() == out[0].get_size())
...@@ -1089,41 +1061,20 @@ namespace ngraph ...@@ -1089,41 +1061,20 @@ namespace ngraph
} }
else else
{ {
GPUAllocator allocator = auto& cuda_emitter =
external_function->get_primitive_emitter()->get_memory_allocator(); external_function->get_primitive_emitter()->get_cuda_emitter();
size_t idx_input_strides = allocator.reserve_argspace( auto index =
input_strides.data(), input_strides.size() * sizeof(size_t)); cuda_emitter->build_slice(external_function->ctx().get(),
size_t idx_output_strides = allocator.reserve_argspace( {{args[0].get_type(), out[0].get_type()}},
output_strides.data(), output_strides.size() * sizeof(size_t)); arg_shape,
size_t idx_lower_bounds = allocator.reserve_argspace( lower_bounds,
lower_bounds.data(), lower_bounds.size() * sizeof(size_t)); slice_strides,
size_t idx_slice_strides = allocator.reserve_argspace( result_shape);
slice_strides.data(), slice_strides.size() * sizeof(size_t));
writer << "size_t rank = " << arg_rank << ";\n";
writer << "void* input_strides_d = "
<< " runtime::gpu::invoke_memory_primitive(ctx, " << idx_input_strides
<< ");\n";
writer << "void* output_strides_d = "
<< " runtime::gpu::invoke_memory_primitive(ctx, " << idx_output_strides
<< ");\n";
writer << "void* slice_strides_d = "
<< " runtime::gpu::invoke_memory_primitive(ctx, " << idx_slice_strides
<< ");\n";
writer << "void* lower_bounds_d = "
<< " runtime::gpu::invoke_memory_primitive(ctx, " << idx_lower_bounds
<< ");\n";
writer << "runtime::gpu::emit_slice(\"" << node->description() writer << "gpu::invoke_primitive(ctx, " << index << ", ";
<< "\", CUdeviceptr(" << args[0].get_name() << "), CUdeviceptr(" writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
<< out[0].get_name() << ")" writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
<< ", {\"" << args[0].get_type() << "\", \"" << out[0].get_type() writer << ");\n";
<< "\"}"
<< ", "
<< "ctx, "
<< "CUdeviceptr(input_strides_d), CUdeviceptr(lower_bounds_d), "
"CUdeviceptr(slice_strides_d), CUdeviceptr(output_strides_d)"
<< ", " << arg_rank << ", " << out[0].get_size() << ");\n";
} }
writer.block_end(); writer.block_end();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment