Commit f89a679a authored by fenglei.tian's avatar fenglei.tian

fix name bug and apply clang format

parent ca07b7f0
...@@ -30,10 +30,10 @@ namespace ngraph ...@@ -30,10 +30,10 @@ namespace ngraph
class Cuda_function_builder class Cuda_function_builder
{ {
public: public:
static std::shared_ptr<CUfunction> Get(std::string& kernel, static std::shared_ptr<CUfunction> Get(const std::string& name,
std::string& name, const std::string& kernel,
int number_of_options, int number_of_options,
const char** options) const char** options)
{ {
nvrtcProgram prog; nvrtcProgram prog;
NVRTC_SAFE_CALL(nvrtcCreateProgram(&prog, NVRTC_SAFE_CALL(nvrtcCreateProgram(&prog,
......
...@@ -26,35 +26,35 @@ namespace ngraph ...@@ -26,35 +26,35 @@ namespace ngraph
{ {
class Cuda_kernel_builder class Cuda_kernel_builder
{ {
public: public:
static void Get_1_element_op(const std::string& name, static void Get_1_element_op(const std::string& name,
const std::string& data_type, const std::string& data_type,
const std::string& op, const std::string& op,
std::string& kernel) std::string& kernel)
{ {
kernel = R"( kernel = R"(
extern "C" __global__ extern "C" __global__
void cuda_op_)" + name + "(" + void cuda_)" + name + "(" + data_type +
data_type + "* in, " + data_type + "* out, size_t n)\n" + R"({ "* in, " + data_type + "* out, size_t n)\n" + R"({
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n) if(tid < n)
{ {
out[tid] = " + op + "(in[tid]);\n" out[tid] =)" + op + "(in[tid]);\n" +
+R"(} R"(}
})"; })";
return; return;
} }
static void Get_2_element_op(const std::string& name, static void Get_2_element_op(const std::string& name,
const std::string& data_type, const std::string& data_type,
const std::string op, const std::string op,
std::string& kernel) std::string& kernel)
{ {
kernel = R"( kernel = R"(
extern "C" __global__ extern "C" __global__
void cuda_op_)" + name + "(" + void )" + name + "(" + data_type +
data_type + "* in1, " + data_type + "* in2, " + data_type + "* in1, " + data_type + "* in2, " + data_type + "* out, size_t n)\n" +
"* out, size_t n)\n" + R"({ R"({
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n) if(tid < n)
{ {
...@@ -65,9 +65,9 @@ out[tid] = in1[tid] )" + op + "in2[tid]\n" + ...@@ -65,9 +65,9 @@ out[tid] = in1[tid] )" + op + "in2[tid]\n" +
} }
static void Get_n_element_op(const std::string& name, static void Get_n_element_op(const std::string& name,
const std::string& data_type, const std::string& data_type,
const std::vector<std::string> ops, const std::vector<std::string> ops,
std::string& kernel) std::string& kernel)
{ {
kernel = ""; kernel = "";
return; return;
......
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
#include <cudnn_v7.h> #include <cudnn_v7.h>
#include <nvrtc.h> #include <nvrtc.h>
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp" #include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp" #include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp" #include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -49,7 +49,7 @@ namespace ngraph ...@@ -49,7 +49,7 @@ namespace ngraph
std::string kernel; std::string kernel;
Cuda_kernel_builder::Get_1_element_op(name, "float", "fabsf", kernel); Cuda_kernel_builder::Get_1_element_op(name, "float", "fabsf", kernel);
Cuda_function_pool::Instance().Set( Cuda_function_pool::Instance().Set(
name, Cuda_function_builder::Get(name, kernel, 2, opts)); name, Cuda_function_builder::Get("cuda_" + name, kernel, 2, opts));
} }
//convert runtime ptr to driver api ptr //convert runtime ptr to driver api ptr
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
if (result != NVRTC_SUCCESS) \ if (result != NVRTC_SUCCESS) \
{ \ { \
throw std::runtime_error("\nerror: " #x " failed with error " + \ throw std::runtime_error("\nerror: " #x " failed with error " + \
std::string(nvrtcGetErrorString(result))); \ std::string(nvrtcGetErrorString(result))); \
} \ } \
} while (0) } while (0)
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
{ \ { \
const char* msg; \ const char* msg; \
cuGetErrorName(result, &msg); \ cuGetErrorName(result, &msg); \
throw std::runtime_error("\nerror: " #x " failed with error " + std::string(msg)); \ throw std::runtime_error("\nerror: " #x " failed with error " + std::string(msg)); \
} \ } \
} while (0) } while (0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment