Commit 31210402 authored by Chris Sullivan's avatar Chris Sullivan Committed by Scott Cyphers

Bind cuda context to thread prior to compilation (#2199)

* Bind cuda context to thread prior to compilation. Small refactoring.

* bind_cuda_context_to_thread in source

* bind_cuda_context_to_thread header
parent ec0a3f5c
...@@ -67,7 +67,7 @@ runtime::gpu::GPU_Backend::BackendContext::BackendContext() ...@@ -67,7 +67,7 @@ runtime::gpu::GPU_Backend::BackendContext::BackendContext()
// Create context use driver API and make it current, the runtime call will pickup the context // Create context use driver API and make it current, the runtime call will pickup the context
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html // http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
// #interoperability-between-runtime-and-driver-apis // #interoperability-between-runtime-and-driver-apis
m_cuda_manager->SetContextCurrent(); bind_cuda_context_to_thread();
m_runtime_context->cublas_handle = new cublasHandle_t; m_runtime_context->cublas_handle = new cublasHandle_t;
cublasStatus_t cublasStatus = cublasCreate(m_runtime_context->cublas_handle); cublasStatus_t cublasStatus = cublasCreate(m_runtime_context->cublas_handle);
...@@ -91,13 +91,18 @@ runtime::gpu::GPU_Backend::BackendContext::BackendContext() ...@@ -91,13 +91,18 @@ runtime::gpu::GPU_Backend::BackendContext::BackendContext()
void runtime::gpu::GPU_Backend::BackendContext::prepare_runtime_context() void runtime::gpu::GPU_Backend::BackendContext::prepare_runtime_context()
{ {
//set context current each time in case thread changed // set context current each time in case thread changed
m_cuda_manager->SetContextCurrent(); bind_cuda_context_to_thread();
// add pointers to gpu primitives into the gpu runtime context // add pointers to gpu primitives into the gpu runtime context
m_runtime_context->gpu_primitives = m_primitive_emitter->get_primitives().data(); m_runtime_context->gpu_primitives = m_primitive_emitter->get_primitives().data();
m_runtime_context->gpu_memory_primitives = m_primitive_emitter->get_memory_primitives().data(); m_runtime_context->gpu_memory_primitives = m_primitive_emitter->get_memory_primitives().data();
} }
void runtime::gpu::GPU_Backend::BackendContext::bind_cuda_context_to_thread()
{
m_cuda_manager->SetContextCurrent();
}
runtime::gpu::GPU_Backend::BackendContext::~BackendContext() runtime::gpu::GPU_Backend::BackendContext::~BackendContext()
{ {
cublasDestroy(*m_runtime_context->cublas_handle); cublasDestroy(*m_runtime_context->cublas_handle);
...@@ -124,6 +129,7 @@ runtime::Handle runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func) ...@@ -124,6 +129,7 @@ runtime::Handle runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func)
FunctionInstance& instance = m_function_map[func]; FunctionInstance& instance = m_function_map[func];
if (instance.m_external_function == nullptr) if (instance.m_external_function == nullptr)
{ {
m_context->bind_cuda_context_to_thread();
instance.m_external_function = make_shared<GPU_ExternalFunction>(func, m_context); instance.m_external_function = make_shared<GPU_ExternalFunction>(func, m_context);
instance.m_external_function->m_emit_timing = instance.m_performance_counters_enabled; instance.m_external_function->m_emit_timing = instance.m_performance_counters_enabled;
instance.m_external_function->compile(); instance.m_external_function->compile();
......
...@@ -70,6 +70,7 @@ namespace ngraph ...@@ -70,6 +70,7 @@ namespace ngraph
BackendContext(); BackendContext();
~BackendContext(); ~BackendContext();
void prepare_runtime_context(); void prepare_runtime_context();
void bind_cuda_context_to_thread();
std::unique_ptr<GPURuntimeContext> m_runtime_context; std::unique_ptr<GPURuntimeContext> m_runtime_context;
std::unique_ptr<GPUPrimitiveEmitter> m_primitive_emitter; std::unique_ptr<GPUPrimitiveEmitter> m_primitive_emitter;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment