Commit 7525d6e1 authored by Fenglei Tian's avatar Fenglei Tian

add broadcast kernel

parent e1b2f54c
...@@ -65,7 +65,50 @@ namespace ngraph ...@@ -65,7 +65,50 @@ namespace ngraph
void emit_broadcast( void emit_broadcast(
void* in, void* out, size_t repeat_size, size_t repeat_times, size_t count) void* in, void* out, size_t repeat_size, size_t repeat_times, size_t count)
{ {
return; std::string name = "broadcast";
// Create an instance of nvrtcProgram with the code string.
if (CudaFunctionPool::instance().get(name) == nullptr)
{
const char* opts[] = {"--gpu-architecture=compute_35",
"--relocatable-device-code=true"};
std::string kernel;
std::string data_type("float");
kernel = R"(
extern "C" __global__
void cuda_)" + name + "(" +
data_type + "* in, " + data_type + "* out, size_t m, size_t k, size_t n)\n" + R"(
{
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
size_t idx = tid / (m * k) * m + tid % m;
out[tid] = in[idx];
}
})";
CudaFunctionPool::instance().set(
name, CudaFunctionBuilder::get("cuda_" + name, kernel, 2, opts));
}
//convert runtime ptr to driver api ptr
CUdeviceptr d_ptr_in, d_ptr_out;
d_ptr_in = (CUdeviceptr)in;
d_ptr_out = (CUdeviceptr)out;
void* args_list[] = {&d_ptr_in, &d_ptr_out, &repeat_size, &repeat_times, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*CudaFunctionPool::instance().get(name).get(),
count,
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment