Commit be9f031e authored by Fenglei's avatar Fenglei Committed by Robert Kimball

nvgpu softmax cuda version (#2014)

* add softmax cuda support

* optimize block size

* remove debug info

* remove debug

* style

* remove unused

* remove cudnn softmax

* format

* using nullptr

* move helper, add test

* fix style

* using all_close_f

* using kahansum

* style

* remove commentted out code
parent 702d465a
......@@ -1568,19 +1568,22 @@ size_t runtime::gpu::CUDAEmitter::build_primitive(const op::MaxPool* node)
return this->m_primitive_emitter->register_primitive(kernel_launch, hash);
}
size_t runtime::gpu::CUDAEmitter::build_softmax_divide(const std::vector<std::string>& dtypes,
NVShape input_shape,
NVShape reduce_shape,
std::vector<size_t> axes_flag)
size_t runtime::gpu::CUDAEmitter::build_softmax(const std::vector<std::string>& dtypes,
NVShape input_shape,
NVShape reduce_axis)
{
std::string kernel_name =
"softmax_divide_" + join(dtypes, "_") + "_axes_" + join(axes_flag, "_");
size_t rank = input_shape.size();
size_t reduce_rank = reduce_axis.size();
size_t out_rank = rank - reduce_rank;
// assumes NC{d1,...,dn} format
std::string kernel_name = "softmax_" + join(dtypes, "_");
kernel_name +=
"_ri_" + std::to_string(input_shape.size()) + "_rr_" + std::to_string(reduce_axis.size());
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
size_t nthreads = shape_size(input_shape);
std::string hash = kernel_name + "_n" + join(input_shape, "_") + join(reduce_shape, "_") +
std::to_string(nthreads);
std::stringstream ss;
ss << kernel_name << "_s_" << join(input_shape, "_") << "_axis_" << join(reduce_axis, "_");
auto hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
......@@ -1588,57 +1591,79 @@ size_t runtime::gpu::CUDAEmitter::build_softmax_divide(const std::vector<std::st
return primitive_index;
}
NVShape reduce_flag(rank, 0);
for (auto a : reduce_axis)
{
reduce_flag[a] = 1;
}
NVShape output_shape;
NVShape non_reduce_strides;
NVShape reduce_shape;
NVShape reduce_strides;
NVShape input_strides = row_major_strides(input_shape);
for (int i = 0; i < rank; i++)
{
if (reduce_flag[i] != 0)
{
reduce_shape.push_back(input_shape[i]);
reduce_strides.push_back(input_strides[i]);
}
else
{
non_reduce_strides.push_back(input_strides[i]);
output_shape.push_back(input_shape[i]);
}
}
NVShape output_strides = row_major_strides(output_shape);
uint32_t nthreads = static_cast<uint32_t>(shape_size(output_shape));
// TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
if (reduce_flag.back() == 1)
{
block_size_x = 8;
}
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);
auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes[0], "in")
.add_placeholder(dtypes[1], "out")
.add("out_strides", output_strides)
.add("non_reduce_strides", non_reduce_strides)
.add("reduce_shape", reduce_shape)
.add("reduce_strides", reduce_strides)
.add("nthreads", nthreads);
// if the kernel has not been compiled, build it
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(hash);
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_softmax_divide_op(
writer, kernel_name, dtypes, axes_flag, input_shape.size());
runtime::gpu::CudaKernelBuilder::get_softmax_op(
writer, kernel_name, args, dtypes, out_rank, reduce_rank);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
}
NVShape input_strides = row_major_strides(input_shape);
NVShape reduce_strides = row_major_strides(reduce_shape);
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
// TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x =
align_to_block_size(static_cast<uint32_t>(nthreads), block_size_x);
std::unique_ptr<gpu::primitive> pool(
std::unique_ptr<gpu::primitive> softmax(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
std::vector<void*> arg_list;
arg_list.push_back(&inputs[0]);
arg_list.push_back(&inputs[1]);
arg_list.push_back(&outputs[0]);
for (size_t i = 0; i < input_strides.size(); i++)
{
arg_list.push_back(&input_strides[i]);
}
for (size_t i = 0; i < reduce_strides.size(); i++)
{
arg_list.push_back(&reduce_strides[i]);
}
arg_list.push_back(&nthreads);
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &outputs[0])
.get_argument_list();
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1, // grid dim
1,
block_size_x,
1,
1, // block dim
1,
0,
nullptr, // shared mem and stream
arg_list.data(),
nullptr)); // arguments
nullptr,
args_list,
nullptr));
debug_sync();
}});
return this->m_primitive_emitter->register_primitive(pool, hash);
return this->m_primitive_emitter->register_primitive(softmax, hash);
}
size_t runtime::gpu::CUDAEmitter::build_reduce_to_nd(const std::vector<std::string>& dtypes,
......@@ -1978,82 +2003,6 @@ size_t runtime::gpu::CUDAEmitter::build_reduce(const std::vector<std::string>& d
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_primitive(const op::Softmax* node)
{
auto& args = node->get_inputs();
auto& out = node->get_outputs();
auto input_shape = args[0].get_shape();
auto axes = node->get_axes();
std::stringstream ss;
ss << "softmax_" << runtime::gpu::kernel::emit_type_string(node) << "_s"
<< join(input_shape, "_") << "_ra" << join(axes, "_");
auto hash = ss.str();
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// build composite primitive
auto& cudnn_emitter = m_primitive_emitter->get_cudnn_emitter();
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
auto reduced_shape = input_shape;
std::vector<size_t> axes_flag(input_shape.size(), 0);
for (auto const& axis : axes)
{
reduced_shape[axis] = 1;
axes_flag[axis] = 1;
}
size_t reduced_size = shape_size(reduced_shape);
size_t tensor_size = shape_size(input_shape);
size_t type_size = out[0].get_element_type().size();
size_t reduce_buffer_idx = allocator.reserve_workspace(reduced_size * type_size);
// exponentiate with fused sum reduction to calculate softmax denominator
auto input_type = args[0].get_element_type().c_type_string();
auto output_type = out[0].get_element_type().c_type_string();
auto exp_index = build_elementwise<ngraph::op::Exp>({input_type, output_type}, input_shape);
std::vector<element::Type> dtypes{args[0].get_element_type(), out[0].get_element_type()};
auto reduce_index = cudnn_emitter->build_reduce_forward(
CUDNN_REDUCE_TENSOR_ADD, dtypes, input_shape, axes, CUDNNEmitter::ReductionMode::Reduce);
size_t divide_index = build_softmax_divide(
std::vector<std::string>(3, output_type), input_shape, reduced_shape, axes_flag);
if (reduced_size == tensor_size)
{
// the result should be all set to 1.
// TODO: add memset
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
runtime::gpu::invoke_primitive(
m_ctx, divide_index, std::vector<void*>{inputs[0], inputs[0]}.data(), outputs);
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
}
else
{
std::unique_ptr<gpu::primitive> kernel_launch(new gpu::primitive{[=](
void** inputs, void** outputs) mutable {
void* reduce_buffer = runtime::gpu::invoke_memory_primitive(m_ctx, reduce_buffer_idx);
runtime::gpu::invoke_primitive(m_ctx, exp_index, inputs, outputs);
runtime::gpu::invoke_primitive(
m_ctx, reduce_index, outputs, std::vector<void*>{reduce_buffer}.data());
runtime::gpu::invoke_primitive(
m_ctx, divide_index, std::vector<void*>{outputs[0], reduce_buffer}.data(), outputs);
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
}
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t
runtime::gpu::CUDAEmitter::build_fused_ew_to_collective(const std::vector<std::string>& dtypes,
NVShape tensor_shape,
......
......@@ -44,7 +44,6 @@ namespace ngraph
friend class GPUPrimitiveEmitter;
public:
size_t build_primitive(const op::Softmax* node);
size_t build_primitive(const op::Convolution* node);
size_t build_primitive(const op::MaxPool* node);
size_t build_primitive(const op::ReplaceSlice* node, bool in_place_op);
......@@ -186,10 +185,9 @@ namespace ngraph
size_t concat_axis,
NVShape output_shape);
size_t build_softmax_divide(const std::vector<std::string>& dtypes,
NVShape input_shape,
NVShape reduce_shape,
std::vector<size_t> axes_flag);
size_t build_softmax(const std::vector<std::string>& dtypes,
NVShape input_shape,
NVShape reduce_axis);
void debug_sync();
void sync();
......
......@@ -1874,74 +1874,6 @@ size_t runtime::gpu::CUDNNEmitter::build_lrn(const std::string& dtype,
return primitive_index;
}
size_t runtime::gpu::CUDNNEmitter::build_softmax(const cudnnSoftmaxAlgorithm_t& algorithm,
const cudnnSoftmaxMode_t& mode,
const std::string& dtype,
const Prop& direction,
const Shape& tensor_shape)
{
// construct hash to determine if kernel needs to be emitted
// or if it already exists in the primitive list
std::stringstream ss;
ss << "softmax_op_" << mode << "_dtype_" << dtype << "_alg" << algorithm << "_dir"
<< static_cast<int>(direction) << "_s" << join(tensor_shape, "_");
std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
cudnnDataType_t data_type = get_cudnn_datatype(dtype);
cudnnTensorFormat_t tensor_format = CUDNN_TENSOR_NCHW;
auto& tensor_desc = tensor_descriptor_from_shape(tensor_shape, data_type, tensor_format);
void* alpha = m_host_parameters.allocate_by_datatype(data_type, 1.0);
void* beta = m_host_parameters.allocate_by_datatype(data_type, 0);
std::unique_ptr<runtime::gpu::primitive> softmax;
switch (direction)
{
case Prop::Forward:
case Prop::Inference:
{
softmax.reset(new gpu::primitive{[=, &tensor_desc](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnSoftmaxForward(*m_ctx->cudnn_handle,
algorithm,
mode,
alpha,
tensor_desc,
inputs[0],
beta,
tensor_desc,
outputs[0]));
debug_sync();
}});
break;
}
case Prop::Backward:
{
softmax.reset(new gpu::primitive{[=, &tensor_desc](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnSoftmaxBackward(*m_ctx->cudnn_handle,
algorithm,
mode,
alpha,
tensor_desc,
inputs[0],
tensor_desc,
inputs[1],
beta,
tensor_desc,
outputs[0]));
debug_sync();
}});
break;
}
}
return this->m_primitive_emitter->register_primitive(softmax, hash);
}
void runtime::gpu::CUDNNEmitter::sync()
{
CUDA_RT_SAFE_CALL(cudaDeviceSynchronize());
......
......@@ -155,12 +155,6 @@ namespace ngraph
const double lrn_bias,
const size_t lrn_size);
size_t build_softmax(const cudnnSoftmaxAlgorithm_t& algorithm,
const cudnnSoftmaxMode_t& mode,
const std::string& dtype,
const Prop& direction,
const Shape& tensor_shape);
void debug_sync();
void sync();
......
......@@ -186,11 +186,12 @@ namespace ngraph
int sm_tile_size = 8,
int reg_tile_size = 1);
static void get_softmax_divide_op(codegen::CodeWriter& writer,
const std::string& name,
const std::vector<std::string>& data_types,
std::vector<size_t> axes_flag,
size_t rank);
static void get_softmax_op(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args,
const std::vector<std::string>& data_types,
size_t out_rank,
size_t reduce_rank);
static void add_pod_typedefs(codegen::CodeWriter& writer);
......
......@@ -1528,23 +1528,17 @@ void runtime::gpu::GPU_Emitter::emit_Softmax(EMIT_ARGS)
auto softmax = static_cast<const ngraph::op::Softmax*>(node);
writer.block_begin();
{
size_t index;
if (softmax->get_axes().size() != args[0].get_shape().size())
auto axes_set = softmax->get_axes();
ngraph::AxisVector axes_vec;
for (auto a : axes_set)
{
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
index = cuda_emitter->build_primitive(softmax);
}
else
{
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
index = cudnn_emitter->build_softmax(CUDNN_SOFTMAX_FAST,
CUDNN_SOFTMAX_MODE_INSTANCE,
out[0].get_type(),
CUDNNEmitter::Prop::Forward,
args[0].get_shape());
axes_vec.push_back(a);
}
std::vector<string> dtypes;
dtypes.push_back(args[0].get_type());
dtypes.push_back(out[0].get_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
size_t index = cuda_emitter->build_softmax(dtypes, args[0].get_shape(), axes_vec);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
......
......@@ -4187,6 +4187,34 @@ NGRAPH_TEST(${BACKEND_NAME}, softmax_underflow)
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, softmax_overflow)
{
Shape shape{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Softmax>(A, AxisSet{0}), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto high = std::numeric_limits<float>::max();
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{high, 1, 2, 3, 4, 5});
auto result = backend->create_tensor(element::f32, shape);
auto d0 = expf(high - high) + expf(3 - high);
auto d1 = expf(1) + expf(4);
auto d2 = expf(2) + expf(5);
backend->call_with_validate(f, {result}, {a});
vector<float> expected{expf(high - high) / d0,
expf(1) / d1,
expf(2) / d2,
expf(3 - high) / d0,
expf(4) / d1,
expf(5) / d2};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, multiple_backends)
{
Shape shape{2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment