Commit dcdaf26e authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

Fix launch parameter bug for broadcast and pad. (#1261)

* Broadcast and Pad bug fix.
parent d4349db8
...@@ -98,10 +98,11 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const std::array<std::string, 2>& dt ...@@ -98,10 +98,11 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const std::array<std::string, 2>& dt
return primitive_index; return primitive_index;
} }
uint32_t nthreads = static_cast<uint32_t>(shape_size(output_shape)); size_t nthreads = shape_size(output_shape);
//TODO: currently we set it to 64, will add tuning method later //TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64; uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x); uint32_t aligned_grid_size_x =
align_to_block_size(static_cast<uint32_t>(nthreads), block_size_x);
// if the kernel has not been compiled, build it // if the kernel has not been compiled, build it
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(hash); auto compiled_kernel = m_ctx->compiled_kernel_pool->get(hash);
...@@ -1372,11 +1373,11 @@ size_t runtime::gpu::CUDAEmitter::build_broadcast(const std::array<std::string, ...@@ -1372,11 +1373,11 @@ size_t runtime::gpu::CUDAEmitter::build_broadcast(const std::array<std::string,
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 0.0f;
int nthreads = static_cast<int>(shape_size(result_shape)); size_t nthreads = shape_size(result_shape);
//TODO: currently we set it to 64, will add tuning method later //TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64; uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = uint32_t aligned_grid_size_x =
align_to_block_size(static_cast<uint32_t>(shape_size(result_shape)), block_size_x); align_to_block_size(static_cast<uint32_t>(nthreads), block_size_x);
std::unique_ptr<gpu::primitive> broadcast(new gpu::primitive{[=](void** inputs, std::unique_ptr<gpu::primitive> broadcast(new gpu::primitive{[=](void** inputs,
void** outputs) mutable { void** outputs) mutable {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment