Commit 31210402 authored by Chris Sullivan's avatar Chris Sullivan Committed by Scott Cyphers

Bind cuda context to thread prior to compilation (#2199)

* Bind cuda context to thread prior to compilation. Small refactoring.

* bind_cuda_context_to_thread in source

* bind_cuda_context_to_thread header
parent ec0a3f5c
......@@ -67,7 +67,7 @@ runtime::gpu::GPU_Backend::BackendContext::BackendContext()
// Create context use driver API and make it current, the runtime call will pickup the context
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
// #interoperability-between-runtime-and-driver-apis
m_cuda_manager->SetContextCurrent();
bind_cuda_context_to_thread();
m_runtime_context->cublas_handle = new cublasHandle_t;
cublasStatus_t cublasStatus = cublasCreate(m_runtime_context->cublas_handle);
......@@ -91,13 +91,18 @@ runtime::gpu::GPU_Backend::BackendContext::BackendContext()
void runtime::gpu::GPU_Backend::BackendContext::prepare_runtime_context()
{
//set context current each time in case thread changed
m_cuda_manager->SetContextCurrent();
// set context current each time in case thread changed
bind_cuda_context_to_thread();
// add pointers to gpu primitives into the gpu runtime context
m_runtime_context->gpu_primitives = m_primitive_emitter->get_primitives().data();
m_runtime_context->gpu_memory_primitives = m_primitive_emitter->get_memory_primitives().data();
}
void runtime::gpu::GPU_Backend::BackendContext::bind_cuda_context_to_thread()
{
m_cuda_manager->SetContextCurrent();
}
runtime::gpu::GPU_Backend::BackendContext::~BackendContext()
{
cublasDestroy(*m_runtime_context->cublas_handle);
......@@ -124,6 +129,7 @@ runtime::Handle runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func)
FunctionInstance& instance = m_function_map[func];
if (instance.m_external_function == nullptr)
{
m_context->bind_cuda_context_to_thread();
instance.m_external_function = make_shared<GPU_ExternalFunction>(func, m_context);
instance.m_external_function->m_emit_timing = instance.m_performance_counters_enabled;
instance.m_external_function->compile();
......
......@@ -70,6 +70,7 @@ namespace ngraph
BackendContext();
~BackendContext();
void prepare_runtime_context();
void bind_cuda_context_to_thread();
std::unique_ptr<GPURuntimeContext> m_runtime_context;
std::unique_ptr<GPUPrimitiveEmitter> m_primitive_emitter;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment