Commit 58bd00de authored by Fenglei's avatar Fenglei Committed by Scott Cyphers

nvgpu concat split (#1894)

* add split concat

* fix bug

* fix bug

* fix bug

* add test

* fix test bug

* add comments

* format

* return intead of check processed

* remove .back() since it's not vector anymore.

* format

* change to paramter tests based on Geoff's comments

* types-> type

* change split size to 256
parent bbf66498
......@@ -75,18 +75,18 @@ runtime::gpu::CUDAEmitter::CUDAEmitter(runtime::gpu::GPUPrimitiveEmitter* emitte
m_ctx = ctx;
}
size_t runtime::gpu::CUDAEmitter::build_concat(const std::vector<std::string>& dtypes,
size_t runtime::gpu::CUDAEmitter::build_concat(const std::string& dtype,
std::vector<NVShape> input_shapes,
size_t concat_axis,
NVShape output_shape)
{
std::stringstream kernel_name;
size_t input_size = input_shapes.size();
kernel_name << "concat_" << join(dtypes, "_") << "_r_" << input_size;
size_t input_num = input_shapes.size();
kernel_name << "concat_" << dtype << "_r_" << input_num;
std::stringstream hash;
hash << kernel_name.str() << "_o_" << join(output_shape, "_") << "_a_" << concat_axis;
for (size_t i = 0; i < input_size; i++)
for (size_t i = 0; i < input_num; i++)
{
hash << "_i_" << join(input_shapes[i], "_");
}
......@@ -104,64 +104,109 @@ size_t runtime::gpu::CUDAEmitter::build_concat(const std::vector<std::string>& d
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
// recompile the kernel and then create the primutive
size_t split_input_size = 256; //max num of inputs fit 4KB parameter space: 256 * 8 + 7 * ?
size_t residue = input_num % split_input_size;
std::stringstream kernel_name_1;
std::stringstream kernel_name_2;
kernel_name_1 << "concat_" << dtype << "_r_" << split_input_size;
kernel_name_2 << "concat_" << dtype << "_r_" << residue;
auto compiled_kernel_1 = m_ctx->compiled_kernel_pool->get(kernel_name_1.str());
if (compiled_kernel_1 == nullptr && input_num >= split_input_size)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_concat_op(writer, kernel_name.str(), dtypes, input_shapes.size());
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
CudaKernelBuilder::get_concat_op(writer, kernel_name_1.str(), dtype, split_input_size);
compiled_kernel_1 =
m_ctx->compiled_kernel_pool->set(kernel_name_1.str(), writer.get_code());
}
auto compiled_kernel_2 = m_ctx->compiled_kernel_pool->get(kernel_name_2.str());
if (compiled_kernel_2 == nullptr && residue != 0)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_concat_op(writer, kernel_name_2.str(), dtype, residue);
compiled_kernel_2 =
m_ctx->compiled_kernel_pool->set(kernel_name_2.str(), writer.get_code());
}
std::vector<uint32_t> block_strides(input_size, 1);
uint32_t block_size = 0;
for (size_t i = 0; i < input_size; i++)
std::vector<uint32_t> inputs_strides(input_num, 1);
uint32_t output_stride = 0;
for (size_t i = 0; i < input_num; i++)
{
auto arg_rank = input_shapes[i].size();
for (size_t j = concat_axis; j < arg_rank; j++)
{
block_strides[i] *= input_shapes[i][j];
inputs_strides[i] *= input_shapes[i][j];
}
block_size += block_strides[i];
output_stride += inputs_strides[i];
}
uint32_t nthreads = static_cast<uint32_t>(shape_size(output_shape));
// TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);
std::vector<uint32_t> split_nthreads;
std::vector<uint32_t> split_output_strides;
std::vector<uint32_t> split_input_stride_offsets;
std::vector<uint32_t> split_aligned_grid_size_x;
split_input_stride_offsets.push_back(0);
size_t split_input_stride_offset = 0;
for (uint32_t i = 0; i < input_num; i += split_input_size)
{
uint32_t nthread = 0;
uint32_t split_output_stride = 0;
for (uint32_t j = i; j < i + split_input_size && j < input_num; j++)
{
nthread += shape_size(input_shapes[j]);
split_output_stride += inputs_strides[j];
}
split_input_stride_offset += split_output_stride;
split_input_stride_offsets.push_back(split_input_stride_offset);
split_output_strides.push_back(split_output_stride);
split_nthreads.push_back(static_cast<uint32_t>(nthread));
split_aligned_grid_size_x.push_back(
align_to_block_size(split_nthreads.back(), block_size_x));
}
// get an allocator for transient per kernel gpu memory
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t idx_block_strides =
allocator.reserve_argspace(block_strides.data(), block_strides.size() * sizeof(uint32_t));
size_t idx_inputs_strides =
allocator.reserve_argspace(inputs_strides.data(), inputs_strides.size() * sizeof(uint32_t));
// create the launch primitive
std::unique_ptr<gpu::primitive> kernel_launch(new gpu::primitive{[=](void** inputs,
void** outputs) mutable {
void* param_block_strides = runtime::gpu::invoke_memory_primitive(m_ctx, idx_block_strides);
std::vector<void*> args_list;
for (size_t i = 0; i < input_size; i++)
void* param_inputs_strides =
runtime::gpu::invoke_memory_primitive(m_ctx, idx_inputs_strides);
for (uint32_t i = 0, n = 0; i < input_num; i += split_input_size, n++)
{
args_list.push_back(&inputs[i]);
std::vector<void*> args_list;
for (uint32_t j = i; j < i + split_input_size && j < input_num; j++)
{
args_list.push_back(&inputs[j]);
}
args_list.push_back(&outputs[0]);
args_list.push_back(&param_inputs_strides);
args_list.push_back(&output_stride);
args_list.push_back(&split_output_strides[n]);
args_list.push_back(&split_input_stride_offsets[n]);
args_list.push_back(&i);
args_list.push_back(&split_nthreads[n]);
auto compiled_kernel =
(args_list.size() == split_input_size + 7) ? compiled_kernel_1 : compiled_kernel_2;
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
split_aligned_grid_size_x[n],
1,
1, // grid dim
block_size_x,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
debug_sync();
}
args_list.push_back(&outputs[0]);
args_list.push_back(&param_block_strides);
args_list.push_back(&block_size);
args_list.push_back(&nthreads);
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1, // grid dim
block_size_x,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
debug_sync();
}});
return this->m_primitive_emitter->register_primitive(kernel_launch, hash.str());
......
......@@ -174,7 +174,7 @@ namespace ngraph
NVShape input_dilation,
NVDiff input_pad_below);
size_t build_concat(const std::vector<std::string>& dtypes,
size_t build_concat(const std::string& dtype,
std::vector<NVShape> input_shapes,
size_t concat_axis,
NVShape output_shape);
......
......@@ -629,38 +629,39 @@ void runtime::gpu::CudaKernelBuilder::get_reshape_op_3d(codegen::CodeWriter& wri
void runtime::gpu::CudaKernelBuilder::get_concat_op(codegen::CodeWriter& writer,
const std::string& name,
const std::vector<std::string>& data_types,
const std::string& data_type,
size_t num_inputs)
{
writer << "extern \"C\" __global__ void cuda_" << name << "(";
for (size_t i = 0; i < num_inputs; i++)
{
writer << data_types[i] << "* in" << i << ", ";
writer << data_type << "* in" << i << ", ";
}
writer << data_types[num_inputs]
<< "* out, uint32_t* block_strides, uint32_t block_size, uint32_t n)\n";
writer << data_type << "* out, uint32_t* inputs_strides, uint32_t output_stride, uint32_t "
"split_output_stride, uint32_t split_input_stride_offset, uint32_t "
"input_offset, uint32_t n)\n";
writer.block_begin();
{
writer << "uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if(tid < n)\n";
writer.block_begin();
{
writer << "out[tid] = 1;\n";
writer << "uint32_t output_idx = tid;\n";
writer << "uint32_t block_id = tid / block_size;\n";
writer << "uint32_t block_idx = tid % block_size;\n";
writer << "bool processed = false;\n";
writer << "uint32_t block_id = tid / split_output_stride;\n";
writer << "uint32_t block_idx = tid % split_output_stride;\n";
writer << "uint32_t output_idx = block_id * output_stride + block_idx + "
"split_input_stride_offset;\n";
writer << "out[output_idx] = 1;\n";
for (size_t i = 0; i < num_inputs; i++)
{
writer << "if(!processed && (block_idx < block_strides[" << i << "]))\n";
writer << "if(block_idx < inputs_strides[" << i << " + input_offset])\n";
writer.block_begin();
{
writer << "out[output_idx] = in" << i << "[block_id * block_strides[" << i
<< "] + block_idx];";
writer << "processed = true;\n";
writer << "out[output_idx] = in" << i << "[block_id * inputs_strides[" << i
<< " + input_offset] + block_idx];\n";
writer << "return;\n";
}
writer.block_end();
writer << "block_idx -= block_strides[" << i << "];\n";
writer << "block_idx -= inputs_strides[" << i << " + input_offset];\n";
}
}
writer.block_end();
......
......@@ -51,7 +51,7 @@ namespace ngraph
static void get_concat_op(codegen::CodeWriter& writer,
const std::string& name,
const std::vector<std::string>& data_types,
const std::string& data_type,
size_t num_inputs);
static void get_onehot_op(codegen::CodeWriter& writer,
......
......@@ -458,17 +458,15 @@ void runtime::gpu::GPU_Emitter::emit_Concat(EMIT_ARGS)
auto concat = static_cast<const ngraph::op::Concat*>(node);
auto axis = concat->get_concatenation_axis();
vector<string> dtypes;
vector<NVShape> input_shapes;
for (auto arg : args)
{
dtypes.push_back(arg.get_type());
input_shapes.push_back(arg.get_shape());
}
dtypes.push_back(out[0].get_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto index = cuda_emitter->build_concat(dtypes, input_shapes, axis, out[0].get_shape());
auto index =
cuda_emitter->build_concat(out[0].get_type(), input_shapes, axis, out[0].get_shape());
writer.block_begin();
{
......
......@@ -441,6 +441,55 @@ NGRAPH_TEST(${BACKEND_NAME}, concat_matrix_int64)
read_vector<int64_t>(result));
}
// Params to drive concat_vector_large testing variations
class concat_vector_params : public ::testing::TestWithParam<int>
{
protected:
concat_vector_params() { num_inputs = GetParam(); }
uint32_t num_inputs;
};
NGRAPH_TEST_P(${BACKEND_NAME}, concat_vector_params, concat_vector_large)
{
Shape shape_a{1};
NodeVector inputs;
op::ParameterVector inputs_param;
for (uint32_t i = 0; i < num_inputs; i++)
{
auto A = make_shared<op::Parameter>(element::f32, shape_a);
inputs_param.push_back(A);
inputs.push_back(A);
}
Shape shape_r{num_inputs};
auto f = make_shared<Function>(make_shared<op::Concat>(inputs, 0), inputs_param);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
std::vector<std::shared_ptr<runtime::Tensor>> inputs_value;
std::vector<float> ref_result;
for (uint32_t i = 0; i < num_inputs; i++)
{
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{static_cast<float>(i)});
ref_result.push_back(static_cast<float>(i));
inputs_value.push_back(a);
}
auto result = backend->create_tensor(element::f32, shape_r);
backend->call_with_validate(f, {result}, inputs_value);
EXPECT_EQ(ref_result, read_vector<float>(result));
}
// concat_vector_large case generation
// Add thhosw tests to cover paramter space overflow:
// cuda kernel parameter space have limit, if there is large number of parameters,
// there will be overflow for parameter space.
NGRAPH_INSTANTIATE_TEST_CASE_P(${BACKEND_NAME},
input_sizes,
concat_vector_params,
testing::Values(100, 128, 999));
NGRAPH_TEST(${BACKEND_NAME}, concat_vector)
{
Shape shape_a{4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment