Commit 041dd524 authored by Fenglei's avatar Fenglei Committed by Scott Cyphers

gpu slice (#843)

* add slice op, first version

* change size to output size

* fix bugs

* working version

* using exist function for join and strides

* clang format

* revert accidental change
parent b14d5665
......@@ -134,6 +134,37 @@ void runtime::gpu::CudaKernelBuilder::get_reshape_op(codegen::CodeWriter& writer
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_slice_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types)
{
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1] << "* out, size_t* input_strides, size_t* lower_bounds, size_t* "
"slice_strides, size_t* output_strides, size_t rank, size_t n)\n";
writer.block_begin();
{
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < n)\n";
writer.block_begin();
{
writer << "size_t idx_in = 0;\n";
writer << "size_t idx_out = tid;\n";
writer << "for(size_t i = 0; i < rank; i++)\n";
writer.block_begin();
{
writer << "idx_in += (((idx_out / output_strides[i]) * slice_strides[i]) + "
"lower_bounds[i]) * input_strides[i];\n";
writer << "idx_out %= output_strides[i];\n";
}
writer.block_end();
writer << "out[tid] = in[idx_in];\n";
}
writer.block_end();
}
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_device_helper(
codegen::CodeWriter& writer,
const std::string& name,
......
......@@ -51,6 +51,10 @@ namespace ngraph
const std::string& name,
const std::array<std::string, 2>& data_types);
static void get_slice_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types);
static void get_device_helper(codegen::CodeWriter& writer,
const std::string& name,
const std::string& math_kernel,
......
......@@ -128,3 +128,42 @@ void runtime::gpu::emit_reshape(const std::string& name,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
void runtime::gpu::emit_slice(const std::string& name,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
CUdeviceptr input_strides,
CUdeviceptr lower_bounds,
CUdeviceptr slice_strides,
CUdeviceptr output_strides,
size_t rank,
size_t count)
{
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
// Create an instance of nvrtcProgram with the code string.
if (CudaFunctionPool::instance().get(name_signature) == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_slice_op(writer, name_signature, data_types);
std::string kernel = writer.get_code();
CudaFunctionPool::instance().set(name_signature, kernel);
}
void* args_list[] = {
&in, &out, &input_strides, &lower_bounds, &slice_strides, &output_strides, &rank, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*CudaFunctionPool::instance().get(name_signature).get(),
static_cast<unsigned int>(count),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
......@@ -59,6 +59,17 @@ namespace ngraph
size_t rank,
size_t count);
void emit_slice(const std::string& name,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
CUdeviceptr input_strides,
CUdeviceptr lower_bounds,
CUdeviceptr slice_strides,
CUdeviceptr output_strides,
size_t rank,
size_t count);
template <typename T, typename... Inputs>
void emit_elementwise_op(const std::string& name,
const std::array<std::string, 2>& data_types,
......
......@@ -881,6 +881,77 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Slice)
{
if (out[0].get_size() == 0)
{
return;
}
auto slice = static_cast<const op::Slice*>(node);
const auto arg_shape = args[0].get_shape();
const auto arg_rank = arg_shape.size();
const auto result_shape = out[0].get_shape();
const Coordinate& lower_bounds = slice->get_lower_bounds();
const Strides slice_strides = slice->get_strides();
const auto input_strides = row_major_strides(arg_shape);
const auto output_strides = row_major_strides(result_shape);
writer.block_begin(" // " + node->get_name());
if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
else
{
writer << "size_t rank = " << arg_rank << ";\n";
writer << "std::vector<size_t> input_strides_h = {"
<< join(input_strides, "UL,") << "UL};\n";
writer << "std::vector<size_t> output_strides_h = {"
<< join(output_strides, "UL,") << "UL};\n";
writer << "std::vector<size_t> lower_bounds_h = {" << join(lower_bounds, "UL,")
<< "UL};\n";
writer << "std::vector<size_t> slice_strides_h = {"
<< join(slice_strides, "UL,") << "UL};\n";
writer << "void* input_strides_d = "
"runtime::gpu::create_gpu_buffer(sizeof(size_t) * rank);\n";
writer << "void* output_strides_d = "
"runtime::gpu::create_gpu_buffer(sizeof(size_t) * rank);\n";
writer << "void* slice_strides_d = "
"runtime::gpu::create_gpu_buffer(sizeof(size_t) * rank);\n";
writer << "void* lower_bounds_d = "
"runtime::gpu::create_gpu_buffer(sizeof(size_t) * rank);\n";
writer
<< "runtime::gpu::cuda_memcpyHtD(input_strides_d, input_strides_h.data(), "
"sizeof(size_t) * rank);\n";
writer << "runtime::gpu::cuda_memcpyHtD(output_strides_d, "
"output_strides_h.data(), "
"sizeof(size_t) * rank);\n";
writer
<< "runtime::gpu::cuda_memcpyHtD(slice_strides_d, slice_strides_h.data(), "
"sizeof(size_t) * rank);\n";
writer << "runtime::gpu::cuda_memcpyHtD(lower_bounds_d, lower_bounds_h.data(), "
"sizeof(size_t) * rank);\n";
writer << "runtime::gpu::emit_slice(\"" << node->description()
<< "\", CUdeviceptr(" << args[0].get_name() << "), CUdeviceptr("
<< out[0].get_name() << ")"
<< ", {\"" << args[0].get_type() << "\", \"" << out[0].get_type()
<< "\"}"
<< ", "
<< "CUdeviceptr(input_strides_d), CUdeviceptr(lower_bounds_d), "
"CUdeviceptr(slice_strides_d), CUdeviceptr(output_strides_d)"
<< ", " << arg_rank << ", " << out[0].get_size() << ");\n";
writer << "runtime::gpu::free_gpu_buffer(input_strides_d);\n";
writer << "runtime::gpu::free_gpu_buffer(output_strides_d);\n";
writer << "runtime::gpu::free_gpu_buffer(slice_strides_d);\n";
writer << "runtime::gpu::free_gpu_buffer(lower_bounds_d);\n";
}
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::FunctionCall)
{
......
......@@ -2579,7 +2579,6 @@ TEST(${BACKEND_NAME}, exp)
TEST(${BACKEND_NAME}, slice_scalar)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r{};
......@@ -2599,7 +2598,6 @@ TEST(${BACKEND_NAME}, slice_scalar)
TEST(${BACKEND_NAME}, slice_matrix)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{4, 4};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r{3, 2};
......@@ -2619,7 +2617,6 @@ TEST(${BACKEND_NAME}, slice_matrix)
TEST(${BACKEND_NAME}, slice_vector)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{16};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r{12};
......@@ -2639,7 +2636,6 @@ TEST(${BACKEND_NAME}, slice_vector)
TEST(${BACKEND_NAME}, slice_matrix_strided)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
Shape shape_a{4, 4};
......@@ -2661,7 +2657,6 @@ TEST(${BACKEND_NAME}, slice_matrix_strided)
TEST(${BACKEND_NAME}, slice_3d)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_a{4, 4, 4};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r{2, 2, 2};
......@@ -2687,7 +2682,6 @@ TEST(${BACKEND_NAME}, slice_3d)
TEST(${BACKEND_NAME}, slice_3d_strided)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
Shape shape_a{4, 4, 4};
......@@ -2715,7 +2709,6 @@ TEST(${BACKEND_NAME}, slice_3d_strided)
TEST(${BACKEND_NAME}, slice_3d_strided_different_strides)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
Shape shape_a{4, 4, 4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment