Commit 7d29490f authored by Fenglei's avatar Fenglei Committed by Robert Kimball

fix invalid context when run mxnet and nbench (#1047)

parent 5dcd835f
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
using namespace ngraph; using namespace ngraph;
runtime::gpu::CudaContextManager& runtime::gpu::CudaContextManager::instance() runtime::gpu::CudaContextManager& runtime::gpu::CudaContextManager::Instance()
{ {
static CudaContextManager manager; static CudaContextManager manager;
return manager; return manager;
...@@ -32,7 +32,6 @@ runtime::gpu::CudaContextManager::CudaContextManager() ...@@ -32,7 +32,6 @@ runtime::gpu::CudaContextManager::CudaContextManager()
CUDA_SAFE_CALL(cuInit(0)); CUDA_SAFE_CALL(cuInit(0));
CUDA_SAFE_CALL(cuDeviceGet(&m_device, 0)); CUDA_SAFE_CALL(cuDeviceGet(&m_device, 0));
CUDA_SAFE_CALL(cuDevicePrimaryCtxRetain(&m_context, m_device)); CUDA_SAFE_CALL(cuDevicePrimaryCtxRetain(&m_context, m_device));
m_context_ptr = std::make_shared<CUcontext>(m_context);
} }
runtime::gpu::CudaContextManager::~CudaContextManager() runtime::gpu::CudaContextManager::~CudaContextManager()
......
...@@ -30,19 +30,19 @@ namespace ngraph ...@@ -30,19 +30,19 @@ namespace ngraph
class CudaContextManager class CudaContextManager
{ {
public: public:
static CudaContextManager& instance(); static CudaContextManager& Instance();
CudaContextManager(CudaContextManager const&) = delete; CudaContextManager(CudaContextManager const&) = delete;
CudaContextManager(CudaContextManager&&) = delete; CudaContextManager(CudaContextManager&&) = delete;
CudaContextManager& operator=(CudaContextManager const&) = delete; CudaContextManager& operator=(CudaContextManager const&) = delete;
CudaContextManager& operator=(CudaContextManager&&) = delete; CudaContextManager& operator=(CudaContextManager&&) = delete;
std::shared_ptr<CUcontext> get_context() { return m_context_ptr; } CUcontext GetContext() { return m_context; }
void SetContextCurrent() { cuCtxSetCurrent(m_context); }
protected: protected:
CudaContextManager(); CudaContextManager();
~CudaContextManager(); ~CudaContextManager();
CUdevice m_device; CUdevice m_device;
CUcontext m_context; CUcontext m_context;
std::shared_ptr<CUcontext> m_context_ptr;
}; };
} }
} }
......
...@@ -259,7 +259,7 @@ runtime::gpu::GPU_ExternalFunction::GPU_ExternalFunction( ...@@ -259,7 +259,7 @@ runtime::gpu::GPU_ExternalFunction::GPU_ExternalFunction(
// Create context use driver API and make it current, the runtime call will pickup the context // Create context use driver API and make it current, the runtime call will pickup the context
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html // http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
// #interoperability-between-runtime-and-driver-apis // #interoperability-between-runtime-and-driver-apis
ngraph::runtime::gpu::CudaContextManager::instance(); ngraph::runtime::gpu::CudaContextManager::Instance().SetContextCurrent();
m_ctx->compiled_kernel_pool = new CudaFunctionPool; m_ctx->compiled_kernel_pool = new CudaFunctionPool;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment