Commit cb84305e authored by Fenglei's avatar Fenglei Committed by Robert Kimball

move onehot and reverse op to cuda_emitter (#1266)

* move to cuda_emiiter

* fix bug, clang format

* size_t to uint32_t

* reverse_axes

* add rank back, clang format

* remove unused code and file

* remove unused code and file

* manually merge with master
parent 167844e4
......@@ -27,7 +27,6 @@ set(SRC
gpu_cuda_function_builder.cpp
gpu_cuda_function_pool.cpp
gpu_cuda_kernel_builder.cpp
gpu_cuda_kernel_emitters.cpp
gpu_emitter.cpp
gpu_external_function.cpp
gpu_invoke.cpp
......
......@@ -163,6 +163,149 @@ size_t runtime::gpu::CUDAEmitter::build_concat(const std::vector<std::string>& d
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_onehot(const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
GPUShape output_shape,
size_t one_hot_axis)
{
std::stringstream kernel_name;
kernel_name << "onehot_" << join(dtypes, "_");
std::string hash = kernel_name.str() + "_i_" + join(input_shape, "_") + "_o_" +
join(output_shape, "_") + std::to_string(one_hot_axis);
// For backwards compatability we currently use two unordered maps
// 1. one looks up the compiled cuda kernel (CudaFunctionPool)
// 2. the other looks to see if this kernel is already in the primitive list
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_onehot_op(writer, kernel_name.str(), dtypes);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
uint32_t nthreads = static_cast<uint32_t>(shape_size(input_shape));
//TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);
uint32_t repeat_times = static_cast<uint32_t>(output_shape[one_hot_axis]);
uint32_t repeat_size = 1;
for (size_t i = one_hot_axis + 1; i < output_shape.size(); i++)
{
repeat_size *= output_shape[i];
}
// create the launch primitive
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
std::vector<void*> args_list{
&inputs[0], &outputs[0], &repeat_size, &repeat_times, &nthreads};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1, // grid dim
block_size_x,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_reverse(const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
std::vector<uint32_t> reverse_axes)
{
uint32_t rank = static_cast<uint32_t>(input_shape.size());
std::stringstream kernel_name;
kernel_name << "reverse_" << join(dtypes, "_");
std::string hash = kernel_name.str() + "_i_" + join(input_shape, "_") + "_axes_" +
join(reverse_axes, "_") + std::to_string(rank);
// For backwards compatability we currently use two unordered maps
// 1. one looks up the compiled cuda kernel (CudaFunctionPool)
// 2. the other looks to see if this kernel is already in the primitive list
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_reverse_op(writer, kernel_name.str(), dtypes);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
uint32_t nthreads = static_cast<uint32_t>(shape_size(input_shape));
//TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);
// get an allocator for transient per kernel gpu memory
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t idx_input_shape =
allocator.reserve_argspace(input_shape.data(), input_shape.size() * sizeof(uint32_t));
size_t idx_reverse_axes =
allocator.reserve_argspace(reverse_axes.data(), reverse_axes.size() * sizeof(uint32_t));
// create the launch primitive
std::unique_ptr<gpu::primitive> kernel_launch(new gpu::primitive{[=](void** inputs,
void** outputs) mutable {
void* param_input_shape = runtime::gpu::invoke_memory_primitive(m_ctx, idx_input_shape);
void* param_reverse_axes = runtime::gpu::invoke_memory_primitive(m_ctx, idx_reverse_axes);
std::vector<void*> args_list{
&inputs[0], &outputs[0], &param_input_shape, &param_reverse_axes, &rank, &nthreads};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1, // grid dim
block_size_x,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_pad(const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
GPUShape output_shape,
......@@ -505,8 +648,8 @@ size_t runtime::gpu::CUDAEmitter::build_reshape(const std::array<std::string, 2>
0,
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
0)); // arguments
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
......
......@@ -86,6 +86,15 @@ namespace ngraph
size_t batch_axis,
size_t sequence_axis);
size_t build_onehot(const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
GPUShape output_shape,
size_t one_hot_axis);
size_t build_reverse(const std::array<std::string, 2>& dtypes,
GPUShape input_shape,
std::vector<uint32_t> reverse_axes);
template <typename T>
size_t build_elementwise(const std::vector<std::string>& dtypes,
GPUShape tensor_shape)
......
......@@ -190,23 +190,19 @@ void runtime::gpu::CudaKernelBuilder::get_onehot_op(codegen::CodeWriter& writer,
const std::array<std::string, 2>& data_types)
{
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1] << "* out, size_t m, size_t k, size_t n)\n";
writer << "{\n";
writer.indent++;
<< data_types[1] << "* out, uint32_t m, uint32_t k, uint32_t n)\n";
writer.block_begin();
{
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < n)\n";
writer << "{\n";
writer.indent++;
writer.block_begin();
{
writer << "size_t idx = (tid / m) * m * k + (m * in[tid]) + tid % m;\n";
writer << "uint32_t idx = (tid / m) * m * k + (m * in[tid]) + tid % m;\n";
writer << "out[idx] = 1;\n";
}
writer.indent--;
writer << "}\n";
writer.block_end();
}
writer.indent--;
writer << "}\n";
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_reshape_op(codegen::CodeWriter& writer,
......@@ -406,24 +402,25 @@ void runtime::gpu::CudaKernelBuilder::get_reverse_op(codegen::CodeWriter& writer
{
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1]
<< "* out, size_t* input_shape, size_t* reverse_axes, size_t rank, size_t n)\n";
<< "* out, uint32_t* input_shape, uint32_t* reverse_axes, uint32_t rank, uint32_t n)\n";
writer.block_begin();
{
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
writer << "if (tid < n)\n";
writer.block_begin();
{
writer << "size_t input_idx = tid;\n";
writer << "size_t output_idx = 0;\n";
writer << "size_t stride = 1;\n";
writer << "for(size_t i = rank; i > 0; i--)\n";
writer << "uint32_t input_idx = tid;\n";
writer << "uint32_t output_idx = 0;\n";
writer << "uint32_t stride = 1;\n";
writer << "for(uint32_t i = rank; i > 0; i--)\n";
writer.block_begin();
{
writer << "size_t idx = i - 1;\n";
writer << "size_t axes_i_in = input_idx % input_shape[idx];\n";
writer << "uint32_t idx = i - 1;\n";
writer << "uint32_t axes_i_in = input_idx % input_shape[idx];\n";
writer << "input_idx /= input_shape[idx];\n";
writer << "size_t axes_i_out = reverse_axes[idx] ? input_shape[idx] - axes_i_in - "
"1 : axes_i_in;\n";
writer
<< "uint32_t axes_i_out = reverse_axes[idx] ? input_shape[idx] - axes_i_in - "
"1 : axes_i_in;\n";
writer << "output_idx += axes_i_out * stride;\n";
writer << "stride *= input_shape[idx];\n";
}
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <map>
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
using namespace ngraph;
using namespace ngraph::runtime::gpu;
void runtime::gpu::emit_onehot(const std::string& name,
std::array<std::string, 2> data_types,
GPURuntimeContext* ctx,
CUdeviceptr in,
CUdeviceptr out,
size_t repeat_size,
size_t repeat_times,
size_t count)
{
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
// Create an instance of nvrtcProgram with the code string.
auto compiled_kernel = ctx->compiled_kernel_pool->get(name_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_onehot_op(writer, name_signature, data_types);
std::string kernel = writer.get_code();
compiled_kernel = ctx->compiled_kernel_pool->set(name_signature, kernel);
}
void* args_list[] = {&in, &out, &repeat_size, &repeat_times, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(count),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
void runtime::gpu::emit_reverse(const std::string& name,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
CUdeviceptr input_shapes,
CUdeviceptr reverse_axes,
size_t rank,
size_t count)
{
std::string name_signature = name + "_" + data_types[0] + "_" + data_types[1];
std::replace(name_signature.begin(), name_signature.end(), ' ', '_');
auto compiled_kernel = ctx->compiled_kernel_pool->get(name_signature);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_reverse_op(writer, name_signature, data_types);
std::string kernel = writer.get_code();
compiled_kernel = ctx->compiled_kernel_pool->set(name_signature, kernel);
}
void* args_list[] = {&in, &out, &input_shapes, &reverse_axes, &rank, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(count),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <array>
#include <string>
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/coordinate.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
template <typename T>
struct CudaOpMap;
void emit_onehot(const std::string& name,
std::array<std::string, 2> data_types,
GPURuntimeContext* ctx,
CUdeviceptr in,
CUdeviceptr out,
size_t repeat_size,
size_t repeat_times,
size_t count);
void emit_reverse(const std::string& name,
CUdeviceptr in,
CUdeviceptr out,
const std::array<std::string, 2>& data_types,
GPURuntimeContext* ctx,
CUdeviceptr input_shape,
CUdeviceptr reverse_axes,
size_t rank,
size_t count);
}
}
}
......@@ -93,7 +93,6 @@
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
#include "ngraph/runtime/gpu/gpu_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
......@@ -1006,7 +1005,7 @@ namespace ngraph
const auto arg_rank = arg_shape.size();
const auto result_shape = out[0].get_shape();
const auto reverse_axes = reverse->get_reversed_axes();
std::vector<size_t> reverse_axes_flag(arg_rank, 0);
std::vector<uint32_t> reverse_axes_flag(arg_rank, 0);
for (auto a : reverse_axes)
{
reverse_axes_flag[a] = 1;
......@@ -1018,30 +1017,15 @@ namespace ngraph
}
else
{
GPUAllocator allocator =
external_function->get_primitive_emitter()->get_memory_allocator();
size_t idx_arg_shape = allocator.reserve_argspace(
arg_shape.data(), arg_shape.size() * sizeof(size_t));
size_t idx_reverse_axes_flag = allocator.reserve_argspace(
reverse_axes_flag.data(), reverse_axes_flag.size() * sizeof(size_t));
writer << "size_t rank = " << arg_rank << ";\n";
writer << "void* input_shapes_d = "
<< " runtime::gpu::invoke_memory_primitive(ctx, " << idx_arg_shape
<< ");\n";
writer << "void* reverse_axes_d = "
<< " runtime::gpu::invoke_memory_primitive(ctx, "
<< idx_reverse_axes_flag << ");\n";
writer << "runtime::gpu::emit_reverse(\"" << node->description()
<< "\", CUdeviceptr(" << args[0].get_name() << "), CUdeviceptr("
<< out[0].get_name() << ")"
<< ", {\"" << args[0].get_type() << "\", \"" << out[0].get_type()
<< "\"}"
<< ", "
<< "ctx, "
<< "CUdeviceptr(input_shapes_d), CUdeviceptr(reverse_axes_d), "
<< arg_rank << ", " << out[0].get_size() << ");\n";
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
auto index = cuda_emitter->build_reverse(
{{args[0].get_type(), out[0].get_type()}}, arg_shape, reverse_axes_flag);
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_end();
}
......@@ -1111,23 +1095,19 @@ namespace ngraph
auto arg_shape = args[0].get_shape();
auto result_shape = out[0].get_shape();
size_t idx = onehot->get_one_hot_axis();
size_t repeat_times = result_shape[idx];
size_t repeat_size = 1;
for (size_t i = idx + 1; i < result_shape.size(); i++)
{
repeat_size *= result_shape[i];
}
writer.block_begin();
writer << "runtime::gpu::cuda_memset(" << out[0].get_name() << ", 0, "
<< out[0].get_size() << " * " << out[0].get_element_type().size() << ");\n";
writer << "runtime::gpu::emit_onehot(\"" << node->description() << "\", {\""
<< args[0].get_type() << "\", \"" << out[0].get_type() << "\"}"
<< ", ctx"
<< ", CUdeviceptr(" << args[0].get_name() << "), CUdeviceptr("
<< out[0].get_name() << ")"
<< ", " << repeat_size << ", " << repeat_times << ", " << args[0].get_size()
<< ");\n";
{
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
auto index = cuda_emitter->build_onehot(
{{args[0].get_type(), out[0].get_type()}}, arg_shape, result_shape, idx);
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_end();
}
......
......@@ -99,7 +99,6 @@
#include "ngraph/op/tanh.hpp"
#include "ngraph/pass/common_function_collection.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
......@@ -279,7 +278,6 @@ void runtime::gpu::GPU_ExternalFunction::emit_header()
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/gpu/cudnn_descriptors.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
#include "ngraph/runtime/gpu/gpu_invoke.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment