Commit f89a679a authored by fenglei.tian's avatar fenglei.tian

fix name bug and apply clang format

parent ca07b7f0
...@@ -30,8 +30,8 @@ namespace ngraph ...@@ -30,8 +30,8 @@ namespace ngraph
class Cuda_function_builder class Cuda_function_builder
{ {
public: public:
static std::shared_ptr<CUfunction> Get(std::string& kernel, static std::shared_ptr<CUfunction> Get(const std::string& name,
std::string& name, const std::string& kernel,
int number_of_options, int number_of_options,
const char** options) const char** options)
{ {
......
...@@ -34,13 +34,13 @@ namespace ngraph ...@@ -34,13 +34,13 @@ namespace ngraph
{ {
kernel = R"( kernel = R"(
extern "C" __global__ extern "C" __global__
void cuda_op_)" + name + "(" + void cuda_)" + name + "(" + data_type +
data_type + "* in, " + data_type + "* out, size_t n)\n" + R"({ "* in, " + data_type + "* out, size_t n)\n" + R"({
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n) if(tid < n)
{ {
out[tid] = " + op + "(in[tid]);\n" out[tid] =)" + op + "(in[tid]);\n" +
+R"(} R"(}
})"; })";
return; return;
} }
...@@ -52,9 +52,9 @@ out[tid] = " + op + "(in[tid]);\n" ...@@ -52,9 +52,9 @@ out[tid] = " + op + "(in[tid]);\n"
{ {
kernel = R"( kernel = R"(
extern "C" __global__ extern "C" __global__
void cuda_op_)" + name + "(" + void )" + name + "(" + data_type +
data_type + "* in1, " + data_type + "* in2, " + data_type + "* in1, " + data_type + "* in2, " + data_type + "* out, size_t n)\n" +
"* out, size_t n)\n" + R"({ R"({
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n) if(tid < n)
{ {
......
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
#include <cudnn_v7.h> #include <cudnn_v7.h>
#include <nvrtc.h> #include <nvrtc.h>
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp" #include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp" #include "ngraph/runtime/gpu/gpu_cuda_function_pool.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp" #include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -49,7 +49,7 @@ namespace ngraph ...@@ -49,7 +49,7 @@ namespace ngraph
std::string kernel; std::string kernel;
Cuda_kernel_builder::Get_1_element_op(name, "float", "fabsf", kernel); Cuda_kernel_builder::Get_1_element_op(name, "float", "fabsf", kernel);
Cuda_function_pool::Instance().Set( Cuda_function_pool::Instance().Set(
name, Cuda_function_builder::Get(name, kernel, 2, opts)); name, Cuda_function_builder::Get("cuda_" + name, kernel, 2, opts));
} }
//convert runtime ptr to driver api ptr //convert runtime ptr to driver api ptr
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment