Commit f89a679a authored by fenglei.tian's avatar fenglei.tian

fix name bug and apply clang format

parent ca07b7f0
......@@ -30,10 +30,10 @@ namespace ngraph
class Cuda_function_builder
{
public:
static std::shared_ptr<CUfunction> Get(std::string& kernel,
std::string& name,
int number_of_options,
const char** options)
static std::shared_ptr<CUfunction> Get(const std::string& name,
const std::string& kernel,
int number_of_options,
const char** options)
{
nvrtcProgram prog;
NVRTC_SAFE_CALL(nvrtcCreateProgram(&prog,
......
......@@ -26,35 +26,35 @@ namespace ngraph
{
class Cuda_kernel_builder
{
public:
public:
static void Get_1_element_op(const std::string& name,
const std::string& data_type,
const std::string& op,
std::string& kernel)
const std::string& data_type,
const std::string& op,
std::string& kernel)
{
kernel = R"(
extern "C" __global__
void cuda_op_)" + name + "(" +
data_type + "* in, " + data_type + "* out, size_t n)\n" + R"({
void cuda_)" + name + "(" + data_type +
"* in, " + data_type + "* out, size_t n)\n" + R"({
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
out[tid] = " + op + "(in[tid]);\n"
+R"(}
out[tid] =)" + op + "(in[tid]);\n" +
R"(}
})";
return;
}
static void Get_2_element_op(const std::string& name,
const std::string& data_type,
const std::string op,
std::string& kernel)
const std::string& data_type,
const std::string op,
std::string& kernel)
{
kernel = R"(
extern "C" __global__
void cuda_op_)" + name + "(" +
data_type + "* in1, " + data_type + "* in2, " + data_type +
"* out, size_t n)\n" + R"({
void )" + name + "(" + data_type +
"* in1, " + data_type + "* in2, " + data_type + "* out, size_t n)\n" +
R"({
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
......@@ -65,9 +65,9 @@ out[tid] = in1[tid] )" + op + "in2[tid]\n" +
}
static void Get_n_element_op(const std::string& name,
const std::string& data_type,
const std::vector<std::string> ops,
std::string& kernel)
const std::string& data_type,
const std::vector<std::string> ops,
std::string& kernel)
{
kernel = "";
return;
......
......@@ -23,10 +23,10 @@
#include <cudnn_v7.h>
#include <nvrtc.h>
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
namespace ngraph
{
......@@ -49,7 +49,7 @@ namespace ngraph
std::string kernel;
Cuda_kernel_builder::Get_1_element_op(name, "float", "fabsf", kernel);
Cuda_function_pool::Instance().Set(
name, Cuda_function_builder::Get(name, kernel, 2, opts));
name, Cuda_function_builder::Get("cuda_" + name, kernel, 2, opts));
}
//convert runtime ptr to driver api ptr
......
......@@ -23,7 +23,7 @@
if (result != NVRTC_SUCCESS) \
{ \
throw std::runtime_error("\nerror: " #x " failed with error " + \
std::string(nvrtcGetErrorString(result))); \
std::string(nvrtcGetErrorString(result))); \
} \
} while (0)
......@@ -35,7 +35,7 @@
{ \
const char* msg; \
cuGetErrorName(result, &msg); \
throw std::runtime_error("\nerror: " #x " failed with error " + std::string(msg)); \
throw std::runtime_error("\nerror: " #x " failed with error " + std::string(msg)); \
} \
} while (0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment