Commit 13f00048 authored by Chris Sullivan's avatar Chris Sullivan Committed by Nick Korovaiko

Add extra hash parameters to broadcast and max pool (#1163)

* Move maxpool and avgpool into CudaKernelBuilder and add cache parameters to kernel name for broadcast which are required for correct lookup.

* Styling.

* Add space before avg_pool.
parent b69f0734
......@@ -26,7 +26,6 @@
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
#include "ngraph/runtime/gpu/type_info.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
......@@ -422,11 +421,13 @@ size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const GPURuntimeContext* ctx
auto input_width = input_shape.back();
auto output_width = output_shape.back();
std::stringstream ss;
ss << "maxpool"
<< "_i" << input_width << "_o" << output_width << "_w" << window_width << "_s"
<< window_stride;
auto hash = ss.str();
std::string kernel_name = "maxpool_" + join(dtypes, "_") + "_iw" + std::to_string(input_width) +
"_ow" + std::to_string(output_width) + "_ww" +
std::to_string(window_width) + "_wst" + std::to_string(window_stride);
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
// primitive hash and kernel name are equivalent for maxpool_1d
auto hash = kernel_name;
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
......@@ -435,62 +436,34 @@ size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const GPURuntimeContext* ctx
return primitive_index;
}
size_t nthreads = shape_size(output_shape);
// if the kernel has not been compiled, build it
auto compiled_kernel = ctx->compiled_kernel_pool->get(hash);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
// assumes data is in NCW format
writer << "extern \"C\" __global__ void cuda_" << hash << "(" << dtypes[0] << "* in, "
<< dtypes[1] << "* out)\n";
writer.block_begin();
{
// index into output tensor
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < " << nthreads << ")\n";
writer.block_begin();
{
// index into input tensor
writer << "size_t start = (tid / " << output_width << ") * " << input_width << " + "
<< " (tid % " << output_width << ") * " << window_stride << ";\n";
writer << dtypes[0] << " max_val = " << TypeInfo::Get(dtypes[0])->lowest() << ";\n";
writer << "for (size_t i = start; i < start + " << window_width << "; i++)\n";
writer.block_begin();
{
writer << "const " << dtypes[0] << " input = in[i];\n";
writer << "if (input > max_val)\n";
writer.block_begin();
{
writer << "max_val = input;\n";
}
writer.block_end();
}
writer.block_end();
writer << "out[tid] = max_val;\n";
}
writer.block_end();
}
writer.block_end();
CudaKernelBuilder::get_max_pool_1d(
writer, kernel_name, dtypes, input_width, output_width, window_width, window_stride);
compiled_kernel = ctx->compiled_kernel_pool->set(hash, writer.get_code());
}
std::unique_ptr<gpu::primitive> pool(new gpu::primitive{[=](void** inputs, void** outputs) {
void* args_list[] = {&inputs[0], &outputs[0]};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(nthreads),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}});
size_t nthreads = shape_size(output_shape);
std::unique_ptr<gpu::primitive> pool(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void* args_list[] = {&inputs[0], &outputs[0], &nthreads};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(nthreads),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}});
primitive_index = this->m_primitive_emitter->insert(std::move(pool));
m_primitive_emitter->cache(hash, primitive_index);
......@@ -599,111 +572,7 @@ size_t runtime::gpu::CUDAEmitter::build_avg_pool(const GPURuntimeContext* ctx,
{
codegen::CodeWriter writer;
writer << include_helpers();
// In the pooling operation out = P(in) where in: NCDHW -> out: NKMPQ
// via pooling window: JTRS. Currently feature pooling
// is not supported and so K = C and J is unused
writer << "extern \"C\" __global__ void cuda_" << kernel_name << "(" << dtypes[0]
<< "* in, " << dtypes[1] << "* out, "
<< "float alpha, float beta, "
<< "int N, int C, int D, int H, int W, "
<< "int HW, int DHW, int CDHW, int magic_N, int shift_N, "
<< "int P, int Q, int magic_P, int shift_P, "
<< "int PQ, int MPQ, int KMPQ, "
<< "int S, int RS, int TRS, "
<< "int magic_S, int shift_S, int magic_RS, int shift_RS, "
<< "int str_d, int str_h, int str_w, "
<< "int pad_d, int pad_h, int pad_w"
<< ")\n";
writer.block_begin();
{
writer << "const int tid = threadIdx.x;\n";
writer << "if (tid < 32)\n";
writer.block_begin();
{
writer << "const int q = blockIdx.x;\n";
writer << "const int mp = blockIdx.y;\n";
writer << "const int nk = blockIdx.z;\n";
writer << "const int k = division_by_invariant_multiplication(nk, magic_N, "
"shift_N);\n";
writer << "const int n = nk - k * N;\n";
writer << "const int m = division_by_invariant_multiplication(mp, magic_P, "
"shift_P);\n";
writer << "const int p = mp - m * P;\n";
writer << "out += n*KMPQ + k*MPQ + m*PQ + mad16(p, Q, q);\n";
// coordinate transform factors from MPQ to DHW
writer << "int qs = q * str_w - pad_w;\n";
writer << "int pr = p * str_h - pad_h;\n";
writer << "int mt = m * str_d - pad_d;\n";
writer << "int pool_size = ";
auto pool_size = include_pad ? "TRS" : "0";
writer << pool_size << ";\n";
writer << "float sum = 0.0f;\n";
writer << "float rcp_pool_size = 1.0f;\n";
// each warp operates on a single pooling window and
// reduces the contents of the window within the warp
writer << "for (int trs = tid; trs < TRS; trs += 32)\n";
writer.block_begin();
{
writer << "int t = division_by_invariant_multiplication(trs, magic_RS, "
"shift_RS);\n";
writer << "int rs = mod16(trs, t, RS);\n";
writer
<< "int r = division_by_invariant_multiplication(rs, magic_S, shift_S);\n";
writer << "int s = mod16(rs, r, S);\n";
// coordinate transformation from TRS to DHW
// via MPQ transform factors above
writer << "int x = qs + s;\n";
writer << "int y = pr + r;\n";
writer << "int z = mt + t;\n";
// helper to check participating threads
writer << "bool bounds_x = (x >= 0) && (x < W);\n";
writer << "bool bounds_y = (y >= 0) && (y < H);\n";
writer << "bool bounds_z = (z >= 0) && (z < D);\n";
writer << "bool within_tensor_bounds = bounds_x && bounds_y && bounds_z;\n";
if (include_pad == false)
{
// count the number of (non-padded) elements
writer << "pool_size += __popc(__ballot_sync(0xffffffff, "
"within_tensor_bounds));\n";
}
// this will need to change to k->c once
// feature pooling support is added
writer << "int idx = n*CDHW + k*DHW + z*HW + y*W + x;\n";
writer << "sum += load(in,idx,within_tensor_bounds);\n";
}
writer.block_end();
writer << "rcp_pool_size = 1.0f / (float)pool_size;\n";
// reduce pooling window within warp.
// this could be improved by calculating the
// pooling windows each thread can partake in to
// reduce loads and increase coalescing. in that case,
// multiple warps per block would be required and the
// warp reduced sums would need to be accumulated in
// shared memory
writer << "for (int i = 16; i > 0; i >>= 1)\n";
writer.block_begin();
{
writer << "sum += __shfl_xor_sync(0xffffffff,sum,i,32);\n";
}
writer.block_end();
// write result to output
writer << "if (tid == 0)\n";
writer.block_begin();
{
writer << "*out = sum * rcp_pool_size;\n";
}
writer.block_end();
}
writer.block_end();
}
writer.block_end();
CudaKernelBuilder::get_avg_pool(writer, kernel_name, dtypes, include_pad);
compiled_kernel = ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
}
......@@ -1260,11 +1129,12 @@ size_t runtime::gpu::CUDAEmitter::build_broadcast(const GPURuntimeContext* ctx,
const std::set<size_t>& reduce_axes)
{
// assumes NC{d1,...,dn} format
std::string kernel_name = "broadcast_" + join(dtypes, "_");
std::string kernel_name =
"broadcast_" + join(dtypes, "_") + "_r" + std::to_string(result_shape.size());
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
std::stringstream ss;
ss << kernel_name << "_s" << join(result_shape, "_") << "_r" << join(reduce_axes, "_");
ss << kernel_name << "_s" << join(result_shape, "_") << "_rs" << join(reduce_axes, "_");
auto hash = ss.str();
// check if the requested kernel is already an inserted primitive
......
......@@ -17,6 +17,7 @@
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/runtime/gpu/type_info.hpp"
using namespace ngraph;
......@@ -556,6 +557,159 @@ void runtime::gpu::CudaKernelBuilder::get_replace_slice_op(
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_max_pool_1d(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types,
size_t input_width,
size_t output_width,
size_t window_width,
size_t window_stride)
{
// assumes data is in NCW format
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1] << "* out, size_t nthreads)\n";
writer.block_begin();
{
// index into output tensor
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < nthreads)\n";
writer.block_begin();
{
// index into input tensor
writer << "size_t start = (tid / " << output_width << ") * " << input_width << " + "
<< " (tid % " << output_width << ") * " << window_stride << ";\n";
writer << data_types[0] << " max_val = " << TypeInfo::Get(data_types[0])->lowest()
<< ";\n";
writer << "for (size_t i = start; i < start + " << window_width << "; i++)\n";
writer.block_begin();
{
writer << "const " << data_types[0] << " input = in[i];\n";
writer << "if (input > max_val)\n";
writer.block_begin();
{
writer << "max_val = input;\n";
}
writer.block_end();
}
writer.block_end();
writer << "out[tid] = max_val;\n";
}
writer.block_end();
}
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_avg_pool(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types,
bool include_pad)
{
// In the pooling operation out = P(in) where in: NCDHW -> out: NKMPQ
// via pooling window: JTRS. Currently feature pooling
// is not supported and so K = C and J is unused
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1] << "* out, "
<< "float alpha, float beta, "
<< "int N, int C, int D, int H, int W, "
<< "int HW, int DHW, int CDHW, int magic_N, int shift_N, "
<< "int P, int Q, int magic_P, int shift_P, "
<< "int PQ, int MPQ, int KMPQ, "
<< "int S, int RS, int TRS, "
<< "int magic_S, int shift_S, int magic_RS, int shift_RS, "
<< "int str_d, int str_h, int str_w, "
<< "int pad_d, int pad_h, int pad_w"
<< ")\n";
writer.block_begin();
{
writer << "const int tid = threadIdx.x;\n";
writer << "if (tid < 32)\n";
writer.block_begin();
{
writer << "const int q = blockIdx.x;\n";
writer << "const int mp = blockIdx.y;\n";
writer << "const int nk = blockIdx.z;\n";
writer << "const int k = division_by_invariant_multiplication(nk, magic_N, "
"shift_N);\n";
writer << "const int n = nk - k * N;\n";
writer << "const int m = division_by_invariant_multiplication(mp, magic_P, "
"shift_P);\n";
writer << "const int p = mp - m * P;\n";
writer << "out += n*KMPQ + k*MPQ + m*PQ + mad16(p, Q, q);\n";
// coordinate transform factors from MPQ to DHW
writer << "int qs = q * str_w - pad_w;\n";
writer << "int pr = p * str_h - pad_h;\n";
writer << "int mt = m * str_d - pad_d;\n";
writer << "int pool_size = ";
auto pool_size = include_pad ? "TRS" : "0";
writer << pool_size << ";\n";
writer << "float sum = 0.0f;\n";
writer << "float rcp_pool_size = 1.0f;\n";
// each warp operates on a single pooling window and
// reduces the contents of the window within the warp
writer << "for (int trs = tid; trs < TRS; trs += 32)\n";
writer.block_begin();
{
writer << "int t = division_by_invariant_multiplication(trs, magic_RS, "
"shift_RS);\n";
writer << "int rs = mod16(trs, t, RS);\n";
writer << "int r = division_by_invariant_multiplication(rs, magic_S, shift_S);\n";
writer << "int s = mod16(rs, r, S);\n";
// coordinate transformation from TRS to DHW
// via MPQ transform factors above
writer << "int x = qs + s;\n";
writer << "int y = pr + r;\n";
writer << "int z = mt + t;\n";
// helper to check participating threads
writer << "bool bounds_x = (x >= 0) && (x < W);\n";
writer << "bool bounds_y = (y >= 0) && (y < H);\n";
writer << "bool bounds_z = (z >= 0) && (z < D);\n";
writer << "bool within_tensor_bounds = bounds_x && bounds_y && bounds_z;\n";
if (include_pad == false)
{
// count the number of (non-padded) elements
writer << "pool_size += __popc(__ballot_sync(0xffffffff, "
"within_tensor_bounds));\n";
}
// this will need to change to k->c once
// feature pooling support is added
writer << "int idx = n*CDHW + k*DHW + z*HW + y*W + x;\n";
writer << "sum += load(in,idx,within_tensor_bounds);\n";
}
writer.block_end();
writer << "rcp_pool_size = 1.0f / (float)pool_size;\n";
// reduce pooling window within warp.
// this could be improved by calculating the
// pooling windows each thread can partake in to
// reduce loads and increase coalescing. in that case,
// multiple warps per block would be required and the
// warp reduced sums would need to be accumulated in
// shared memory
writer << "for (int i = 16; i > 0; i >>= 1)\n";
writer.block_begin();
{
writer << "sum += __shfl_xor_sync(0xffffffff,sum,i,32);\n";
}
writer.block_end();
// write result to output
writer << "if (tid == 0)\n";
writer.block_begin();
{
writer << "*out = sum * rcp_pool_size;\n";
}
writer.block_end();
}
writer.block_end();
}
writer.block_end();
}
std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_helper(
codegen::CodeWriter& writer,
std::string i_thread_index,
......
......@@ -101,6 +101,19 @@ namespace ngraph
bool save_elementwise,
size_t rank);
static void get_max_pool_1d(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types,
size_t input_width,
size_t output_width,
size_t window_width,
size_t window_stride);
static void get_avg_pool(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types,
bool include_pad);
static void add_pod_typedefs(codegen::CodeWriter& writer);
/// \brief Given kernel input variables i_* produce register variables o_coordinates{i}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment