Commit 63a233b6 authored by Chris Sullivan's avatar Chris Sullivan Committed by Scott Cyphers

Utilize GPUKernelArgs parameter for ew-collective, nd-conv, replace_slice. (#1346)

* Support GPUKernelArgs in Elementwise-collective and Nd-Convolution.

* Update op::ReplaceSlice to use GPUKernelArgs and unroll coordinate transform loop.

* Formatting.

* Moved function signature for global kernels back to emitter body.

* Formatting.
parent 14019ab9
This diff is collapsed.
......@@ -31,23 +31,19 @@ namespace ngraph
{
namespace gpu
{
class GPUKernelArgs;
class CudaKernelBuilder
{
public:
static void get_kernel_signature(codegen::CodeWriter& writer,
const std::string& name,
const std::string& input_signature)
{
writer << "extern \"C\" __global__ void cuda_" << name;
writer << input_signature;
}
static void get_elementwise_op(codegen::CodeWriter& writer,
const std::string& name,
const std::string& op,
const std::vector<std::string>& data_types);
static void get_broadcast_op(codegen::CodeWriter& writer, const size_t rank);
static void get_broadcast_op(codegen::CodeWriter& writer,
const std::string& name,
GPUKernelArgs& args,
const size_t rank);
static void get_concat_op(codegen::CodeWriter& writer,
const std::string& name,
......@@ -74,8 +70,8 @@ namespace ngraph
static void get_replace_slice_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 3>& data_types,
int nthreads_per_block);
GPUKernelArgs& args,
const size_t rank);
static void get_reduce_window_op(codegen::CodeWriter& writer,
const std::string& name,
......@@ -101,6 +97,7 @@ namespace ngraph
static void get_ew_collective_op(codegen::CodeWriter& writer,
const std::string& name,
GPUKernelArgs& args,
const std::string& op,
const std::string& reduce_op,
const std::vector<std::string>& data_types,
......@@ -124,10 +121,11 @@ namespace ngraph
static void get_convolution_forward(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 3>& data_types,
GPUKernelArgs& args,
int N,
int K,
int filter_size,
int rank,
int filter_size,
int sm_tile_size = 8,
int reg_tile_size = 1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment