Commit af946b7d authored by Fenglei's avatar Fenglei Committed by Scott Cyphers

add gpu reverse (#952)

* add code to gpu reverse

* add reverse emitter and kernel builder

* working versrion
parent a2ba10b5
......@@ -166,6 +166,41 @@ void runtime::gpu::CudaKernelBuilder::get_slice_op(codegen::CodeWriter& writer,
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_reverse_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types)
{
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1]
<< "* out, size_t* input_shape, size_t* reverse_axes, size_t rank, size_t n)\n";
writer.block_begin();
{
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < n)\n";
writer.block_begin();
{
writer << "size_t idx_in = tid;\n";
writer << "size_t idx_out = 0;\n";
writer << "size_t stride = 1;\n";
writer << "for(size_t i = rank; i > 0; i--)\n";
writer.block_begin();
{
writer << "size_t idx = i - 1;\n";
writer << "size_t axes_i_in = idx_in % input_shape[idx];\n";
writer << "idx_in /= input_shape[idx];\n";
writer << "size_t axes_i_out = reverse_axes[idx] ? input_shape[idx] - axes_i_in - "
"1 : axes_i_in;\n";
writer << "idx_out += axes_i_out * stride;\n";
writer << "stride *= input_shape[idx];\n";
}
writer.block_end();
writer << "out[idx_out] = in[tid];\n";
}
writer.block_end();
}
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_device_helper(codegen::CodeWriter& writer,
const std::string& name,
const std::string& math_kernel,
......
......@@ -55,6 +55,10 @@ namespace ngraph
const std::string& name,
const std::array<std::string, 2>& data_types);
static void get_reverse_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types);
static void get_device_helper(codegen::CodeWriter& writer,
const std::string& name,
const std::string& math_kernel,
......
......@@ -173,3 +173,40 @@ void runtime::gpu::emit_slice(const std::string& name,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
void runtime::gpu::emit_reverse(const std::string& name,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
CUdeviceptr input_shapes,
CUdeviceptr reverse_axes,
size_t rank,
size_t count)
{
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
auto compiled_kernel = ctx->compiled_kernel_pool->get(name_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_reverse_op(writer, name_signature, data_types);
std::string kernel = writer.get_code();
compiled_kernel = ctx->compiled_kernel_pool->set(name_signature, kernel);
}
void* args_list[] = {&in, &out, &input_shapes, &reverse_axes, &rank, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(count),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
......@@ -76,6 +76,16 @@ namespace ngraph
size_t rank,
size_t count);
void emit_reverse(const std::string& name,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
CUdeviceptr input_shape,
CUdeviceptr reverse_axes,
size_t rank,
size_t count);
template <typename T, typename... Inputs>
void emit_elementwise_op(const std::string& name,
const std::vector<std::string>& data_types,
......
......@@ -971,6 +971,62 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Reverse)
{
if (out[0].get_size() == 0)
{
return;
}
auto reverse = static_cast<const op::Reverse*>(node);
const auto arg_shape = args[0].get_shape();
const auto arg_rank = arg_shape.size();
const auto result_shape = out[0].get_shape();
const auto reverse_axes = reverse->get_reversed_axes();
std::vector<size_t> reverse_axes_flag(arg_rank, 0);
for (auto a : reverse_axes)
{
reverse_axes_flag[a] = 1;
}
writer.block_begin(" // " + node->get_name());
if (out[0].get_size() == 1)
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
else
{
writer << "size_t rank = " << arg_rank << ";\n";
writer << "std::vector<size_t> input_shapes_h = {" << join(arg_shape, "UL,")
<< "UL};\n";
writer << "std::vector<size_t> reverse_axes_h = {"
<< join(reverse_axes_flag, "UL,") << "UL};\n";
writer << "void* input_shapes_d = "
"runtime::gpu::create_gpu_buffer(sizeof(size_t) * rank);\n";
writer << "void* reverse_axes_d = "
"runtime::gpu::create_gpu_buffer(sizeof(size_t) * rank);\n";
writer << "runtime::gpu::cuda_memcpyHtD(input_shapes_d, input_shapes_h.data(), "
"sizeof(size_t) * rank);\n";
writer << "runtime::gpu::cuda_memcpyHtD(reverse_axes_d, "
"reverse_axes_h.data(), "
"sizeof(size_t) * rank);\n";
writer << "runtime::gpu::emit_reverse(\"" << node->description()
<< "\", CUdeviceptr(" << args[0].get_name() << "), CUdeviceptr("
<< out[0].get_name() << ")"
<< ", {\"" << args[0].get_type() << "\", \"" << out[0].get_type()
<< "\"}"
<< ", "
<< "ctx, "
<< "CUdeviceptr(input_shapes_d), CUdeviceptr(reverse_axes_d), "
<< arg_rank << ", " << out[0].get_size() << ");\n";
writer << "runtime::gpu::free_gpu_buffer(input_shapes_d);\n";
writer << "runtime::gpu::free_gpu_buffer(reverse_axes_d);\n";
}
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::FunctionCall)
{
......
......@@ -4526,7 +4526,6 @@ TEST(${BACKEND_NAME}, not)
TEST(${BACKEND_NAME}, reverse_0d)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Reverse>(A, AxisSet{}), op::ParameterVector{A});
......@@ -4544,7 +4543,6 @@ TEST(${BACKEND_NAME}, reverse_0d)
TEST(${BACKEND_NAME}, reverse_1d_nochange)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{8};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Reverse>(A, AxisSet{}), op::ParameterVector{A});
......@@ -4562,7 +4560,6 @@ TEST(${BACKEND_NAME}, reverse_1d_nochange)
TEST(${BACKEND_NAME}, reverse_1d_0)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{8};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Reverse>(A, AxisSet{0}), op::ParameterVector{A});
......@@ -4580,7 +4577,6 @@ TEST(${BACKEND_NAME}, reverse_1d_0)
TEST(${BACKEND_NAME}, reverse_2d_nochange)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Reverse>(A, AxisSet{}), op::ParameterVector{A});
......@@ -4601,7 +4597,6 @@ TEST(${BACKEND_NAME}, reverse_2d_nochange)
TEST(${BACKEND_NAME}, reverse_2d_0)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Reverse>(A, AxisSet{0}), op::ParameterVector{A});
......@@ -4622,7 +4617,6 @@ TEST(${BACKEND_NAME}, reverse_2d_0)
TEST(${BACKEND_NAME}, reverse_2d_1)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Reverse>(A, AxisSet{1}), op::ParameterVector{A});
......@@ -4643,7 +4637,6 @@ TEST(${BACKEND_NAME}, reverse_2d_1)
TEST(${BACKEND_NAME}, reverse_2d_01)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
......@@ -4665,7 +4658,6 @@ TEST(${BACKEND_NAME}, reverse_2d_01)
TEST(${BACKEND_NAME}, reverse_3d_nochange)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Reverse>(A, AxisSet{}), op::ParameterVector{A});
......@@ -4689,7 +4681,6 @@ TEST(${BACKEND_NAME}, reverse_3d_nochange)
TEST(${BACKEND_NAME}, reverse_3d_0)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Reverse>(A, AxisSet{0}), op::ParameterVector{A});
......@@ -4713,7 +4704,6 @@ TEST(${BACKEND_NAME}, reverse_3d_0)
TEST(${BACKEND_NAME}, reverse_3d_1)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Reverse>(A, AxisSet{1}), op::ParameterVector{A});
......@@ -4737,7 +4727,6 @@ TEST(${BACKEND_NAME}, reverse_3d_1)
TEST(${BACKEND_NAME}, reverse_3d_2)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Reverse>(A, AxisSet{2}), op::ParameterVector{A});
......@@ -4761,7 +4750,6 @@ TEST(${BACKEND_NAME}, reverse_3d_2)
TEST(${BACKEND_NAME}, reverse_3d_01)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
......@@ -4786,7 +4774,6 @@ TEST(${BACKEND_NAME}, reverse_3d_01)
TEST(${BACKEND_NAME}, reverse_3d_02)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
......@@ -4811,7 +4798,6 @@ TEST(${BACKEND_NAME}, reverse_3d_02)
TEST(${BACKEND_NAME}, reverse_3d_12)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
......@@ -4836,7 +4822,6 @@ TEST(${BACKEND_NAME}, reverse_3d_12)
TEST(${BACKEND_NAME}, reverse_3d_012)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Reverse>(A, AxisSet{0, 1, 2}),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment