Commit 776569fd authored by fenglei.tian's avatar fenglei.tian

cleanup code

parent 03c58dc6
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <stdio.h> #include <stdio.h>
#include "ngraph/runtime/gpu/gpu_call_frame.hpp" #include "ngraph/runtime/gpu/gpu_call_frame.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_context_manager.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp" #include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_view.hpp" #include "ngraph/runtime/gpu/gpu_tensor_view.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp" #include "ngraph/runtime/gpu/gpu_util.hpp"
...@@ -31,13 +32,8 @@ runtime::gpu::GPU_CallFrame::GPU_CallFrame(std::shared_ptr<GPU_ExternalFunction> ...@@ -31,13 +32,8 @@ runtime::gpu::GPU_CallFrame::GPU_CallFrame(std::shared_ptr<GPU_ExternalFunction>
: m_external_function(external_function) : m_external_function(external_function)
, m_compiled_function(compiled_function) , m_compiled_function(compiled_function)
{ {
CUdevice cuDevice; ngraph::runtime::gpu::Cuda_context_manager::
CUcontext context; Instance(); //this call will init a cuda context and will use by cublas and cudnn automatically
CUmodule module;
CUfunction cuda_op_abs_kernel;
CUDA_SAFE_CALL(cuInit(0));
CUDA_SAFE_CALL(cuDeviceGet(&cuDevice, 0));
CUDA_SAFE_CALL(cuCtxCreate(&context, 0, cuDevice));
cublasStatus_t cublasStatus = cublasCreate(&m_cublas_handle); cublasStatus_t cublasStatus = cublasCreate(&m_cublas_handle);
if (cublasStatus != CUBLAS_STATUS_SUCCESS) if (cublasStatus != CUBLAS_STATUS_SUCCESS)
{ {
......
...@@ -17,12 +17,6 @@ ...@@ -17,12 +17,6 @@
#include <algorithm> #include <algorithm>
#include <map> #include <map>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudnn_v7.h>
#include <nvrtc.h>
#include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp" #include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp" #include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp" #include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
......
...@@ -16,6 +16,12 @@ ...@@ -16,6 +16,12 @@
#pragma once #pragma once
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudnn_v7.h>
#include <nvrtc.h>
#define NVRTC_SAFE_CALL(x) \ #define NVRTC_SAFE_CALL(x) \
do \ do \
{ \ { \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment