Commit a3133482 authored by Fenglei's avatar Fenglei Committed by Scott Cyphers

nvgpu cuda softmax optimization (#2101)

* add some helper function

* update with new helper function

* update reduce to nd with new helper function

* update float sum to stable sum

* fix bug

* update all reduce to stable sum for float

* fix bug and pass the sum stable test

* remove debug info

* style

* update with shape

* fix bug

* add host parameters to cuda_emitter

* clang format

* fix bugs

* add element::type support

* format

* add a cached value with datatype name

* add init_reduce_value

* unroll loop

* optimization

* remove the need for init_value

* add memset kernel

* add memcpy

* working version

* remove debug info

* add comments, clean up code.

* change in_idx to input_idx

* fix bug

* change args name for memset in emitter

* pass element::Type instead of string

* the op::reduce come with init value, add support

* resolve codacy-bot comment

* fix bug

* resove codacy-bot comment

* add soft_max_block_reduce kernel

* fix bugs

* add softmax_block_reduce to cuda_emitter

* compiing ok, result wrong

* fix bug in kernel

* working version

* removed unused code

* remove unused comments, resolve comments

* cuda reduce for max, min, mul, reduce op init value, format

* use type::info

* use type info for numeric_limits

* remove code from gpu_host_parameters

* header

* remvoe outdated comments

* add helper to check if stable sum is needed

* add stable sum test for double

* remove extra line

* consolidate helper functions

* no need list now.

* remove extra ;

* clang format

* style

* add skip test for cpu and intelGPU side

* resolve more conflict

* update comment

* fix a warning

* Update src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.cpp

using load.
Co-Authored-By: 's avatarfengleitian <35274053+fengleitian@users.noreply.github.com>

* using WARPSIZE instead of 32, using lambda

* more WARPSIZE instead of 32

* fix block_size_x bug

* using __expf
parent 6584306c
......@@ -1628,21 +1628,27 @@ size_t runtime::gpu::CUDAEmitter::build_primitive(const op::MaxPool* node)
return this->m_primitive_emitter->register_primitive(kernel_launch, hash);
}
size_t runtime::gpu::CUDAEmitter::build_softmax(const std::vector<std::string>& dtypes,
size_t runtime::gpu::CUDAEmitter::build_softmax(const std::vector<element::Type>& dtypes,
NVShape input_shape,
NVShape reduce_axis)
{
size_t rank = input_shape.size();
size_t reduce_rank = reduce_axis.size();
std::vector<std::string> dtypes_str = get_string_vector(dtypes);
NVShape simplified_reduce_axis;
NVShape simplified_input_shape;
simplify_reduce_shape(input_shape, reduce_axis, simplified_input_shape, simplified_reduce_axis);
size_t rank = simplified_input_shape.size();
size_t reduce_rank = simplified_reduce_axis.size();
size_t non_reduce_rank = rank - reduce_rank;
// assumes NC{d1,...,dn} format
std::string kernel_name = "softmax_" + join(dtypes, "_");
kernel_name +=
"_ri_" + std::to_string(input_shape.size()) + "_rr_" + std::to_string(reduce_axis.size());
std::string kernel_name = "softmax_" + join(dtypes_str, "_");
kernel_name += "_ri_" + std::to_string(simplified_input_shape.size()) + "_rr_" +
std::to_string(simplified_reduce_axis.size());
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
std::stringstream ss;
ss << kernel_name << "_s_" << join(input_shape, "_") << "_axis_" << join(reduce_axis, "_");
ss << kernel_name << "_s_" << join(simplified_input_shape, "_") << "_axis_"
<< join(simplified_reduce_axis, "_");
auto hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
......@@ -1651,79 +1657,157 @@ size_t runtime::gpu::CUDAEmitter::build_softmax(const std::vector<std::string>&
return primitive_index;
}
NVShape reduce_flag(rank, 0);
for (auto a : reduce_axis)
{
reduce_flag[a] = 1;
}
NVShape output_shape;
NVShape non_reduce_shape;
NVShape non_reduce_strides;
NVShape non_reduce_strides_in_input;
NVShape reduce_shape;
NVShape reduce_strides;
NVShape input_strides = row_major_strides(input_shape);
for (int i = 0; i < rank; i++)
NVShape reduce_strides_in_input;
get_reduce_strides(simplified_input_shape,
simplified_reduce_axis,
non_reduce_shape,
non_reduce_strides,
non_reduce_strides_in_input,
reduce_shape,
reduce_strides,
reduce_strides_in_input);
std::vector<int> reduce_strides_magic;
std::vector<int> reduce_strides_shift;
std::vector<int> non_reduce_strides_magic;
std::vector<int> non_reduce_strides_shift;
div_to_mul(reduce_strides, reduce_strides_magic, reduce_strides_shift);
div_to_mul(non_reduce_strides, non_reduce_strides_magic, non_reduce_strides_shift);
uint32_t nthreads = static_cast<uint32_t>(shape_size(non_reduce_shape));
// if reduce shape is empty, all result should be 1.
if (reduce_shape.empty())
{
if (reduce_flag[i] != 0)
{
reduce_shape.push_back(input_shape[i]);
reduce_strides.push_back(input_strides[i]);
}
else
size_t memset_idx = build_memset(dtypes_str[0], nthreads);
void* init_value =
m_host_parameters->val_by_datatype(dtypes_str[0], static_cast<int64_t>(1));
// get an allocator for transient per kernel gpu memory
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
// (lazy) allocation for kernel arguments
size_t idx_init_value = allocator.reserve_argspace(init_value, dtypes[0].size());
std::unique_ptr<gpu::primitive> memset(new gpu::primitive{[=](void** inputs,
void** outputs) mutable {
void* init_value_buff = runtime::gpu::invoke_memory_primitive(m_ctx, idx_init_value);
gpu::invoke_primitive(m_ctx,
memset_idx,
std::vector<void*>{init_value_buff}.data(),
std::vector<void*>{outputs[0]}.data());
}});
return this->m_primitive_emitter->register_primitive(memset, hash);
}
// if reduce not include last axis, this is a heuristic to choose by reduce axis for better cache
// a more accurate but slow way is to tune with actual kernel
else if (reduce_strides_in_input.back() != 1)
{
// TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);
auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes_str[0], "in")
.add_placeholder(dtypes_str[1], "out")
.add("non_reduce_strides", non_reduce_strides)
.add("non_reduce_strides_in_input", non_reduce_strides_in_input)
.add("reduce_strides_in_input", reduce_strides_in_input)
.add("reduce_shape", reduce_shape)
.add("nthreads", nthreads);
// if the kernel has not been compiled, build it
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name);
if (compiled_kernel == nullptr)
{
non_reduce_strides.push_back(input_strides[i]);
output_shape.push_back(input_shape[i]);
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
runtime::gpu::CudaKernelBuilder::get_softmax_op(
writer, kernel_name, args, dtypes_str, non_reduce_rank, reduce_rank);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
}
}
NVShape output_strides = row_major_strides(output_shape);
uint32_t nthreads = static_cast<uint32_t>(shape_size(output_shape));
// TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
if (reduce_flag.back() == 1)
{
block_size_x = 8;
}
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);
auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes[0], "in")
.add_placeholder(dtypes[1], "out")
.add("out_strides", output_strides)
.add("non_reduce_strides", non_reduce_strides)
.add("reduce_shape", reduce_shape)
.add("reduce_strides", reduce_strides)
.add("nthreads", nthreads);
// if the kernel has not been compiled, build it
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
runtime::gpu::CudaKernelBuilder::get_softmax_op(
writer, kernel_name, args, dtypes, non_reduce_rank, reduce_rank);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
std::unique_ptr<gpu::primitive> softmax(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &outputs[0])
.get_argument_list();
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1,
block_size_x,
1,
1,
0,
nullptr,
args_list,
nullptr));
debug_sync();
}});
return this->m_primitive_emitter->register_primitive(softmax, hash);
}
else
{
uint32_t reduce_count = static_cast<uint32_t>(shape_size(reduce_shape));
uint32_t block_size_x = 1;
while ((block_size_x << 1) <= fmin(512, reduce_count))
{
block_size_x <<= 1;
}
uint32_t shared_data_bytes = block_size_x * static_cast<uint32_t>(dtypes[0].size());
uint32_t aligned_grid_size_x = nthreads;
auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes_str[0], "in")
.add_placeholder(dtypes_str[1], "out")
.add("non_reduce_strides", non_reduce_strides)
.add("non_reduce_strides_magic", non_reduce_strides_magic)
.add("non_reduce_strides_shift", non_reduce_strides_shift)
.add("non_reduce_strides_in_input", non_reduce_strides_in_input)
.add("reduce_strides", reduce_strides)
.add("reduce_strides_magic", reduce_strides_magic)
.add("reduce_strides_shift", reduce_strides_shift)
.add("reduce_strides_in_input", reduce_strides_in_input)
.add("reduce_count", reduce_count)
.add("nthreads", nthreads);
// if the kernel has not been compiled, build it
kernel_name += "_bs_" + std::to_string(block_size_x);
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
runtime::gpu::CudaKernelBuilder::get_softmax_block_reduce_op(
writer, kernel_name, args, dtypes_str, non_reduce_rank, reduce_rank, block_size_x);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
}
std::unique_ptr<gpu::primitive> softmax(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &outputs[0])
.get_argument_list();
std::unique_ptr<gpu::primitive> softmax(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &outputs[0])
.get_argument_list();
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1,
block_size_x,
1,
1,
0,
nullptr,
args_list,
nullptr));
debug_sync();
}});
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1,
block_size_x,
1,
1,
shared_data_bytes,
nullptr,
args_list,
nullptr));
debug_sync();
}});
return this->m_primitive_emitter->register_primitive(softmax, hash);
return this->m_primitive_emitter->register_primitive(softmax, hash);
}
}
size_t runtime::gpu::CUDAEmitter::build_reduce_to_nd(const std::vector<element::Type>& dtypes,
......
......@@ -190,7 +190,7 @@ namespace ngraph
size_t concat_axis,
NVShape output_shape);
size_t build_softmax(const std::vector<std::string>& dtypes,
size_t build_softmax(const std::vector<element::Type>& dtypes,
NVShape input_shape,
NVShape reduce_axis);
......
......@@ -22,6 +22,7 @@
#include "ngraph/runtime/gpu/type_info.hpp"
using namespace ngraph;
#define WARPSIZE 32
void runtime::gpu::CudaKernelBuilder::get_elementwise_op(codegen::CodeWriter& writer,
const std::string& name,
......@@ -309,6 +310,20 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
size_t out_rank,
size_t reduce_rank)
{
auto stable_sum_lambda = [&]() {
writer << "input_i = __expf(input_i - r_max);\n";
writer << "y = input_i - c;\n";
writer << "t = r_sum + y;\n";
writer << "c = (t - r_sum) - y;\n";
writer << "r_sum = t;\n";
};
auto max_lambda = [&]() { writer << "r_max = r_max > input_i ? r_max : input_i;\n"; };
auto divide_lambda = [&]() {
writer << "input_i = __expf(input_i - r_max) / r_sum;\n";
writer << "out[reduce_idx] = input_i;\n";
};
writer << runtime::gpu::nvrtc::helpers();
writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature();
writer.block_begin();
......@@ -326,9 +341,9 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
// loop through all reduction axis
for (int64_t i = 0; i < static_cast<int64_t>(out_rank); i++)
{
writer << "in_idx += (dim_idx_generator / out_strides" << i
<< ") * non_reduce_strides" << i << ";\n";
writer << "dim_idx_generator %= out_strides" << i << ";\n";
writer << "in_idx += (dim_idx_generator / non_reduce_strides" << i
<< ") * non_reduce_strides_in_input" << i << ";\n";
writer << "dim_idx_generator %= non_reduce_strides" << i << ";\n";
}
writer << "uint32_t init_in_idx = in_idx;\n";
int64_t last_r_idx = static_cast<int64_t>(reduce_rank) - 1;
......@@ -348,14 +363,15 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
writer << "uint32_t reduce_idx = in_idx;\n";
for (int64_t j = 0; j < last_r_idx; j++)
{
writer << "reduce_idx += idx" << j << " * reduce_strides" << j << ";\n";
writer << "reduce_idx += idx" << j << " * reduce_strides_in_input" << j
<< ";\n";
}
writer << "uint32_t step = reduce_strides" << last_r_idx << ";\n";
writer << "uint32_t step = reduce_strides_in_input" << last_r_idx << ";\n";
writer << "if(reduce_idx != init_in_idx)\n";
writer.block_begin();
{
writer << "input_i = in[reduce_idx];\n";
writer << "r_max = r_max > input_i ? r_max : input_i;\n";
max_lambda();
}
writer.block_end();
writer << "reduce_idx += step;\n";
......@@ -369,7 +385,7 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
for (int k = 0; k < unroll_num; k++)
{
writer << "input_i = in[reduce_idx];\n";
writer << "r_max = r_max > input_i ? r_max : input_i;\n";
max_lambda();
writer << "reduce_idx += step;\n";
}
}
......@@ -379,7 +395,7 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
writer.block_begin();
{
writer << "input_i = in[reduce_idx];\n";
writer << "r_max = r_max > input_i ? r_max : input_i;\n";
max_lambda();
writer << "reduce_idx += step;\n";
}
writer.block_end();
......@@ -406,9 +422,10 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
writer << "uint32_t reduce_idx = in_idx;\n";
for (int64_t j = 0; j < last_r_idx; j++)
{
writer << "reduce_idx += idx" << j << " * reduce_strides" << j << ";\n";
writer << "reduce_idx += idx" << j << " * reduce_strides_in_input" << j
<< ";\n";
}
writer << "uint32_t step = reduce_strides" << last_r_idx << ";\n";
writer << "uint32_t step = reduce_strides_in_input" << last_r_idx << ";\n";
writer << "int idx" << last_r_idx << " = 0;\n";
// unroll last reduction axis
uint32_t unroll_num = 8;
......@@ -418,11 +435,8 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
{
for (int k = 0; k < unroll_num; k++)
{
writer << "input_i = expf(in[reduce_idx] - r_max);\n";
writer << "y = input_i - c;\n";
writer << "t = r_sum + y;\n";
writer << "c = (t - r_sum) - y;\n";
writer << "r_sum = t;\n";
writer << "input_i = in[reduce_idx];\n";
stable_sum_lambda();
writer << "reduce_idx += step;\n";
}
}
......@@ -431,11 +445,8 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
<< last_r_idx << "++)\n";
writer.block_begin();
{
writer << "input_i = expf(in[reduce_idx] - r_max);\n";
writer << "y = input_i - c;\n";
writer << "t = r_sum + y;\n";
writer << "c = (t - r_sum) - y;\n";
writer << "r_sum = t;\n";
writer << "input_i = in[reduce_idx];\n";
stable_sum_lambda();
writer << "reduce_idx += step;\n";
}
writer.block_end();
......@@ -458,9 +469,10 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
writer << "uint32_t reduce_idx = in_idx;\n";
for (int64_t j = 0; j < last_r_idx; j++)
{
writer << "reduce_idx += idx" << j << " * reduce_strides" << j << ";\n";
writer << "reduce_idx += idx" << j << " * reduce_strides_in_input" << j
<< ";\n";
}
writer << "uint32_t step = reduce_strides" << last_r_idx << ";\n";
writer << "uint32_t step = reduce_strides_in_input" << last_r_idx << ";\n";
writer << "int idx" << last_r_idx << " = 0;\n";
// unroll last reduction axis
uint32_t unroll_num = 8;
......@@ -470,8 +482,8 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
{
for (int k = 0; k < unroll_num; k++)
{
writer << "input_i = expf(in[reduce_idx] - r_max) / r_sum;\n";
writer << "out[reduce_idx] = input_i;\n";
writer << "input_i = in[reduce_idx];\n";
divide_lambda();
writer << "reduce_idx += step;\n";
}
}
......@@ -480,8 +492,8 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
<< last_r_idx << "++)\n";
writer.block_begin();
{
writer << "input_i = expf(in[reduce_idx] - r_max) / r_sum;\n";
writer << "out[reduce_idx] = input_i;\n";
writer << "input_i = in[reduce_idx];\n";
divide_lambda();
writer << "reduce_idx += step;\n";
}
writer.block_end();
......@@ -498,6 +510,267 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
return;
}
void runtime::gpu::CudaKernelBuilder::get_softmax_block_reduce_op(
codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args,
const std::vector<std::string>& data_types,
size_t non_reduce_rank,
size_t reduce_rank,
size_t block_size_x)
{
writer << runtime::gpu::nvrtc::helpers();
writer << runtime::gpu::nvrtc::define_non_coherent_load(data_types[0], "load");
auto get_reduce_input_lambda = [&]() {
collective_coordinate_transform_helper(writer,
"reduce_idx",
"reduce_strides",
"reduce_strides_magic",
"reduce_strides_shift",
"reduce_strides_in_input",
"reduce_coordinate",
reduce_rank,
true,
"reduce_input_index");
writer << "input_idx = reduce_input_index + non_reduce_input_index;\n";
writer << "input_i = load(in, input_idx);\n";
};
auto stable_sum_lambda = [&]() {
writer << "input_i = __expf(input_i - r_max);\n";
writer << "y = input_i - c;\n";
writer << "t = r_sum + y;\n";
writer << "c = (t - r_sum) - y;\n";
writer << "r_sum = t;\n";
};
auto max_lambda = [&]() { writer << "r_max = r_max > input_i ? r_max : input_i;\n"; };
auto divide_lambda = [&]() {
writer << "input_i = __expf(input_i - r_max) / r_sum;\n";
writer << "out[input_idx] = input_i;\n";
};
writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature();
writer.block_begin();
{
writer << "extern __shared__ " << data_types[1] << " sdata[];\n";
if (non_reduce_rank > 0)
{
writer << "uint32_t bid = blockIdx.x;\n";
}
writer << "uint32_t tid = threadIdx.x;\n";
writer << "uint32_t step = blockDim.x; \n";
collective_coordinate_transform_helper(writer,
"bid",
"non_reduce_strides",
"non_reduce_strides_magic",
"non_reduce_strides_shift",
"non_reduce_strides_in_input",
"non_reduce_coordinate",
non_reduce_rank,
true,
"non_reduce_input_index");
writer << "uint32_t input_idx;\n";
writer << "uint32_t reduce_idx = tid;\n";
writer << data_types[1] << " r_max;\n";
writer << data_types[1] << " input_i;\n";
// find max
writer.block_begin();
{
get_reduce_input_lambda();
writer << "r_max = input_i;\n";
writer << "reduce_idx += step;\n";
}
writer.block_end();
writer << "while (reduce_idx + 7 * step < reduce_count)\n";
writer.block_begin();
{
for (int i = 0; i < 8; i++)
{
writer.block_begin();
get_reduce_input_lambda();
max_lambda();
writer << "reduce_idx += step;\n";
writer.block_end();
}
}
writer.block_end();
writer << "while (reduce_idx < reduce_count)\n";
writer.block_begin();
{
writer.block_begin();
get_reduce_input_lambda();
max_lambda();
writer << "reduce_idx += step;\n";
writer.block_end();
}
writer.block_end();
// reduction max
// accumulate WARPSIZE = 32 threads for each warp
for (int i = (WARPSIZE >> 1); i >= 1; i >>= 1)
{
if (block_size_x > i)
{
writer << "input_i = __shfl_down_sync(0xffffffff, r_max, " << i << ", " << WARPSIZE
<< ");\n";
max_lambda();
}
}
if (block_size_x > WARPSIZE)
{
writer << "uint32_t lane_idx = threadIdx.x & " << WARPSIZE - 1 << "; \n";
writer << "uint32_t warp_idx = threadIdx.x >> 5; \n";
writer << "if(lane_idx == 0)\n";
writer.block_begin();
{
writer << "sdata[warp_idx] = r_max;\n";
}
writer.block_end();
writer << "__syncthreads();\n";
uint32_t num_of_warp = block_size_x >> 5;
writer << "if(tid < " << num_of_warp << ")\n";
writer.block_begin();
{
writer << "r_max = sdata[tid];\n";
}
writer.block_end();
//accumulate WARPSIZE threads
for (int i = (WARPSIZE >> 1); i >= 1; i >>= 1)
{
if (num_of_warp > i)
{
writer << "input_i = __shfl_down_sync(0xffffffff, r_max, " << i << ", "
<< WARPSIZE << ");\n";
max_lambda();
}
}
}
// save and broadcast
writer << "if(tid == 0)\n";
writer.block_begin();
{
writer << "sdata[0] = r_max;\n";
;
}
writer.block_end();
writer << "__syncthreads();\n";
writer << "r_max = sdata[0];\n";
//exp and sum , https://en.wikipedia.org/wiki/Kahan_summation_algorithm
writer << data_types[1] << " r_sum = 0;\n";
writer << data_types[1] << " c = 0;\n";
writer << data_types[1] << " y;\n";
writer << data_types[1] << " t;\n";
writer << "reduce_idx = tid;\n";
writer << "while (reduce_idx + 7 * step < reduce_count)\n";
writer.block_begin();
{
for (int i = 0; i < 8; i++)
{
writer.block_begin();
get_reduce_input_lambda();
stable_sum_lambda();
writer << "reduce_idx += step;\n";
writer.block_end();
}
}
writer.block_end();
writer << "while (reduce_idx < reduce_count)\n";
writer.block_begin();
{
writer.block_begin();
get_reduce_input_lambda();
stable_sum_lambda();
writer << "reduce_idx += step;\n";
writer.block_end();
}
writer.block_end();
// reduction sum
// accumulate WARPSIZE = 32 threads for each warp
for (int i = (WARPSIZE >> 1); i >= 1; i >>= 1)
{
if (block_size_x > i)
{
writer << "r_sum += __shfl_down_sync(0xffffffff, r_sum, " << i << ", " << WARPSIZE
<< ");\n";
}
}
if (block_size_x > WARPSIZE)
{
writer << "if(lane_idx == 0)\n";
writer.block_begin();
{
writer << "sdata[warp_idx] = r_sum;\n";
}
writer.block_end();
writer << "__syncthreads();\n";
uint32_t num_of_warp = block_size_x >> 5;
writer << "if(tid < " << num_of_warp << ")\n";
writer.block_begin();
{
writer << "r_sum = sdata[tid];\n";
}
writer.block_end();
//accumulate WARPSIZE = 32 threads
for (int i = (WARPSIZE >> 1); i >= 1; i >>= 1)
{
if (num_of_warp > i)
{
writer << "r_sum += __shfl_down_sync(0xffffffff, r_sum, " << i << ", "
<< WARPSIZE << ");\n";
}
}
}
// save and broadcast
writer << "__syncthreads();\n";
writer << "if(tid == 0)\n";
writer.block_begin();
{
writer << "sdata[0] = r_sum;\n";
;
}
writer.block_end();
writer << "__syncthreads();\n";
writer << "r_sum = sdata[0];\n";
// divide
writer << "reduce_idx = tid;\n";
writer << "while (reduce_idx + 7 * step < reduce_count)\n";
writer.block_begin();
{
for (int i = 0; i < 8; i++)
{
writer.block_begin();
get_reduce_input_lambda();
divide_lambda();
writer << "reduce_idx += step;\n";
writer.block_end();
}
}
writer.block_end();
writer << "while (reduce_idx < reduce_count)\n";
writer.block_begin();
{
writer.block_begin();
get_reduce_input_lambda();
divide_lambda();
writer << "reduce_idx += step;\n";
writer.block_end();
}
writer.block_end();
}
writer.block_end();
return;
}
//each thread calculate the whole reduction of one output
void runtime::gpu::CudaKernelBuilder::get_reduce_to_nd_op(
codegen::CodeWriter& writer,
......@@ -673,19 +946,19 @@ void runtime::gpu::CudaKernelBuilder::get_reduce_to_scalar_op(
}
writer.block_end();
//accumulate 32 threads for each warp
for (int i = 16; i >= 1; i >>= 1)
//accumulate WARPSIZE threads for each warp
for (int i = (WARPSIZE >> 1); i >= 1; i >>= 1)
{
if (block_size_x > i)
{
writer << "r = " << reduce_op << "(r, __shfl_down_sync(0xffffffff, r, " << i
<< ", 32));\n";
writer << "r = " << reduce_op << "(r, __shfl_down_sync(0xffffffff, r, " << i << ", "
<< WARPSIZE << "));\n";
}
}
if (block_size_x > 32)
if (block_size_x > WARPSIZE)
{
writer << "uint32_t lane_idx = tid & 0x1f; \n";
writer << "uint32_t lane_idx = tid & " << WARPSIZE - 1 << "; \n";
writer << "uint32_t warp_idx = tid >> 5; \n";
writer << "if(lane_idx == 0)\n";
writer.block_begin();
......@@ -695,21 +968,21 @@ void runtime::gpu::CudaKernelBuilder::get_reduce_to_scalar_op(
writer.block_end();
writer << "__syncthreads();\n";
uint32_t warp_size = block_size_x >> 5;
uint32_t num_of_warp = block_size_x >> 5;
writer << "if(tid < " << warp_size << ")\n";
writer << "if(tid < " << num_of_warp << ")\n";
writer.block_begin();
{
writer << "r = sdata[tid];\n";
}
writer.block_end();
//accumulate 32 threads
for (int i = 16; i >= 1; i >>= 1)
//accumulate WARPSIZE threads
for (int i = (WARPSIZE >> 1); i >= 1; i >>= 1)
{
if (warp_size > i)
if (num_of_warp > i)
{
writer << "r = " << reduce_op << "(r, __shfl_down_sync(0xffffffff, r, " << i
<< ", 32));\n";
<< ", " << WARPSIZE << "));\n";
}
}
}
......@@ -1931,6 +2204,10 @@ void runtime::gpu::CudaKernelBuilder::coordinate_transform_to_multi_d(codegen::C
size_t rank,
bool register_arguments)
{
if (rank == 0)
{
return;
}
std::string brace_open = (register_arguments) ? "" : "[";
std::string brace_close = (register_arguments) ? "" : "]";
......
......@@ -208,6 +208,14 @@ namespace ngraph
size_t out_rank,
size_t reduce_rank);
static void get_softmax_block_reduce_op(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args,
const std::vector<std::string>& data_types,
size_t non_reduce_rank,
size_t reduce_rank,
size_t block_size_x);
static void add_pod_typedefs(codegen::CodeWriter& writer);
static void coordinate_transform_to_multi_d(codegen::CodeWriter& writer,
......
......@@ -1477,9 +1477,9 @@ void runtime::gpu::GPU_Emitter::emit_Softmax(EMIT_ARGS)
writer.block_begin();
{
auto axes_set = softmax->get_axes();
std::vector<string> dtypes;
dtypes.push_back(args[0].get_type());
dtypes.push_back(out[0].get_type());
std::vector<element::Type> dtypes;
dtypes.push_back(args[0].get_element_type());
dtypes.push_back(out[0].get_element_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
size_t index = cuda_emitter->build_softmax(dtypes, args[0].get_shape(), axes_set);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment