Commit 8009b475 authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

Fix bugs in StaticInitializer and CudaContextManager (#1321)

* Bug fix: StaticInitializer.

* Make CudaContextManager a member of GPU_Backend::BackendContext.

* fix formatting
parent 8ab89b29
......@@ -53,11 +53,12 @@ runtime::gpu::GPU_Backend::GPU_Backend()
runtime::gpu::GPU_Backend::BackendContext::BackendContext()
: m_runtime_context(new GPURuntimeContext)
, m_primitive_emitter(new GPUPrimitiveEmitter(m_runtime_context))
, m_cuda_manager(new CudaContextManager)
{
// Create context use driver API and make it current, the runtime call will pickup the context
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
// #interoperability-between-runtime-and-driver-apis
ngraph::runtime::gpu::CudaContextManager::Instance().SetContextCurrent();
m_cuda_manager->SetContextCurrent();
m_runtime_context->cublas_handle = new cublasHandle_t;
cublasStatus_t cublasStatus = cublasCreate(m_runtime_context->cublas_handle);
......
......@@ -33,6 +33,7 @@ namespace ngraph
class GPU_CallFrame;
class GPUPrimitiveEmitter;
struct GPURuntimeContext;
class CudaContextManager;
class GPU_Backend : public Backend
{
......@@ -71,6 +72,9 @@ namespace ngraph
std::unique_ptr<GPURuntimeContext> m_runtime_context;
std::unique_ptr<GPUPrimitiveEmitter> m_primitive_emitter;
private:
std::unique_ptr<CudaContextManager> m_cuda_manager;
};
private:
......
......@@ -21,12 +21,6 @@
using namespace ngraph;
runtime::gpu::CudaContextManager& runtime::gpu::CudaContextManager::Instance()
{
static CudaContextManager manager;
return manager;
}
runtime::gpu::CudaContextManager::CudaContextManager()
{
CUDA_SAFE_CALL(cuInit(0));
......
......@@ -30,7 +30,9 @@ namespace ngraph
class CudaContextManager
{
public:
static CudaContextManager& Instance();
CudaContextManager();
~CudaContextManager();
CudaContextManager(CudaContextManager const&) = delete;
CudaContextManager(CudaContextManager&&) = delete;
CudaContextManager& operator=(CudaContextManager const&) = delete;
......@@ -39,8 +41,6 @@ namespace ngraph
CUcontext GetContext() { return m_context; }
void SetContextCurrent() { cuCtxSetCurrent(m_context); }
protected:
CudaContextManager();
~CudaContextManager();
CUdevice m_device;
CUcontext m_context;
};
......
......@@ -112,10 +112,10 @@ using namespace ngraph;
static const string s_output_dir = "gpu_codegen";
static std::mutex s_compilation;
class StaticInitializers
class GPUStaticInitializers
{
public:
StaticInitializers()
GPUStaticInitializers()
{
file_util::remove_directory(s_output_dir);
file_util::make_directory(s_output_dir);
......@@ -154,7 +154,7 @@ static string emit_string_array(const vector<string>& s, size_t max_line_length)
return ss.str();
}
static StaticInitializers s_static_initializers;
static GPUStaticInitializers s_static_initializers;
#define TI(x) type_index(typeid(x))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment