Commit e6c3d5e3 authored by Fenglei's avatar Fenglei Committed by Robert Kimball

make cuda call async, add cuda async timer (#1204)

* using async gpu timers

* remove sync for cuda calls, add async gpu stopwatch, add count to timing-detail

* add debug sync

* make timer static

* move timer to runtime context
parent 1a074e5a
......@@ -207,7 +207,7 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const std::array<std::string, 2>& dt
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
debug_sync();
}});
}
else // pad value provided at compile time (static)
......@@ -225,7 +225,7 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const std::array<std::string, 2>& dt
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
debug_sync();
}});
}
......@@ -326,7 +326,7 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const std::array<std::string
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(pad_dynamic));
......@@ -339,7 +339,7 @@ size_t runtime::gpu::CUDAEmitter::build_reshape(const std::array<std::string, 2>
{
auto rank = input_shape.size();
std::stringstream kernel_name;
kernel_name << "slice_" << join(dtypes, "_") << "_r_" << rank;
kernel_name << "reshape_" << join(dtypes, "_") << "_r_" << rank;
std::string hash =
kernel_name.str() + "_i_" + join(input_shape, "_") + "_o_" + join(input_order, "_");
......@@ -500,7 +500,7 @@ size_t runtime::gpu::CUDAEmitter::build_slice(const std::array<std::string, 2>&
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
......@@ -583,7 +583,7 @@ size_t runtime::gpu::CUDAEmitter::build_reverse_sequence(const std::array<std::s
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
......@@ -643,7 +643,7 @@ size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const std::array<std::string
NULL, // shared mem and stream
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(pool));
......@@ -831,8 +831,7 @@ size_t runtime::gpu::CUDAEmitter::build_avg_pool(const std::array<std::string, 2
NULL,
args_list,
0));
CUDA_SAFE_CALL(cuCtxSynchronize());
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(pool));
......@@ -907,7 +906,7 @@ size_t runtime::gpu::CUDAEmitter::build_elementwise_n_to_1(const std::vector<std
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(ew));
......@@ -1043,7 +1042,7 @@ size_t
NULL,
args_list.data(),
0));
CUDA_SAFE_CALL(cuCtxSynchronize());
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(ew_collective));
......@@ -1157,8 +1156,8 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_window(const OpName op_name,
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
debug_sync();
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}});
primitive_index = this->m_primitive_emitter->insert(std::move(f));
......@@ -1296,7 +1295,7 @@ size_t runtime::gpu::CUDAEmitter::build_replace_slice(const std::array<std::stri
NULL,
args_list,
0));
CUDA_SAFE_CALL(cuCtxSynchronize());
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(replace_slice));
......@@ -1407,7 +1406,7 @@ size_t runtime::gpu::CUDAEmitter::build_broadcast(const std::array<std::string,
NULL,
args_list,
0));
CUDA_SAFE_CALL(cuCtxSynchronize());
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(broadcast));
......@@ -1663,7 +1662,7 @@ size_t runtime::gpu::CUDAEmitter::build_convolution(const std::array<std::string
NULL,
args_list,
0));
CUDA_SAFE_CALL(cuCtxSynchronize());
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(conv));
......@@ -1818,3 +1817,17 @@ uint32_t runtime::gpu::CUDAEmitter::align_to_block_size(uint32_t threads, uint32
uint32_t r = (threads + block_size - 1) / block_size;
return r;
}
void runtime::gpu::CUDAEmitter::sync()
{
CUDA_SAFE_CALL(cuCtxSynchronize());
return;
}
void runtime::gpu::CUDAEmitter::debug_sync()
{
#ifdef NGRAPH_DEBUG_ENABLE
CUDA_SAFE_CALL(cuCtxSynchronize());
#endif
return;
}
......@@ -135,6 +135,9 @@ namespace ngraph
GPUShape filter_dilation,
GPUShape output_shape);
void debug_sync();
void sync();
private:
CUDAEmitter(GPUPrimitiveEmitter* emitter, GPURuntimeContext* ctx);
uint32_t align_to_block_size(uint32_t threads, uint32_t block_size);
......
......@@ -190,6 +190,7 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const cudnnReduceTensorO
beta,
output_desc,
outputs[0]));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(reduce));
......@@ -241,6 +242,7 @@ size_t runtime::gpu::CUDNNEmitter::build_tensor_op(const cudnnOpTensorOp_t& tens
beta_dt,
descriptor,
outputs[0]));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(tensor));
......@@ -394,6 +396,7 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution(const std::string& dtype,
beta,
tensor_desc_1,
outputs[0]));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(conv));
......@@ -470,6 +473,7 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_data(
beta,
tensor_desc_1,
outputs[0]));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(conv));
......@@ -548,6 +552,7 @@ size_t runtime::gpu::CUDNNEmitter::build_convolution_backward_filter(
beta,
filter_desc,
outputs[0]));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(conv));
m_primitive_emitter->cache(hash, primitive_index);
......@@ -641,6 +646,7 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_
beta,
output_desc,
outputs[0]));
debug_sync();
}});
break;
}
......@@ -671,6 +677,7 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_
// adjoint of input
input_desc,
outputs[0]));
debug_sync();
}});
break;
}
......@@ -736,6 +743,7 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
inputs[3], // mean
inputs[4], // variance
epsilon));
debug_sync();
}});
break;
}
......@@ -773,6 +781,7 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
epsilon,
NULL,
NULL));
debug_sync();
// convert to biased variance
CUDNN_SAFE_CALL(cudnnOpTensor(*m_ctx->cudnn_handle,
......@@ -786,6 +795,7 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
bias_factor,
derived_param_desc,
outputs[2]));
debug_sync();
}});
break;
}
......@@ -813,6 +823,7 @@ size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const cudnnBatchNormMode_t& b
epsilon,
NULL, // inputs[3 /* mu batch mean*/],
NULL)); // inputs[4 /* 1/sig**2 batch inverse variance*/]);
debug_sync();
}});
break;
}
......@@ -864,6 +875,7 @@ size_t runtime::gpu::CUDNNEmitter::build_softmax(const cudnnSoftmaxAlgorithm_t&
beta,
tensor_desc,
outputs[0]));
debug_sync();
}});
break;
}
......@@ -881,6 +893,7 @@ size_t runtime::gpu::CUDNNEmitter::build_softmax(const cudnnSoftmaxAlgorithm_t&
beta,
tensor_desc,
outputs[0]));
debug_sync();
}});
break;
}
......@@ -890,3 +903,17 @@ size_t runtime::gpu::CUDNNEmitter::build_softmax(const cudnnSoftmaxAlgorithm_t&
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
void runtime::gpu::CUDNNEmitter::sync()
{
CUDA_RT_SAFE_CALL(cudaDeviceSynchronize());
return;
}
void runtime::gpu::CUDNNEmitter::debug_sync()
{
#ifdef NGRAPH_DEBUG_ENABLE
CUDA_RT_SAFE_CALL(cudaDeviceSynchronize());
#endif
return;
}
......@@ -115,6 +115,9 @@ namespace ngraph
const Prop& direction,
const Shape& tensor_shape);
void debug_sync();
void sync();
private:
CUDNNEmitter(GPUPrimitiveEmitter* emitter, GPURuntimeContext* ctx);
......
......@@ -299,6 +299,7 @@ using namespace std;
// to register cleanup handlers. We use it, and not atexit(), because
// atexit() happens too late, when the JIT is no longer alive
m_writer << "void *__dso_handle = 0;\n\n";
m_writer << "static gpu::GPURuntimeContext* m_runtime_context = nullptr;\n";
}
void runtime::gpu::GPU_ExternalFunction::emit_timer_functions()
......@@ -319,7 +320,13 @@ void runtime::gpu::GPU_ExternalFunction::emit_timer_functions()
}
}
}
m_writer << "ngraph::stopwatch timers[" << names.size() << "];\n";
if (m_shared_context->m_runtime_context->stopwatch_pool == nullptr)
{
m_shared_context->m_runtime_context->stopwatch_pool = new StopWatchPool;
}
m_offset = m_shared_context->m_runtime_context->stopwatch_pool->size();
m_shared_context->m_runtime_context->stopwatch_pool->allocate(names.size());
m_writer << "extern \"C\" size_t get_debug_timer_count() { return " << names.size()
<< "; }\n";
m_writer << "extern \"C\" const char* get_debug_timer_name(size_t index)\n";
......@@ -340,13 +347,15 @@ void runtime::gpu::GPU_ExternalFunction::emit_timer_functions()
m_writer << "extern \"C\" const size_t get_debug_timer_microseconds(size_t index)\n";
m_writer.block_begin();
m_writer << "return (index < " << names.size()
<< " ? timers[index].get_total_microseconds() : 0);\n";
<< " ? runtime::gpu::us_stopwatch(m_runtime_context, index + " << m_offset
<< ") : 0);\n";
m_writer.block_end();
m_writer << "extern \"C\" const size_t get_debug_timer_call_count(size_t index)\n";
m_writer.block_begin();
m_writer << "return (index < " << names.size()
<< " ? timers[index].get_call_count() : 0);\n";
<< " ? runtime::gpu::count_stopwatch(m_runtime_context, index + " << m_offset
<< ") : 0);\n";
m_writer.block_end();
m_writer << "\n";
}
......@@ -379,7 +388,7 @@ void runtime::gpu::GPU_ExternalFunction::emit_constant_declarations()
}
m_writer << "\nstatic bool is_constant_mem_ptr_null = true;\n\n";
m_writer << "static void invoke_constant_mem_ptr(gpu::GPURuntimeContext* ctx)\n";
m_writer << "static void invoke_constant_mem_ptr()\n";
m_writer.block_begin();
{
m_writer << "if(is_constant_mem_ptr_null)\n";
......@@ -396,7 +405,7 @@ void runtime::gpu::GPU_ExternalFunction::emit_constant_declarations()
node->get_outputs()[0].get_tensor_view();
m_writer << tv->get_tensor().get_name() << " = reinterpret_cast<"
<< tv->get_tensor().get_element_type().c_type_string()
<< "*>(runtime::gpu::invoke_memory_primitive(ctx, "
<< "*>(runtime::gpu::invoke_memory_primitive(m_runtime_context, "
<< tv->get_tensor().get_name() << "_idx));\n";
}
}
......@@ -539,8 +548,9 @@ void runtime::gpu::GPU_ExternalFunction::emit_functions()
<< "gpu::GPURuntimeContext* ctx)\n";
m_writer.block_begin();
{
m_writer << "m_runtime_context = ctx;\n";
//set constant pointers during the first run
m_writer << "invoke_constant_mem_ptr(ctx);\n";
m_writer << "invoke_constant_mem_ptr();\n";
//alocate temp memory pool
emit_temp_mem_pool_allocation(current_function);
......@@ -746,7 +756,8 @@ void runtime::gpu::GPU_ExternalFunction::emit_debug_function_entry(Node* node)
{
if (m_emit_timing)
{
m_writer << "timers[" << m_name_index_map[node->get_name()] << "].start();\n";
m_writer << "runtime::gpu::start_stopwatch(ctx, "
<< m_name_index_map[node->get_name()] + m_offset << ");\n";
}
}
......@@ -754,7 +765,8 @@ void runtime::gpu::GPU_ExternalFunction::emit_debug_function_exit(Node* node)
{
if (m_emit_timing)
{
m_writer << "timers[" << m_name_index_map[node->get_name()] << "].stop();\n";
m_writer << "runtime::gpu::stop_stopwatch(ctx, "
<< m_name_index_map[node->get_name()] + m_offset << ");\n";
}
}
......
......@@ -112,6 +112,7 @@ namespace ngraph
bool m_is_compiled;
bool m_release_function;
bool m_temporaries_used;
size_t m_offset;
std::string m_function_name;
std::string m_pch_header_source;
......
......@@ -18,3 +18,21 @@
using namespace ngraph;
using namespace ngraph::runtime::gpu;
extern "C" void ngraph::runtime::gpu::start_stopwatch(GPURuntimeContext* ctx, size_t idx)
{
ctx->stopwatch_pool->get(idx).start();
}
extern "C" void ngraph::runtime::gpu::stop_stopwatch(GPURuntimeContext* ctx, size_t idx)
{
ctx->stopwatch_pool->get(idx).stop();
}
extern "C" size_t ngraph::runtime::gpu::count_stopwatch(GPURuntimeContext* ctx, size_t idx)
{
return ctx->stopwatch_pool->get(idx).get_call_count();
}
extern "C" size_t ngraph::runtime::gpu::us_stopwatch(GPURuntimeContext* ctx, size_t idx)
{
return ctx->stopwatch_pool->get(idx).get_total_microseconds();
}
......@@ -40,6 +40,7 @@ namespace ngraph
gpu::primitive* const* gpu_primitives;
const gpu::memory_primitive* gpu_memory_primitives;
CudaFunctionPool* compiled_kernel_pool;
StopWatchPool* stopwatch_pool = nullptr;
// Note that in it's current state, calling methods of CudaFunctionPool
// or other native compiled C++ functions in ngraph from the JIT code is
// unsafe and will fail if the GLIBCXX versions are diffent for the
......@@ -47,6 +48,11 @@ namespace ngraph
// to use the GPUPrimitiveEmitter, the above pointer can be removed. It is left now
// for backward compatability.
};
void start_stopwatch(GPURuntimeContext* ctx, size_t idx);
void stop_stopwatch(GPURuntimeContext* ctx, size_t idx);
size_t count_stopwatch(GPURuntimeContext* ctx, size_t idx);
size_t us_stopwatch(GPURuntimeContext* ctx, size_t idx);
}
}
}
......
......@@ -192,3 +192,62 @@ uint32_t runtime::gpu::idiv_ceil(int n, int d)
// compiler fused modulo and division
return n / d + (n % d > 0);
}
void runtime::gpu::StopWatch::start()
{
if (m_active == false)
{
m_total_count++;
m_active = true;
cudaEvent_t start;
cudaEventCreate(&start);
cudaEventRecord(start);
starts.push_back(start);
}
}
void runtime::gpu::StopWatch::stop()
{
if (m_active == true)
{
cudaEvent_t stop;
cudaEventCreate(&stop);
cudaEventRecord(stop);
stops.push_back(stop);
m_active = false;
}
}
size_t runtime::gpu::StopWatch::get_call_count()
{
return m_total_count;
}
size_t runtime::gpu::StopWatch::get_total_seconds()
{
return runtime::gpu::StopWatch::get_total_nanoseconds() / 1e9;
}
size_t runtime::gpu::StopWatch::get_total_milliseconds()
{
return runtime::gpu::StopWatch::get_total_nanoseconds() / 1e6;
}
size_t runtime::gpu::StopWatch::get_total_microseconds()
{
return runtime::gpu::StopWatch::get_total_nanoseconds() / 1e3;
}
size_t runtime::gpu::StopWatch::get_total_nanoseconds()
{
//only need to sync the last stop.
cudaEventSynchronize(stops.back());
float total_time = 0;
for (int i = 0; i < stops.size(); i++)
{
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, starts[i], stops[i]);
total_time += milliseconds;
}
m_total_time_in_ns = static_cast<size_t>(total_time * 1000000.0f);
return m_total_time_in_ns;
}
......@@ -127,6 +127,41 @@ namespace ngraph
cuda_memcpyDtH(local.data(), p, size_in_bytes);
std::cout << "{" << ngraph::join(local) << "}" << std::endl;
}
class StopWatch
{
public:
void start();
void stop();
size_t get_call_count();
size_t get_total_seconds();
size_t get_total_milliseconds();
size_t get_total_microseconds();
size_t get_total_nanoseconds();
private:
std::vector<cudaEvent_t> starts;
std::vector<cudaEvent_t> stops;
size_t m_total_count = 0;
size_t m_total_time_in_ns = 0;
bool m_active = false;
};
class StopWatchPool
{
public:
void allocate(size_t num)
{
for (size_t i = 0; i < num; i++)
{
pool.push_back(StopWatch());
}
}
StopWatch& get(size_t idx) { return pool[idx]; }
size_t size() { return pool.size(); }
private:
std::vector<StopWatch> pool;
};
}
}
}
......@@ -45,18 +45,20 @@ multimap<size_t, string>
}
unordered_map<string, size_t> timing;
unordered_map<string, size_t> count;
for (const runtime::PerformanceCounter& p : perf_data)
{
shared_ptr<Node> node = node_map.at(p.name());
string op = p.name().substr(0, p.name().find('_'));
string shape_name = "{" + join(node->get_outputs()[0].get_shape()) + "}";
string shape_name = " {" + join(node->get_outputs()[0].get_shape()) + "} ";
timing[op + shape_name] += p.microseconds();
count[op + shape_name] += 1;
}
multimap<size_t, string> rc;
for (const pair<string, size_t>& t : timing)
{
rc.insert({t.second, t.first});
rc.insert({t.second, t.first + to_string(count[t.first])});
}
return rc;
}
......@@ -253,6 +255,6 @@ void run_benchmark(shared_ptr<Function> f,
cout << "\n---- Aggregate times per op type ----\n";
print_times(timing);
cout << "\n---- Aggregate times per op type/shape ----\n";
cout << "\n---- Aggregate times per op type/shape/count ----\n";
print_times(timing_details);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment