Commit 83e6aa5f authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

CUDA softmax kernel and broadcast kernel support for multiple non-consecutive axes (#1070)

* Added op::ReplaceSlice and enabled respective tests.

* div64 -> division_by_invariant_multiplication

* Added GPUMemoryManager for aggregating memory allocations and copies into a single operation for kernel arguments, and a reusuable memory space for workspace allocations.

* Added GPUShape and reworked Shape helpers to be
compatible with different shape types.
Shape is now implicitly convertable to GPUShape.

* Updated shape helpers signature and add conversion operators/constructors for GPUShape.

* Removed several unecessary static_casts now that GPUShape is utilized. GPUTensorViewWrapper had a few functions returning std::vector<size_t> instead of Shape/Strides. These were updated as well to take advantage of GPUShape convertion operators.

* Forgot to fix lambda for workspace allocations to match that of argspace allocations.

* Added GPUShape and reworked Shape helpers to be
compatible with different shape types.
Shape is now implicitly convertable to GPUShape.

* Updated shape helpers signature and add conversion operators/constructors for GPUShape.

* Adjust row_major_strides to avoid reversed-copy.

* Moved declaration out of loop for clang.

* Moved gpu_shape to gpu transformer.

* Removed no longer necessary headers.

* Added stdexcept header to gpu_shape.hpp

* Coordinate->GPUShape

* Refactored replace_slice into CudaKernelBuilder. Simplified allocations using new GPUAllocator and GPUMemoryManager.

* Refactor allocations to make use of primitive emitter.
Now memory primitives are registered at compile time and
the gpu memory address is resolved at runtime by ivoking
the primitive.

* Changed check on 64bit shape to check if high bits are set.

* Added const qualifier to data being copied in GPUAllocator::reserve_argspace

* Added const qualifier to data being copied in GPUAllocator::reserve_argspace

* Replaced runtime host to device memcpys with GPUAllocator reservations in order to move them to compile time.

* Forgot to remove no longer necessary buffer freeing from op emitters.

* Removed replace slice.

* Removed more replace_slice diffs.

* Updated replace_slice op to utilize GPUShape and GPUMemoryManager.

* Added back missing changes after timeline resolution.

* Added spacing between functions in GPUShape and boolean operators in shape.hpp.

* Template parameters are UPPER_SNAKE_CASE.

* Added unit tests for GPUMemoryManager and added checks that ensure the
device memory is allocated prior to address resolution by the memory_primitives.
Also exposed the allocation size of the memory manager.

* Return type of shape_size should be large enough to encapsulate the full stride of the tensor.
This should be 64bits wide regardless of the underlying value_type of the ShapeType.

* Upstreaming changes to shape_size (which returns size_t).

* cuDNN softmax impl. for all axis activation.

* Added catch for per-axis activations.

* Removed commended headers.

* Added explicit function for queueing kernel argument data rather than inline in the reservation function per @fengleitian recommendation.

* Add softmax cuda kernel. It relies on atomic memory addition to global
memory, this will add contention and should be optimized in the
future. A multilevel reduction can be found in
cs/gpu_softmax_cuda_shfl but it requires some further engineering.

* Refactored reduce coordinate transform code into a helper and applied it to broadcast.
Broadcast added to CUDAEmitter, now supports multiple non-consecutive axes.

* Removed change to data_types variable and updated/removed comments.

* Refactored softmax into the emission of two fused elementwise collective ops.
Added fused elementwise + collective kernels. Softmax is then just the combination of exp_sum_reduce + div_broadcast.

* Added default param to GPUAllocator::reserve_workspace to request memory initialization for each invocation of the memory primitive.

* GPU workspace memory is zero initialized by default but can be turned off if desired.

* Added template parameter to CUDAEmitter::build_elementwise, REDUCE_OP_TYPE,
to specify the ngraph op type to use for the reduction in the fusted ew_collective kernel.

* Renamed variables and updated a comment.

* Removed outdated softmax kernel to avoid confusion. Can be added later when atomic reduce is replaced.

* Clang complained about lack of explicit destructor for AxisSet. Since cuda_emitter doesn't need AxisSet specifically, switch to std::set<size_t>.
This also has the benefit that in the future, if we wish to emit kernels without ngraph core (for example in a standalone binary via a
serialized graph manifest, we don't depend on AxisSet.

* softmax -> broadcast in build_broadcast.

* Separate elementwise and elementwise_collective.
parent 692101a7
......@@ -682,6 +682,142 @@ size_t runtime::gpu::CUDAEmitter::build_elementwise_n_to_1(const GPURuntimeConte
return primitive_index;
}
size_t
runtime::gpu::CUDAEmitter::build_fused_ew_to_collective(const GPURuntimeContext* ctx,
const std::vector<std::string>& dtypes,
GPUShape tensor_shape,
const std::set<size_t>& reduced_tensors,
const std::set<size_t>& axes,
const char* op,
const char* kernel,
const char* reduce_op,
bool save_elementwise)
{
// kernel_name is used to check if the cuda kernel has been previously compiled
std::stringstream kernel_name;
kernel_name << "ew_collective"
<< "_" << op << "_" << join(dtypes, "_") << "_" << reduce_op
// multi-output op
<< "_mo" << int(save_elementwise);
// hash is used to check if the emitted primitive already exists
std::stringstream ss;
ss << kernel_name.str() << "_s" << join(tensor_shape, "_");
auto hash = ss.str();
// if the primitive exists, we are done
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto compiled_kernel = ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
writer << include_helpers();
if (kernel)
{
CudaKernelBuilder::get_device_helper(writer, op, kernel, dtypes);
}
CudaKernelBuilder::get_ew_collective_op(writer,
kernel_name.str(),
op,
reduce_op,
dtypes,
reduced_tensors,
save_elementwise,
tensor_shape.size());
compiled_kernel = ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
// calculate strides
GPUShape strides = row_major_strides(tensor_shape);
// precacluate invariants for integer division via multiplication
std::vector<int> stride_magic;
std::vector<int> stride_shift;
for (int i = 0; i < strides.size(); i++)
{
int magic;
int shift;
std::tie(magic, shift) = idiv_magic_u64(strides[i]);
stride_magic.push_back(magic);
stride_shift.push_back(shift);
}
// calculate reduced tensor strides with 0s inserted for reduced axes
GPUShape reduced_shape = tensor_shape;
for (auto const& axis : axes)
{
reduced_shape[axis] = 1;
}
GPUShape reduced_strides = row_major_strides(reduced_shape);
for (auto const& axis : axes)
{
reduced_strides[axis] = 0;
}
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t idx_strides = allocator.reserve_argspace(strides.data(), strides.size() * sizeof(int));
size_t idx_stride_magic =
allocator.reserve_argspace(stride_magic.data(), stride_magic.size() * sizeof(int));
size_t idx_stride_shift =
allocator.reserve_argspace(stride_shift.data(), stride_shift.size() * sizeof(int));
size_t idx_reduced_strides =
allocator.reserve_argspace(reduced_strides.data(), reduced_strides.size() * sizeof(int));
size_t nthreads = shape_size(tensor_shape);
constexpr const int nthreads_per_block = 32;
int nblocks = 1 + ((static_cast<int>(nthreads) - 1) / nthreads_per_block);
// TODO: check if mutable is necessary
std::unique_ptr<gpu::primitive> ew_collective(new gpu::primitive{[=](void** inputs,
void** outputs) mutable {
void* strides_d = runtime::gpu::invoke_memory_primitive(ctx, idx_strides);
void* stride_magic_d = runtime::gpu::invoke_memory_primitive(ctx, idx_stride_magic);
void* stride_shift_d = runtime::gpu::invoke_memory_primitive(ctx, idx_stride_shift);
void* reduced_strides_d = runtime::gpu::invoke_memory_primitive(ctx, idx_reduced_strides);
std::vector<void*> args_list;
for (auto i = 0u; i < dtypes.size() - 1; i++)
{
args_list.push_back(&inputs[i]);
}
args_list.push_back(&outputs[0]);
if (save_elementwise)
{
args_list.push_back(&outputs[1]);
}
args_list.push_back(&strides_d);
args_list.push_back(&stride_magic_d);
args_list.push_back(&stride_shift_d);
args_list.push_back(&reduced_strides_d);
args_list.push_back(&nthreads);
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
nblocks,
1,
1,
nthreads_per_block,
1,
1,
0,
NULL,
args_list.data(),
0));
CUDA_SAFE_CALL(cuCtxSynchronize());
}});
primitive_index = this->m_primitive_emitter->insert(std::move(ew_collective));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_reduce_window(const GPURuntimeContext* ctx,
const OpName op_name,
const std::vector<std::string>& dtypes,
......@@ -937,6 +1073,104 @@ size_t runtime::gpu::CUDAEmitter::build_replace_slice(const GPURuntimeContext* c
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_broadcast(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
GPUShape result_shape,
const std::set<size_t>& reduce_axes)
{
// assumes NC{d1,...,dn} format
std::string kernel_name = "broadcast_" + join(dtypes, "_");
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
std::stringstream ss;
ss << kernel_name << "_s" << join(result_shape, "_") << "_r" << join(reduce_axes, "_");
auto hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// if the kernel has not been compiled, build it
auto compiled_kernel = ctx->compiled_kernel_pool->get(kernel_name);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
writer << include_helpers();
runtime::gpu::CudaKernelBuilder::get_broadcast_op(
writer, kernel_name, dtypes, result_shape.size());
compiled_kernel = ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
}
// calculate strides
GPUShape strides = row_major_strides(result_shape);
// precacluate invariants for integer division via multiplication
std::vector<int> stride_magic;
std::vector<int> stride_shift;
for (int i = 0; i < strides.size(); i++)
{
int magic;
int shift;
std::tie(magic, shift) = idiv_magic_u64(strides[i]);
stride_magic.push_back(magic);
stride_shift.push_back(shift);
}
// calculate reduced tensor strides with 0s inserted for reduced axes
GPUShape reduced_shape = result_shape;
for (auto const& axis : reduce_axes)
{
reduced_shape[axis] = 1;
}
GPUShape reduced_strides = row_major_strides(reduced_shape);
for (auto const& axis : reduce_axes)
{
reduced_strides[axis] = 0;
}
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t idx_strides = allocator.reserve_argspace(strides.data(), strides.size() * sizeof(int));
size_t idx_stride_magic =
allocator.reserve_argspace(stride_magic.data(), stride_magic.size() * sizeof(int));
size_t idx_stride_shift =
allocator.reserve_argspace(stride_shift.data(), stride_shift.size() * sizeof(int));
size_t idx_reduced_strides =
allocator.reserve_argspace(reduced_strides.data(), reduced_strides.size() * sizeof(int));
// TODO: blending factors are not currently implemented
float alpha = 1.0f;
float beta = 0.0f;
int nthreads = static_cast<int>(shape_size(result_shape));
std::unique_ptr<gpu::primitive> broadcast(new gpu::primitive{[=](void** inputs,
void** outputs) mutable {
void* strides_d = runtime::gpu::invoke_memory_primitive(ctx, idx_strides);
void* stride_magic_d = runtime::gpu::invoke_memory_primitive(ctx, idx_stride_magic);
void* stride_shift_d = runtime::gpu::invoke_memory_primitive(ctx, idx_stride_shift);
void* reduced_strides_d = runtime::gpu::invoke_memory_primitive(ctx, idx_reduced_strides);
void* args_list[] = {&inputs[0],
&outputs[0],
&strides_d,
&stride_magic_d,
&stride_shift_d,
&reduced_strides_d,
&alpha,
&beta,
&nthreads};
CUDA_SAFE_CALL(
cuLaunchKernel(*compiled_kernel.get(), nthreads, 1, 1, 1, 1, 1, 0, NULL, args_list, 0));
CUDA_SAFE_CALL(cuCtxSynchronize());
}});
primitive_index = this->m_primitive_emitter->insert(std::move(broadcast));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
void runtime::gpu::CUDAEmitter::print_tensor_from_gpu(codegen::CodeWriter& writer,
const std::string& tensor_name,
GPUShape shape)
......@@ -977,12 +1211,24 @@ std::string runtime::gpu::CUDAEmitter::include_helpers()
std::stringstream ss;
#if defined(CUDA_VERSION) && CUDA_VERSION < 9000
ss << R"(
#define WARP_SIZE 32
#define __ballot_sync(mask, predicate) __ballot(predicate)
#define __shfl_down_sync(mask, val, delta, width) __shfl_down(val, delta, width)
#define __shfl_xor_sync(mask, val, laneMask, width) __shfl_xor(val, laneMask, width)
)";
#endif
// add modern type definitions
ss << "typedef signed char int8_t;\n";
ss << "typedef signed short int16_t;\n";
ss << "typedef signed int int32_t;\n";
ss << "typedef signed long int int64_t;\n";
ss << "typedef unsigned char uint8_t;\n";
ss << "typedef unsigned short uint16_t;\n";
ss << "typedef unsigned int uint32_t;\n";
ss << "typedef unsigned long int uint64_t;\n";
ss << "\n";
// division_by_invariant_multiplication:
// fast integer division via invariant multiplication and shifting
// if value is a power of 2, magic will be 1 and only shifting
......@@ -1031,6 +1277,15 @@ __device__ __forceinline__ float load(const float* __restrict__ in, int i=0, b
}
return v;
}
__device__ __forceinline__ int64_t load(const int64_t* __restrict__ in, int i=0, bool b=true)
{
int64_t v = 0;
if (b)
{
v = __ldg(in + i);
}
return v;
}
)";
return ss.str();
}
......@@ -80,6 +80,25 @@ namespace ngraph
ctx, dtypes, tensor_shape, CudaOpMap<T>::op, CudaOpMap<T>::math_kernel);
}
template <typename ELEMENTWISE_OP_TYPE, typename REDUCE_OP_TYPE = ngraph::op::Nop>
size_t build_elementwise_collective(const GPURuntimeContext* ctx,
const std::vector<std::string>& dtypes,
GPUShape tensor_shape,
const std::set<size_t>& reduced_tensors = {},
const std::set<size_t>& axes = {},
bool save_elementwise = false)
{
return build_fused_ew_to_collective(ctx,
dtypes,
tensor_shape,
reduced_tensors,
axes,
CudaOpMap<ELEMENTWISE_OP_TYPE>::op,
CudaOpMap<ELEMENTWISE_OP_TYPE>::math_kernel,
CudaOpMap<REDUCE_OP_TYPE>::atomic,
save_elementwise);
}
size_t build_replace_slice(const GPURuntimeContext* ctx,
const std::array<std::string, 3>& dtypes,
GPUShape tensor_shape,
......@@ -88,6 +107,11 @@ namespace ngraph
GPUShape upper_bounds,
GPUShape slice_stride);
size_t build_broadcast(const GPURuntimeContext* ctx,
const std::array<std::string, 2>& dtypes,
GPUShape result_shape,
const std::set<size_t>& bcast_axes);
private:
CUDAEmitter(GPUPrimitiveEmitter* emitter);
void print_tensor_from_gpu(codegen::CodeWriter& writer,
......@@ -99,6 +123,15 @@ namespace ngraph
GPUShape tensor_shape,
const char* op,
const char* kernel);
size_t build_fused_ew_to_collective(const GPURuntimeContext* ctx,
const std::vector<std::string>& dtypes,
GPUShape tensor_shape,
const std::set<size_t>& reduced_tensors,
const std::set<size_t>& axes,
const char* op,
const char* kernel,
const char* reduce_op,
bool save_elementwise);
GPUPrimitiveEmitter* m_primitive_emitter;
};
......
......@@ -57,28 +57,131 @@ void runtime::gpu::CudaKernelBuilder::get_elementwise_op(codegen::CodeWriter& wr
return;
}
void runtime::gpu::CudaKernelBuilder::get_ew_collective_op(
codegen::CodeWriter& writer,
const std::string& name,
const std::string& op,
const std::string& reduce_op,
const std::vector<std::string>& data_types,
const std::set<size_t>& reduced_tensors,
bool save_elementwise,
size_t rank)
{
auto num_inputs = data_types.size() - 1;
writer << "extern \"C\" __global__ void cuda_" << name << "(";
for (size_t i = 0; i < num_inputs; i++)
{
writer << data_types[i] << "* in" << i << ", ";
}
writer << data_types[num_inputs] << "* out0, ";
// multi-output to save intermediate elementwise op if requested
if (save_elementwise)
{
writer << data_types[num_inputs] << "* out1, ";
}
writer << "int* strides, "
<< "int* stride_magic, "
<< "int* stride_shift, "
<< "int* reduced_strides, "
<< "size_t n)\n";
writer.block_begin();
{
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x; \n";
writer << "if (tid < n)\n";
writer.block_begin();
{
std::string reduced_idx = collective_coordinate_transform_helper(writer,
"tid",
"strides",
"stride_magic",
"stride_shift",
"reduced_strides",
"coordinate",
rank);
// element-wise operation
writer << data_types[num_inputs] << " output = " << op << "(";
for (size_t i = 0; i < num_inputs; i++)
{
if (i > 0)
{
writer << ", ";
}
writer << "in" << i << "[";
if (reduced_tensors.count(i) > 0)
{
writer << reduced_idx;
}
else
{
writer << "tid";
}
writer << "]";
}
writer << ");\n";
// global collective reduce or broadcast
if (reduce_op != "")
{
// TODO: mediate atomic memory access contention
writer << reduce_op << "(&out0[" << reduced_idx << "], output);\n";
if (save_elementwise)
{
writer << "out1["
<< "tid"
<< "] = output;\n";
}
}
else
{
writer << "out0[tid] = output;\n";
if (save_elementwise)
{
writer << "out1[" << reduced_idx << "] = output;\n";
}
}
}
writer.block_end();
}
writer.block_end();
return;
}
void runtime::gpu::CudaKernelBuilder::get_broadcast_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types)
const std::array<std::string, 2>& data_types,
const size_t rank)
{
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1] << "* out, size_t m, size_t k, size_t n)\n";
writer << "{\n";
writer.indent++;
<< data_types[1] << "* out, "
<< "int* strides, "
<< "int* stride_magic, "
<< "int* stride_shift, "
<< "int* reduced_strides, "
<< "float alpha, float beta, "
<< "size_t nthreads"
<< ")\n";
writer.block_begin();
{
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < n)\n";
writer << "{\n";
writer.indent++;
writer << "const int tid = blockDim.x*blockIdx.x + threadIdx.x;\n";
writer << "if (tid < nthreads)\n";
writer.block_begin();
{
writer << "size_t idx = tid / (m * k) * m + tid % m;\n";
writer << "out[tid] = in[idx];\n";
// calculate tensor coordinates (inverse tensor reduction)
std::string reduced_idx = collective_coordinate_transform_helper(writer,
"tid",
"strides",
"stride_magic",
"stride_shift",
"reduced_strides",
"coordinate",
rank);
writer << "out[tid] = load(in, " << reduced_idx << ");\n";
}
writer.indent--;
writer << "}\n";
writer.block_end();
}
writer.indent--;
writer << "}\n";
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_onehot_op(codegen::CodeWriter& writer,
......@@ -372,6 +475,47 @@ void runtime::gpu::CudaKernelBuilder::get_replace_slice_op(
writer.block_end();
}
std::string runtime::gpu::CudaKernelBuilder::collective_coordinate_transform_helper(
codegen::CodeWriter& writer,
std::string i_thread_index,
std::string i_strides,
std::string i_stride_magic,
std::string i_stride_shift,
std::string i_reduced_strides,
std::string o_coordinates,
size_t rank)
{
// Translation from flat index to dense tensor coordinates:
// Given tensor shape [d0 d1 ... dN] with strides [d1*...*dN, d2*...*dN, ... 1],
// calculate coordinates as:
//
// product = tid
// d0 = product/stride[0]
// product = product % stride[0]
// d1 = product/stride[1]
// ...
writer << "int coordinate_product = " << i_thread_index << ";\n";
for (size_t i = 0; i < rank; i++)
{
writer << "int " << o_coordinates << i << " = division_by_invariant_multiplication("
<< "coordinate_product, " << i_stride_magic << "[" << i << "], " << i_stride_shift
<< "[" << i << "]);\n";
writer << "coordinate_product -= (" << o_coordinates << i << " * " << i_strides << "[" << i
<< "]);\n";
}
// index into reduced tensor from coordinates of non-reduced tensor
std::string reduced_idx = "reduced_idx";
writer << "int " << reduced_idx << " = 0;\n";
for (size_t i = 0; i < rank; i++)
{
writer << "reduced_idx += " << o_coordinates << i << " * " << i_reduced_strides << "[" << i
<< "];\n";
}
return reduced_idx;
}
void runtime::gpu::CudaKernelBuilder::get_device_helper(codegen::CodeWriter& writer,
const std::string& name,
const std::string& math_kernel,
......
......@@ -17,6 +17,7 @@
#pragma once
#include <array>
#include <set>
#include <string>
#include <vector>
......@@ -40,7 +41,8 @@ namespace ngraph
static void get_broadcast_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types);
const std::array<std::string, 2>& data_types,
const size_t rank);
static void get_concat_op(codegen::CodeWriter& writer,
const std::string& name,
......@@ -63,23 +65,44 @@ namespace ngraph
const std::string& name,
const std::array<std::string, 2>& data_types);
static void get_replace_slice_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 3>& data_types,
int nthreads_per_block);
static void get_reduce_window_op(codegen::CodeWriter& writer,
const std::string& name,
const std::string& op,
const std::vector<std::string>& data_types,
const size_t rank);
static void get_replace_slice_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 3>& data_types,
int nthreads_per_block);
static void get_device_helper(codegen::CodeWriter& writer,
const std::string& name,
const std::string& math_kernel,
const std::vector<std::string>& data_types);
static void get_ew_collective_op(codegen::CodeWriter& writer,
const std::string& name,
const std::string& op,
const std::string& reduce_op,
const std::vector<std::string>& data_types,
const std::set<size_t>& reduced_tensors,
bool save_elementwise,
size_t rank);
static void add_pod_typedefs(codegen::CodeWriter& writer);
/// \brief Given kernel input variables i_* produce register variables o_coordinates{i}
/// of the non-reduced tensor and return the string name of integer index into reduced tensor
static std::string
collective_coordinate_transform_helper(codegen::CodeWriter& writer,
std::string i_thread_index,
std::string i_strides,
std::string i_stride_magic,
std::string i_stride_shift,
std::string i_reduced_strides,
std::string o_coordinates,
size_t rank);
};
}
}
......
......@@ -23,43 +23,6 @@
using namespace ngraph;
using namespace ngraph::runtime::gpu;
void runtime::gpu::emit_broadcast(const std::string& name,
std::array<std::string, 2> data_types,
GPURuntimeContext* ctx,
CUdeviceptr in,
CUdeviceptr out,
size_t repeat_size,
size_t repeat_times,
size_t count)
{
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
// Create an instance of nvrtcProgram with the code string.
auto compiled_kernel = ctx->compiled_kernel_pool->get(name_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_broadcast_op(writer, name_signature, data_types);
std::string kernel = writer.get_code();
compiled_kernel = ctx->compiled_kernel_pool->set(name_signature, kernel);
}
void* args_list[] = {&in, &out, &repeat_size, &repeat_times, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(count),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
void runtime::gpu::emit_onehot(const std::string& name,
std::array<std::string, 2> data_types,
GPURuntimeContext* ctx,
......
......@@ -36,15 +36,6 @@ namespace ngraph
template <typename T>
struct CudaOpMap;
void emit_broadcast(const std::string& name,
std::array<std::string, 2> data_types,
GPURuntimeContext* ctx,
CUdeviceptr in,
CUdeviceptr out,
size_t repeat_size,
size_t repeat_times,
size_t count);
void emit_onehot(const std::string& name,
std::array<std::string, 2> data_types,
GPURuntimeContext* ctx,
......
......@@ -60,6 +60,7 @@ namespace ngraph
class Select;
class And;
class Or;
class Nop;
}
namespace runtime
{
......@@ -207,6 +208,7 @@ namespace ngraph
{
static constexpr const char* op = "subtractf";
static constexpr const char* math_kernel = "x0-x1";
static constexpr const char* atomic = "atomicSub";
};
template <>
......@@ -305,6 +307,7 @@ namespace ngraph
{
static constexpr const char* op = "logical_and";
static constexpr const char* math_kernel = "x0 & x1";
static constexpr const char* atomic = "atomicAnd";
};
template <>
......@@ -312,6 +315,7 @@ namespace ngraph
{
static constexpr const char* op = "logical_or";
static constexpr const char* math_kernel = "x0 | x1";
static constexpr const char* atomic = "atomicOr";
};
template <>
......@@ -319,6 +323,7 @@ namespace ngraph
{
static constexpr const char* op = "add";
static constexpr const char* math_kernel = "x0 + x1";
static constexpr const char* atomic = "atomicAdd";
};
template <>
......@@ -333,6 +338,7 @@ namespace ngraph
{
static constexpr const char* op = "min";
static constexpr const char* math_kernel = "x0 > x1 ? x1 : x0";
static constexpr const char* atomic = "atomicMin";
};
template <>
......@@ -340,6 +346,15 @@ namespace ngraph
{
static constexpr const char* op = "max";
static constexpr const char* math_kernel = "x0 > x1 ? x0 : x1";
static constexpr const char* atomic = "atomicMax";
};
template <>
struct CudaOpMap<ngraph::op::Nop>
{
static constexpr const char* op = "";
static constexpr const char* math_kernel = "";
static constexpr const char* atomic = "";
};
}
}
......
......@@ -22,6 +22,7 @@
#include <cudnn.h>
#include <iostream>
#include <nvrtc.h>
#include <set>
#include <string>
#include <typeindex>
#include <unordered_map>
......@@ -825,50 +826,17 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
return;
}
// broadcast axes size is 1, or can be group to 1 (consecutive axes, like 01 or 12 or 123 etc)
vector<int> axes_v;
std::copy(axes.begin(), axes.end(), std::back_inserter(axes_v));
std::sort(axes_v.begin(), axes_v.end());
bool is_one_axes = true;
if (axes.size() != 1)
{
for (int i = 1; i < axes_v.size(); i++)
{
if (axes_v[i] != axes_v[i - 1] + 1)
{
is_one_axes = false;
break;
}
}
}
if (is_one_axes)
{
int repeat_times = 1;
for (int i = 0; i < axes_v.size(); i++)
{
repeat_times *= result_shape[axes_v[i]];
}
int repeat_size = 1;
for (int i = *axes_v.rbegin() + 1; i < result_shape.size(); i++)
{
repeat_size *= result_shape[i];
}
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
writer.block_begin(" // " + node->get_name());
writer << "runtime::gpu::emit_broadcast(\"" << node->description() << "\", {\""
<< args[0].get_type() << "\", \"" << out[0].get_type() << "\"}"
<< ", ctx"
<< ", CUdeviceptr(" << args[0].get_name() << "), CUdeviceptr("
<< out[0].get_name() << ")"
<< ", " << repeat_size << ", " << repeat_times << ", "
<< out[0].get_size() << ");\n";
writer.block_end();
}
else
{
throw std::runtime_error(node->get_name() + " is not implemented.");
}
auto bcast_index =
cuda_emitter->build_broadcast(external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
result_shape,
axes);
writer << "gpu::invoke_primitive(ctx, " << bcast_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
template <>
......@@ -2165,26 +2133,76 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
auto softmax = static_cast<const ngraph::op::Softmax*>(node);
auto tensor_shape = args[0].get_shape();
auto axes = softmax->get_axes();
if (axes.size() != tensor_shape.size())
{
throw std::runtime_error(
"Softmax implementation currently only supports all axis activation.");
}
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
size_t softmax_index =
cudnn_emitter->build_softmax(external_function->ctx().get(),
CUDNN_SOFTMAX_FAST,
CUDNN_SOFTMAX_MODE_INSTANCE,
CUDNNEmitter::Prop::Forward,
tensor_shape);
if (axes.size() != tensor_shape.size())
{
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
writer << "gpu::invoke_primitive(ctx, " << softmax_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
// reserve a temporary buffer for the intermediate reduction
GPUAllocator allocator =
external_function->get_primitive_emitter()->get_memory_allocator();
auto reduced_shape = tensor_shape;
for (auto const& axis : axes)
{
reduced_shape[axis] = 1;
}
size_t reduced_size = shape_size(reduced_shape);
size_t workspace_idx = allocator.reserve_workspace(
reduced_size * out[0].get_element_type().size());
// exponentiate with fused sum reduction to calculate softmax denominator
size_t exp_sum_reduce =
cuda_emitter
->build_elementwise_collective<ngraph::op::Exp, ngraph::op::Add>(
external_function->ctx().get(),
{{args[0].get_type(), out[0].get_type()}},
args[0].get_shape(),
{},
axes,
true /* multi-output */);
writer << "void* workspace = gpu::invoke_memory_primitive(ctx, "
<< workspace_idx << ");\n";
writer << "gpu::invoke_primitive(ctx, " << exp_sum_reduce << ", ";
writer << "std::vector<void*>{" << args[0].get_name();
writer << "}.data(), ";
// cache the elementwise result and the fused result (multi-output)
writer << "std::vector<void*>{ workspace, ";
writer << out[0].get_name() << "}.data()";
writer << ");\n";
// inplace binary division with fused broadcast to calculate softmax
size_t div_broadcast =
cuda_emitter->build_elementwise_collective<ngraph::op::Divide>(
external_function->ctx().get(),
{{out[0].get_type(), out[0].get_type(), out[0].get_type()}},
out[0].get_shape(),
{1},
axes);
writer << "gpu::invoke_primitive(ctx, " << div_broadcast << ", ";
writer << "std::vector<void*>{" << out[0].get_name();
writer << ", workspace}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
else
{
size_t softmax_index =
cudnn_emitter->build_softmax(external_function->ctx().get(),
CUDNN_SOFTMAX_FAST,
CUDNN_SOFTMAX_MODE_INSTANCE,
CUDNNEmitter::Prop::Forward,
tensor_shape);
writer << "gpu::invoke_primitive(ctx, " << softmax_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
}
writer.block_end();
}
......
......@@ -771,6 +771,10 @@ using namespace std;
// End generated function
writer += "}\n\n";
}
// allocate device buffers for primitive arguments and workspace
m_primitive_emitter->allocate_primitive_memory();
// TODO: Cleanup and make this a utility function
// allocate device buffers for primitive arguments and workspace
......
......@@ -2,14 +2,10 @@ abc_int64
backwards_reverse_sequence_n4d2c3h2w2
backwards_reverse_sequence_n3_c2_h3
backwards_slice
backwards_softmax_3d
backwards_softmax_axis
batch_norm_one_output
batch_norm_three_outputs
broadcast_vector_rowwise_int64
computation_reuse
concat_matrix_int64
constant_broadcast
constant_equality_bool
convolution_2d_1item_1o1i_data_dilated
convolution_2d_1item_2o1i_data_dilated
......@@ -58,8 +54,6 @@ scalar_constant_int64
select_and_scatter_3d_without_overlap
select_and_scatter_with_overlap
select_and_scatter_without_overlap
softmax_axis
softmax_underflow
tensor_constant
tensor_constant_float32
tensor_constant_int64
......
......@@ -7699,6 +7699,42 @@ NGRAPH_TEST(${BACKEND_NAME}, softmax_all)
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, softmax_axis_3d)
{
Shape shape{2, 2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Softmax>(A, AxisSet{0}), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{-10, -20, -30, -40, -50, -60, -1, -2, -3, -4, -5, -6});
auto result = backend->create_tensor(element::f32, shape);
auto d0 = expf(-10) + expf(-1);
auto d1 = expf(-20) + expf(-2);
auto d2 = expf(-30) + expf(-3);
auto d3 = expf(-40) + expf(-4);
auto d4 = expf(-50) + expf(-5);
auto d5 = expf(-60) + expf(-6);
backend->call(f, {result}, {a});
vector<float> expected{expf(-10) / d0,
expf(-20) / d1,
expf(-30) / d2,
expf(-40) / d3,
expf(-50) / d4,
expf(-60) / d5,
expf(-1) / d0,
expf(-2) / d1,
expf(-3) / d2,
expf(-4) / d3,
expf(-5) / d4,
expf(-6) / d5};
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, softmax_axis)
{
Shape shape{2, 3};
......@@ -7724,6 +7760,49 @@ NGRAPH_TEST(${BACKEND_NAME}, softmax_axis)
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, softmax_axis_2)
{
Shape shape{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Softmax>(A, AxisSet{0}), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{-10, -20, -30, -40, -50, -60});
auto result = backend->create_tensor(element::f32, shape);
auto d0 = expf(-10) + expf(-40);
auto d1 = expf(-20) + expf(-50);
auto d2 = expf(-30) + expf(-60);
backend->call(f, {result}, {a});
vector<float> expected{expf(-10) / d0,
expf(-20) / d1,
expf(-30) / d2,
expf(-40) / d0,
expf(-50) / d1,
expf(-60) / d2};
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, softmax_axis_3d_trivial)
{
Shape shape{1, 2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Softmax>(A, AxisSet{0}), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{-10, -20, -30, -40, -50, -60});
auto result = backend->create_tensor(element::f32, shape);
backend->call(f, {result}, {a});
vector<float> expected{1, 1, 1, 1, 1, 1};
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, softmax_underflow)
{
Shape shape{2, 3};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment