Commit b5e69eaa authored by Fenglei's avatar Fenglei Committed by Jayaram Bobba

gpu reshape optimization (#1174)

* add gpu_timer to external function

* compiled version

* working version

* using block_begin and block_end

* add the missing '
;'

* move slice to cuda emiter

* change size_t to uint32_t in kernel

* working version

* change block size from 1 to 64

* fix bugs

* nthreads need to be size_t in broadcast op

* add rank to kernel name hash

* change reshape to cuda_emitter

* fix bugs

* bug, remove rank from kernel

* clang format

* update slice in convolution

* resolve index conflict

* change align to align_to_blocksize, add overflow check

* add gird size check and fix pool merge bug

* code style, change names

* fix merge conflict

* change kernel_runner to kernel_launch
parent cf568ef9
This diff is collapsed.
......@@ -211,27 +211,29 @@ void runtime::gpu::CudaKernelBuilder::get_onehot_op(codegen::CodeWriter& writer,
void runtime::gpu::CudaKernelBuilder::get_reshape_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types)
const std::array<std::string, 2>& data_types,
size_t rank)
{
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1]
<< "* out, size_t* input_strides, size_t* trans_strides, size_t rank, size_t n)\n";
<< "* out, uint32_t* input_strides, uint32_t* trans_strides, uint32_t n)\n";
writer.block_begin();
{
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < n)\n";
writer.block_begin();
{
writer << "size_t input_idx = tid;\n";
writer << "size_t output_idx = 0;\n";
writer << "for(size_t i = 0; i < rank; i++)\n";
writer.block_begin();
writer << "uint32_t input_idx = tid;\n";
writer << "uint32_t output_idx = 0;\n";
size_t i = 0;
for (; i < rank - 1; i++)
{
writer << "output_idx += (input_idx / input_strides[i]) * trans_strides[i];\n";
writer << "input_idx %= input_strides[i];\n";
writer << "output_idx += (input_idx / input_strides[" << i << "]) * trans_strides["
<< i << "];\n";
writer << "input_idx %= input_strides[" << i << "];\n";
}
writer.block_end();
writer << "output_idx += (input_idx / input_strides[" << i << "]) * trans_strides[" << i
<< "];\n";
writer << "out[output_idx] = in[tid];\n";
}
writer.block_end();
......
......@@ -55,7 +55,8 @@ namespace ngraph
static void get_reshape_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types);
const std::array<std::string, 2>& data_types,
size_t rank);
static void get_slice_op(codegen::CodeWriter& writer,
const std::string& name,
......
......@@ -60,43 +60,6 @@ void runtime::gpu::emit_onehot(const std::string& name,
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
void runtime::gpu::emit_reshape(const std::string& name,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
CUdeviceptr in,
CUdeviceptr out,
CUdeviceptr input_strides,
CUdeviceptr trans_strides,
size_t rank,
size_t count)
{
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
auto compiled_kernel = ctx->compiled_kernel_pool->get(name_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_reshape_op(writer, name_signature, data_types);
std::string kernel = writer.get_code();
compiled_kernel = ctx->compiled_kernel_pool->set(name_signature, kernel);
}
void* args_list[] = {&in, &out, &input_strides, &trans_strides, &rank, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(count),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
void runtime::gpu::emit_reverse(const std::string& name,
CUdeviceptr in,
CUdeviceptr out,
......
......@@ -45,16 +45,6 @@ namespace ngraph
size_t repeat_times,
size_t count);
void emit_reshape(const std::string& name,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
CUdeviceptr in,
CUdeviceptr out,
CUdeviceptr input_strides,
CUdeviceptr trans_strides,
size_t rank,
size_t count);
void emit_reverse(const std::string& name,
CUdeviceptr in,
CUdeviceptr out,
......
......@@ -921,12 +921,8 @@ namespace ngraph
auto result_shape = out[0].get_shape();
auto input_order = reshape->get_input_order();
bool same_layout = is_sorted(input_order.begin(), input_order.end());
size_t result_shape_product = 1;
size_t result_shape_product = shape_size(result_shape);
for (auto i : result_shape)
{
result_shape_product *= i;
}
// If there is no layout change or we are just going from 1^n to 1^m or a zero-size tensor,
// we can just copy.
if (same_layout || result_shape_product < 2)
......@@ -956,47 +952,18 @@ namespace ngraph
// Other cases (reordering of axes for tensors with rank>2).
else
{
std::vector<size_t> input_strides(arg_rank);
std::vector<size_t> output_strides(arg_rank);
std::vector<size_t> trans_strides(arg_rank);
size_t stride = 1;
for (int i = static_cast<int>(arg_rank) - 1; i >= 0; i--)
{
input_strides[i] = stride;
stride *= arg_shape[i];
}
stride = 1;
for (int i = static_cast<int>(arg_rank) - 1; i >= 0; i--)
{
output_strides[i] = stride;
stride *= arg_shape[input_order[i]];
}
for (int i = 0; i < arg_rank; i++)
{
trans_strides[input_order[i]] = output_strides[i];
}
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
auto index =
cuda_emitter->build_reshape(external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
arg_shape,
input_order);
GPUAllocator allocator =
external_function->get_primitive_emitter()->get_memory_allocator();
size_t idx_input_strides = allocator.reserve_argspace(
input_strides.data(), input_strides.size() * sizeof(size_t));
size_t idx_trans_strides = allocator.reserve_argspace(
trans_strides.data(), trans_strides.size() * sizeof(size_t));
writer << "void* input_strides_d = "
"runtime::gpu::invoke_memory_primitive(ctx, "
<< idx_input_strides << ");\n";
writer << "void* trans_strides_d = "
"runtime::gpu::invoke_memory_primitive(ctx, "
<< idx_trans_strides << ");\n";
writer << "runtime::gpu::emit_reshape(\"" << node->description() << "\", {\""
<< args[0].get_type() << "\", \"" << out[0].get_type() << "\"}"
<< ", ctx"
<< ", CUdeviceptr(" << args[0].get_name() << "), CUdeviceptr("
<< out[0].get_name() << ")"
<< ", "
<< "CUdeviceptr(input_strides_d), CUdeviceptr(trans_strides_d)"
<< ", " << arg_rank << ", " << args[0].get_size() << ");\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_end();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment