Commit eba9439b authored by Fenglei's avatar Fenglei Committed by Robert Kimball

gpu element op optimize (#1287)

* move add,mult,min,max,sqrt to elementwise_op, increase op per threads
parent bcd1daa2
......@@ -1117,11 +1117,12 @@ size_t runtime::gpu::CUDAEmitter::build_elementwise_n_to_1(const std::vector<std
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
size_t nthreads = shape_size(tensor_shape);
uint32_t nthreads = static_cast<uint32_t>(shape_size(tensor_shape));
//TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x =
align_to_block_size(static_cast<uint32_t>(nthreads), block_size_x);
uint32_t block_size_x = 512;
int num_SMs;
CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute(&num_SMs, cudaDevAttrMultiProcessorCount, 0));
uint32_t aligned_grid_size_x = fmin(num_SMs * 32, align_to_block_size(nthreads, block_size_x));
// create the launch primitive
std::unique_ptr<gpu::primitive> ew(
......
......@@ -33,14 +33,13 @@ void runtime::gpu::CudaKernelBuilder::get_elementwise_op(codegen::CodeWriter& wr
writer << data_types[i] << "* in" << i << ", ";
}
writer << data_types[num_inputs] << "* out, "
<< "size_t n)\n";
writer << "{\n";
writer.indent++;
<< "uint32_t n)\n";
writer.block_begin();
{
writer << "size_t tid = blockIdx.x * blockDim.x + threadIdx.x; \n";
writer << "if (tid < n)\n";
writer << "{\n";
writer.indent++;
writer << "uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; \n";
writer << "uint32_t step = gridDim.x * blockDim.x; \n";
writer << "for ( ;tid < n; tid += step)\n";
writer.block_begin();
{
writer << "out[tid] = " << op << "(";
for (size_t i = 0; i < num_inputs - 1; i++)
......@@ -49,11 +48,9 @@ void runtime::gpu::CudaKernelBuilder::get_elementwise_op(codegen::CodeWriter& wr
}
writer << "in" << num_inputs - 1 << "[tid]);\n";
}
writer.indent--;
writer << "}\n";
writer.block_end();
}
writer.indent--;
writer << "}\n";
writer.block_end();
return;
}
......
......@@ -156,9 +156,9 @@ static StaticInitializers s_static_initializers;
#define TI(x) type_index(typeid(x))
static const runtime::gpu::OpMap dispatcher{
{TI(ngraph::op::Add), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Add>},
{TI(ngraph::op::Add), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Add>},
{TI(ngraph::op::Dot), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Dot>},
{TI(ngraph::op::Multiply), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Multiply>},
{TI(ngraph::op::Multiply), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Multiply>},
{TI(ngraph::op::Parameter), &runtime::gpu::GPU_Emitter::nop},
{TI(ngraph::op::Abs), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Abs>},
{TI(ngraph::op::Concat), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Concat>},
......@@ -172,8 +172,8 @@ static const runtime::gpu::OpMap dispatcher{
{TI(ngraph::op::Less), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Less>},
{TI(ngraph::op::LessEq), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::LessEq>},
{TI(ngraph::op::Log), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Log>},
{TI(ngraph::op::Maximum), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Maximum>},
{TI(ngraph::op::Minimum), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Minimum>},
{TI(ngraph::op::Maximum), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Maximum>},
{TI(ngraph::op::Minimum), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Minimum>},
{TI(ngraph::op::Negative), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Negative>},
{TI(ngraph::op::NotEqual), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::NotEqual>},
{TI(ngraph::op::Power), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Power>},
......@@ -202,7 +202,7 @@ static const runtime::gpu::OpMap dispatcher{
{TI(ngraph::op::OneHot), &runtime::gpu::GPU_Emitter::emit<ngraph::op::OneHot>},
{TI(ngraph::op::Floor), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Floor>},
{TI(ngraph::op::Ceiling), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Ceiling>},
{TI(ngraph::op::Sqrt), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Sqrt>},
{TI(ngraph::op::Sqrt), &runtime::gpu::GPU_Emitter::emit_elementwise<ngraph::op::Sqrt>},
{TI(ngraph::op::Convolution), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropFilters),
&runtime::gpu::GPU_Emitter::emit<ngraph::op::ConvolutionBackpropFilters>},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment