Commit 8009b475 authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

Fix bugs in StaticInitializer and CudaContextManager (#1321)

* Bug fix: StaticInitializer.

* Make CudaContextManager a member of GPU_Backend::BackendContext.

* fix formatting
parent 8ab89b29
...@@ -53,11 +53,12 @@ runtime::gpu::GPU_Backend::GPU_Backend() ...@@ -53,11 +53,12 @@ runtime::gpu::GPU_Backend::GPU_Backend()
runtime::gpu::GPU_Backend::BackendContext::BackendContext() runtime::gpu::GPU_Backend::BackendContext::BackendContext()
: m_runtime_context(new GPURuntimeContext) : m_runtime_context(new GPURuntimeContext)
, m_primitive_emitter(new GPUPrimitiveEmitter(m_runtime_context)) , m_primitive_emitter(new GPUPrimitiveEmitter(m_runtime_context))
, m_cuda_manager(new CudaContextManager)
{ {
// Create context use driver API and make it current, the runtime call will pickup the context // Create context use driver API and make it current, the runtime call will pickup the context
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html // http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
// #interoperability-between-runtime-and-driver-apis // #interoperability-between-runtime-and-driver-apis
ngraph::runtime::gpu::CudaContextManager::Instance().SetContextCurrent(); m_cuda_manager->SetContextCurrent();
m_runtime_context->cublas_handle = new cublasHandle_t; m_runtime_context->cublas_handle = new cublasHandle_t;
cublasStatus_t cublasStatus = cublasCreate(m_runtime_context->cublas_handle); cublasStatus_t cublasStatus = cublasCreate(m_runtime_context->cublas_handle);
......
...@@ -33,6 +33,7 @@ namespace ngraph ...@@ -33,6 +33,7 @@ namespace ngraph
class GPU_CallFrame; class GPU_CallFrame;
class GPUPrimitiveEmitter; class GPUPrimitiveEmitter;
struct GPURuntimeContext; struct GPURuntimeContext;
class CudaContextManager;
class GPU_Backend : public Backend class GPU_Backend : public Backend
{ {
...@@ -71,6 +72,9 @@ namespace ngraph ...@@ -71,6 +72,9 @@ namespace ngraph
std::unique_ptr<GPURuntimeContext> m_runtime_context; std::unique_ptr<GPURuntimeContext> m_runtime_context;
std::unique_ptr<GPUPrimitiveEmitter> m_primitive_emitter; std::unique_ptr<GPUPrimitiveEmitter> m_primitive_emitter;
private:
std::unique_ptr<CudaContextManager> m_cuda_manager;
}; };
private: private:
......
...@@ -21,12 +21,6 @@ ...@@ -21,12 +21,6 @@
using namespace ngraph; using namespace ngraph;
runtime::gpu::CudaContextManager& runtime::gpu::CudaContextManager::Instance()
{
static CudaContextManager manager;
return manager;
}
runtime::gpu::CudaContextManager::CudaContextManager() runtime::gpu::CudaContextManager::CudaContextManager()
{ {
CUDA_SAFE_CALL(cuInit(0)); CUDA_SAFE_CALL(cuInit(0));
......
...@@ -30,7 +30,9 @@ namespace ngraph ...@@ -30,7 +30,9 @@ namespace ngraph
class CudaContextManager class CudaContextManager
{ {
public: public:
static CudaContextManager& Instance(); CudaContextManager();
~CudaContextManager();
CudaContextManager(CudaContextManager const&) = delete; CudaContextManager(CudaContextManager const&) = delete;
CudaContextManager(CudaContextManager&&) = delete; CudaContextManager(CudaContextManager&&) = delete;
CudaContextManager& operator=(CudaContextManager const&) = delete; CudaContextManager& operator=(CudaContextManager const&) = delete;
...@@ -39,8 +41,6 @@ namespace ngraph ...@@ -39,8 +41,6 @@ namespace ngraph
CUcontext GetContext() { return m_context; } CUcontext GetContext() { return m_context; }
void SetContextCurrent() { cuCtxSetCurrent(m_context); } void SetContextCurrent() { cuCtxSetCurrent(m_context); }
protected: protected:
CudaContextManager();
~CudaContextManager();
CUdevice m_device; CUdevice m_device;
CUcontext m_context; CUcontext m_context;
}; };
......
...@@ -112,10 +112,10 @@ using namespace ngraph; ...@@ -112,10 +112,10 @@ using namespace ngraph;
static const string s_output_dir = "gpu_codegen"; static const string s_output_dir = "gpu_codegen";
static std::mutex s_compilation; static std::mutex s_compilation;
class StaticInitializers class GPUStaticInitializers
{ {
public: public:
StaticInitializers() GPUStaticInitializers()
{ {
file_util::remove_directory(s_output_dir); file_util::remove_directory(s_output_dir);
file_util::make_directory(s_output_dir); file_util::make_directory(s_output_dir);
...@@ -154,7 +154,7 @@ static string emit_string_array(const vector<string>& s, size_t max_line_length) ...@@ -154,7 +154,7 @@ static string emit_string_array(const vector<string>& s, size_t max_line_length)
return ss.str(); return ss.str();
} }
static StaticInitializers s_static_initializers; static GPUStaticInitializers s_static_initializers;
#define TI(x) type_index(typeid(x)) #define TI(x) type_index(typeid(x))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment