Commit 2ef72022 authored by fenglei.tian's avatar fenglei.tian

more code style

parent 59df2998
...@@ -41,7 +41,7 @@ namespace ngraph ...@@ -41,7 +41,7 @@ namespace ngraph
CudaContextManager& operator=(CudaContextManager const&) = delete; CudaContextManager& operator=(CudaContextManager const&) = delete;
CudaContextManager& operator=(CudaContextManager&&) = delete; CudaContextManager& operator=(CudaContextManager&&) = delete;
std::shared_ptr<CUcontext> GetContext() { return context_ptr; } std::shared_ptr<CUcontext> GetContext() { return m_context_ptr; }
protected: protected:
CudaContextManager() CudaContextManager()
{ {
......
...@@ -24,7 +24,7 @@ namespace ngraph ...@@ -24,7 +24,7 @@ namespace ngraph
{ {
namespace gpu namespace gpu
{ {
class Cuda_kernel_builder class CudaKernelBuilder
{ {
public: public:
static void Get_1_element_op(const std::string& name, static void Get_1_element_op(const std::string& name,
......
...@@ -36,24 +36,24 @@ namespace ngraph ...@@ -36,24 +36,24 @@ namespace ngraph
{ {
std::string name = "abs"; std::string name = "abs";
// Create an instance of nvrtcProgram with the code string. // Create an instance of nvrtcProgram with the code string.
if (Cuda_function_pool::Instance().Get(name) == nullptr) if (CudaFunctionPool::Instance().Get(name) == nullptr)
{ {
const char* opts[] = {"--gpu-architecture=compute_35", const char* opts[] = {"--gpu-architecture=compute_35",
"--relocatable-device-code=true"}; "--relocatable-device-code=true"};
std::string kernel; std::string kernel;
Cuda_kernel_builder::Get_1_element_op(name, "float", "fabsf", kernel); CudaKernelBuilder::Get_1_element_op(name, "float", "fabsf", kernel);
Cuda_function_pool::Instance().Set( CudaFunctionPool::Instance().Set(
name, CudaFunctionBuilder::Get("cuda_" + name, kernel, 2, opts)); name, CudaFunctionBuilder::Get("cuda_" + name, kernel, 2, opts));
} }
//convert runtime ptr to driver api ptr //convert runtime ptr to driver api ptr
CUdeviceptr dPtrIn, dPtrOut; CUdeviceptr d_ptr_in, d_ptr_out;
dPtrIn = (CUdeviceptr)in; d_ptr_in = (CUdeviceptr)in;
dPtrOut = (CUdeviceptr)out; d_ptr_out = (CUdeviceptr)out;
void* argsList[] = {&dPtrIn, &dPtrOut, &count}; void* args_list[] = {&d_ptr_in, &d_ptr_out, &count};
CUDA_SAFE_CALL( CUDA_SAFE_CALL(
cuLaunchKernel(*Cuda_function_pool::Instance().Get(name).get(), cuLaunchKernel(*CudaFunctionPool::Instance().Get(name).get(),
count, count,
1, 1,
1, // grid dim 1, // grid dim
...@@ -62,7 +62,7 @@ namespace ngraph ...@@ -62,7 +62,7 @@ namespace ngraph
1, // block dim 1, // block dim
0, 0,
NULL, // shared mem and stream NULL, // shared mem and stream
argsList, args_list,
0)); // arguments 0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output. CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment