Commit 57d58e50 authored by Fenglei's avatar Fenglei Committed by Scott Cyphers

add gpu concat op (#931)

* add concat op

* change to concat

* add more code for gpu concat

* compile sucess version with bug

* add emit_concat_op

* runable with wrong result

* working version

* add some comments

* delete old comments.

* delete old comments.

* remove bug doxyen comments
parent 443b51b7
......@@ -135,6 +135,47 @@ void runtime::gpu::CudaKernelBuilder::get_reshape_op(codegen::CodeWriter& writer
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_concat_op(codegen::CodeWriter& writer,
const std::string& name,
const std::vector<std::string>& data_types,
size_t num_inputs)
{
writer << "extern \"C\" __global__ void cuda_" << name << "(";
for (size_t i = 0; i < num_inputs; i++)
{
writer << data_types[i] << "* in" << i << ", ";
}
writer << data_types[num_inputs]
<< "* out, size_t* block_strides, size_t block_size, size_t n)\n";
writer.block_begin();
{
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if(tid < n)\n";
writer.block_begin();
{
writer << "out[tid] = 1;\n";
writer << "size_t idx_out = tid;\n";
writer << "size_t block_id = tid / block_size;\n";
writer << "size_t block_idx = tid % block_size;\n";
writer << "bool processed = false;\n";
for (size_t i = 0; i < num_inputs; i++)
{
writer << "if(!processed && (block_idx < block_strides[" << i << "]))\n";
writer.block_begin();
{
writer << "out[idx_out] = in" << i << "[block_id * block_strides[" << i
<< "] + block_idx];";
writer << "processed = true;\n";
}
writer.block_end();
writer << "block_idx -= block_strides[" << i << "];\n";
}
}
writer.block_end();
}
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_slice_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types)
......
......@@ -43,6 +43,11 @@ namespace ngraph
const std::string& name,
const std::array<std::string, 2>& data_types);
static void get_concat_op(codegen::CodeWriter& writer,
const std::string& name,
const std::vector<std::string>& data_types,
size_t num_inputs);
static void get_onehot_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types);
......
......@@ -120,7 +120,6 @@ namespace ngraph
compiled_kernel = ctx->compiled_kernel_pool->set(name + type_signature, kernel);
}
//convert runtime ptr to driver api ptr
void* args_list[] = {&inputs..., &out, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
count,
......@@ -135,6 +134,46 @@ namespace ngraph
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
template <typename... Inputs>
void emit_concat_op(const std::string& name,
const std::vector<std::string>& data_types,
GPURuntimeContext* ctx,
size_t count,
size_t block_size,
CUdeviceptr block_strides,
CUdeviceptr out,
Inputs&&... inputs)
{
std::string type_signature = "_" + join(data_types, "_");
std::replace(type_signature.begin(), type_signature.end(), ' ', '_');
auto compiled_kernel = ctx->compiled_kernel_pool->get(name + type_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_concat_op(
writer, name + type_signature, data_types, sizeof...(inputs));
std::string kernel = writer.get_code();
compiled_kernel = ctx->compiled_kernel_pool->set(name + type_signature, kernel);
}
void* args_list[] = {&inputs..., &out, &block_strides, &block_size, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
count,
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
}
}
}
......@@ -787,6 +787,58 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
}
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Concat)
{
if (out[0].get_size() == 0)
{
return;
}
auto concat = static_cast<const ngraph::op::Concat*>(node);
auto axis = concat->get_concatenation_axis();
std::vector<size_t> block_strides(args.size(), 1);
size_t block_size = 0;
for (size_t i = 0; i < args.size(); i++)
{
auto arg_shape = args[i].get_shape();
auto arg_rank = arg_shape.size();
for (size_t j = axis; j < arg_rank; j++)
{
block_strides[i] *= arg_shape[j];
}
block_size += block_strides[i];
}
writer.block_begin(" // " + node->get_name());
writer << "int count = " << out[0].get_size() << ";\n";
writer << "int num_inputs = " << args.size() << ";\n";
writer << "std::vector<size_t> block_strides_h = {" << join(block_strides)
<< "};\n";
writer << "void* block_strides_d = "
"runtime::gpu::create_gpu_buffer(sizeof(size_t) * num_inputs);\n";
writer << "runtime::gpu::cuda_memcpyHtD(block_strides_d, block_strides_h.data(), "
"sizeof(size_t) * num_inputs);\n";
writer << "ngraph::runtime::gpu::emit_concat_op(\"" << node->description() << "\""
<< ", std::vector<std::string>{";
for (size_t i = 0; i < args.size(); i++)
{
writer << "\"" << args[i].get_type() << "\", ";
}
writer << "\"" << out[0].get_type() << "\"}"
<< ", ctx"
<< ", count"
<< ", " << block_size << ", CUdeviceptr(block_strides_d)"
<< ", CUdeviceptr(" << out[0].get_name() << ")";
for (size_t i = 0; i < args.size(); i++)
{
writer << ", CUdeviceptr(" << args[i].get_name() << ")";
}
writer << ");\n";
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Constant)
{
......
......@@ -38,10 +38,17 @@ void runtime::gpu::kernel::emit_memset(codegen::CodeWriter& writer,
void runtime::gpu::kernel::emit_memcpyDtD(codegen::CodeWriter& writer,
const GPU_TensorViewWrapper& dst,
const GPU_TensorViewWrapper& src)
const GPU_TensorViewWrapper& src,
size_t buffer_size)
{
if (buffer_size == 0)
{
writer << "runtime::gpu::cuda_memcpyDtD(" << dst.get_name() << ", " << src.get_name()
<< ", " << dst.get_size() << " * " << dst.get_element_type().size() << ");\n";
return;
}
writer << "runtime::gpu::cuda_memcpyDtD(" << dst.get_name() << ", " << src.get_name() << ", "
<< dst.get_size() << " * " << dst.get_element_type().size() << ");\n";
<< buffer_size << ");\n";
return;
}
......
......@@ -36,7 +36,8 @@ namespace ngraph
void emit_memcpyDtD(codegen::CodeWriter& writer,
const GPU_TensorViewWrapper& dst,
const GPU_TensorViewWrapper& src);
const GPU_TensorViewWrapper& src,
size_t buffer_size = 0);
void emit_cudnnConvolutionDescriptor(codegen::CodeWriter& writer,
const std::string& name,
......
......@@ -393,7 +393,6 @@ TEST(${BACKEND_NAME}, ceiling)
TEST(${BACKEND_NAME}, concat_matrix_colwise)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{2, 3};
......@@ -422,7 +421,6 @@ TEST(${BACKEND_NAME}, concat_matrix_colwise)
TEST(${BACKEND_NAME}, concat_matrix_rowwise)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{3, 2};
......@@ -480,7 +478,6 @@ TEST(${BACKEND_NAME}, concat_matrix_int64)
TEST(${BACKEND_NAME}, concat_vector)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{4};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{6};
......@@ -508,7 +505,6 @@ TEST(${BACKEND_NAME}, concat_vector)
TEST(${BACKEND_NAME}, concat_4d_tensor)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{1, 1, 1, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
......@@ -534,7 +530,6 @@ TEST(${BACKEND_NAME}, concat_4d_tensor)
TEST(${BACKEND_NAME}, concat_2d_tensor)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{1, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
......@@ -605,7 +600,6 @@ TEST(${BACKEND_NAME}, concat_2d_tensor)
// 2069. 2070. 2071. 2072.]
TEST(${BACKEND_NAME}, concat_5d)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
vector<float> a_data(2 * 3 * 4 * 3 * 2);
for (int i = 0; i < 2 * 3 * 4 * 3 * 2; i++)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment