Commit 154dc47a authored by Fenglei's avatar Fenglei Committed by Robert Kimball

cuda optimize softmax (#1310)

* Updated softmax.

* Formatting.

* Updated convolution.

* Use build_primitive overloading. Add helper to emit type_string given a node.

* Formatting.

* Update ConvolutionBackpropData.

* convolution backprop & max pool memory primitive cacheing (#1303)

* Updated ConvolutionBackpropFilters.
* Update MaxPool.

* Update Max and Min. (#1307)

* softmax optimization

* fix bug

* fix bugs

* clang format

* remove comments

* add softmax divide

* fix bugs

* fix bug

* fix bug

* clang format

* remove unused header

* register

* using single parameters instead of array

* using build_elementwise instead of build_elementwise_collective

* remove workspace as csullivan suggested
parent 8db7b24b
......@@ -21,6 +21,7 @@
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/runtime/gpu/cuda_emitter.hpp"
#include "ngraph/runtime/gpu/cudnn_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/runtime/gpu/gpu_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_invoke.hpp"
......@@ -1254,16 +1255,92 @@ size_t runtime::gpu::CUDAEmitter::build_primitive(const op::MaxPool* node)
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_softmax_divide(const std::vector<std::string>& dtypes,
GPUShape input_shape,
GPUShape reduce_shape,
std::vector<size_t> axes_flag)
{
std::string kernel_name =
"softmax_divide_" + join(dtypes, "_") + "_axes_" + join(axes_flag, "_");
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
size_t nthreads = shape_size(input_shape);
std::string hash = kernel_name + "_n" + join(input_shape, "_") + join(reduce_shape, "_") +
std::to_string(nthreads);
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// if the kernel has not been compiled, build it
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(hash);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
writer << include_helpers();
CudaKernelBuilder::get_softmax_divide_op(
writer, kernel_name, dtypes, axes_flag, input_shape.size());
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
}
GPUShape input_strides = row_major_strides(input_shape);
GPUShape reduce_strides = row_major_strides(reduce_shape);
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
//TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x =
align_to_block_size(static_cast<uint32_t>(nthreads), block_size_x);
std::unique_ptr<gpu::primitive> pool(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
std::vector<void*> arg_list;
arg_list.push_back(&inputs[0]);
arg_list.push_back(&inputs[1]);
arg_list.push_back(&outputs[0]);
for (size_t i = 0; i < input_strides.size(); i++)
{
arg_list.push_back(&input_strides[i]);
}
for (size_t i = 0; i < reduce_strides.size(); i++)
{
arg_list.push_back(&reduce_strides[i]);
}
arg_list.push_back(&nthreads);
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1, // grid dim
block_size_x,
1,
1, // block dim
0,
NULL, // shared mem and stream
arg_list.data(),
0)); // arguments
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(pool));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_primitive(const op::Softmax* node)
{
auto& args = node->get_inputs();
auto& out = node->get_outputs();
auto tensor_shape = args[0].get_shape();
auto input_shape = args[0].get_shape();
auto axes = node->get_axes();
std::stringstream ss;
ss << "softmax_" << runtime::gpu::kernel::emit_type_string(node) << "_s"
<< join(tensor_shape, "_") << "_ra" << join(axes, "_");
<< join(input_shape, "_") << "_ra" << join(axes, "_");
auto hash = ss.str();
size_t primitive_index = m_primitive_emitter->lookup(hash);
......@@ -1273,39 +1350,58 @@ size_t runtime::gpu::CUDAEmitter::build_primitive(const op::Softmax* node)
}
// build composite primitive
auto& cudnn_emitter = m_primitive_emitter->get_cudnn_emitter();
// reserve a temporary buffer for the intermediate reduction
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
auto reduced_shape = tensor_shape;
auto reduced_shape = input_shape;
std::vector<size_t> axes_flag(input_shape.size(), 0);
for (auto const& axis : axes)
{
reduced_shape[axis] = 1;
axes_flag[axis] = 1;
}
size_t reduced_size = shape_size(reduced_shape);
size_t workspace_idx =
allocator.reserve_workspace(reduced_size * out[0].get_element_type().size());
size_t tensor_size = shape_size(input_shape);
size_t type_size = out[0].get_element_type().size();
size_t reduce_buffer_idx = allocator.reserve_workspace(reduced_size * type_size);
// exponentiate with fused sum reduction to calculate softmax denominator
auto input_type = args[0].get_element_type().c_type_string();
auto output_type = out[0].get_element_type().c_type_string();
size_t exp_sum_reduce = build_elementwise_collective<ngraph::op::Exp, ngraph::op::Add>(
{{input_type, output_type}}, tensor_shape, {}, axes, true /* multi-output */);
// inplace binary division with fused broadcast to calculate softmax
size_t div_broadcast = build_elementwise_collective<ngraph::op::Divide>(
std::vector<std::string>(3, output_type), tensor_shape, {1}, axes);
auto exp_index = build_elementwise<ngraph::op::Exp>({{input_type, output_type}}, input_shape);
auto reduce_index = cudnn_emitter->build_reduce_forward(
CUDNN_REDUCE_TENSOR_ADD, output_type, input_shape, axes);
size_t divide_index = build_softmax_divide(
std::vector<std::string>(3, output_type), input_shape, reduced_shape, axes_flag);
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void* workspace = runtime::gpu::invoke_memory_primitive(m_ctx, workspace_idx);
// cache the elementwise result and the fused result (multi-output)
if (reduced_size == tensor_size)
{
//the result should be all set to 1.
//TODO: add memset
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
runtime::gpu::invoke_primitive(
m_ctx, divide_index, std::vector<void*>{inputs[0], inputs[0]}.data(), outputs);
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
}
else
{
std::unique_ptr<gpu::primitive> kernel_launch(new gpu::primitive{[=](
void** inputs, void** outputs) mutable {
void* reduce_buffer = runtime::gpu::invoke_memory_primitive(m_ctx, reduce_buffer_idx);
runtime::gpu::invoke_primitive(m_ctx, exp_index, inputs, outputs);
runtime::gpu::invoke_primitive(
m_ctx, exp_sum_reduce, inputs, std::vector<void*>{workspace, outputs[0]}.data());
m_ctx, reduce_index, outputs, std::vector<void*>{reduce_buffer}.data());
runtime::gpu::invoke_primitive(
m_ctx, div_broadcast, std::vector<void*>{outputs[0], workspace}.data(), outputs);
m_ctx, divide_index, std::vector<void*>{outputs[0], reduce_buffer}.data(), outputs);
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
}
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
......
......@@ -158,6 +158,11 @@ namespace ngraph
size_t concat_axis,
GPUShape output_shape);
size_t build_softmax_divide(const std::vector<std::string>& dtypes,
GPUShape input_shape,
GPUShape reduce_shape,
std::vector<size_t> axes_flag);
void debug_sync();
void sync();
......
......@@ -55,6 +55,54 @@ void runtime::gpu::CudaKernelBuilder::get_elementwise_op(codegen::CodeWriter& wr
return;
}
void runtime::gpu::CudaKernelBuilder::get_softmax_divide_op(
codegen::CodeWriter& writer,
const std::string& name,
const std::vector<std::string>& data_types,
std::vector<size_t> axes_flag,
size_t rank)
{
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in0, "
<< data_types[1] << "* in1, " << data_types[2] << "* out,";
for (size_t i = 0; i < axes_flag.size(); i++)
{
writer << "uint32_t input0_strides" << i << ", ";
}
for (size_t i = 0; i < axes_flag.size(); i++)
{
writer << "uint32_t input1_strides" << i << ", ";
}
writer << "uint32_t n)\n";
writer.block_begin();
{
writer << "uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < n)\n";
writer.block_begin();
{
writer << "uint32_t input0_idx = tid;\n";
writer << "uint32_t input1_idx = 0;\n";
size_t i = 0;
for (; i < rank - 1; i++)
{
if (axes_flag[i] != 1)
{
writer << "input1_idx += (input0_idx / input0_strides" << i
<< ") * input1_strides" << i << ";\n";
}
writer << "input0_idx %= input0_strides" << i << ";\n";
}
if (axes_flag[i] != 1)
{
writer << "input1_idx += (input0_idx / input0_strides" << i << ") * input1_strides"
<< i << ";\n";
}
writer << "out[tid] = in0[tid] / in1[input1_idx];\n";
}
writer.block_end();
}
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_ew_collective_op(
codegen::CodeWriter& writer,
const std::string& name,
......
......@@ -131,6 +131,12 @@ namespace ngraph
int sm_tile_size = 8,
int reg_tile_size = 1);
static void get_softmax_divide_op(codegen::CodeWriter& writer,
const std::string& name,
const std::vector<std::string>& data_types,
std::vector<size_t> axes_flag,
size_t rank);
static void add_pod_typedefs(codegen::CodeWriter& writer);
/// \brief Given kernel input variables i_* produce register variables o_coordinates{i}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment