Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
68eb2e7d
Commit
68eb2e7d
authored
Sep 13, 2018
by
Fenglei
Committed by
Robert Kimball
Sep 13, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
pass args instead of pointer to array (#1591)
parent
309bfdf0
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
33 deletions
+33
-33
cuda_emitter.cpp
src/ngraph/runtime/gpu/cuda_emitter.cpp
+25
-25
gpu_cuda_kernel_builder.cpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.cpp
+7
-8
gpu_cuda_kernel_builder.hpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp
+1
-0
No files found.
src/ngraph/runtime/gpu/cuda_emitter.cpp
View file @
68eb2e7d
...
...
@@ -588,19 +588,6 @@ size_t runtime::gpu::CUDAEmitter::build_reshape(const std::array<std::string, 2>
return
primitive_index
;
}
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto
compiled_kernel
=
m_ctx
->
compiled_kernel_pool
->
get
(
kernel_name
.
str
());
if
(
compiled_kernel
==
nullptr
)
{
codegen
::
CodeWriter
writer
;
CudaKernelBuilder
::
add_pod_typedefs
(
writer
);
CudaKernelBuilder
::
get_reshape_op
(
writer
,
kernel_name
.
str
(),
dtypes
,
rank
);
compiled_kernel
=
m_ctx
->
compiled_kernel_pool
->
set
(
kernel_name
.
str
(),
writer
.
get_code
());
}
uint32_t
nthreads
=
static_cast
<
uint32_t
>
(
shape_size
(
input_shape
));
// TODO: currently we set it to 64, will add tuning method later
uint32_t
block_size_x
=
64
;
...
...
@@ -620,19 +607,32 @@ size_t runtime::gpu::CUDAEmitter::build_reshape(const std::array<std::string, 2>
}
// get an allocator for transient per kernel gpu memory
GPUAllocator
allocator
=
this
->
m_primitive_emitter
->
get_memory_allocator
();
size_t
idx_input_strides
=
allocator
.
reserve_argspace
(
input_strides
.
data
(),
input_strides
.
size
()
*
sizeof
(
uint32_t
));
size_t
idx_trans_strides
=
allocator
.
reserve_argspace
(
trans_strides
.
data
(),
trans_strides
.
size
()
*
sizeof
(
uint32_t
));
auto
args
=
m_primitive_emitter
->
add_kernel_args
();
args
.
add_placeholder
(
dtypes
[
0
],
"in"
)
.
add_placeholder
(
dtypes
[
1
],
"out"
)
.
add
(
"input_strides"
,
input_strides
)
.
add
(
"trans_strides"
,
trans_strides
)
.
add
(
"n"
,
nthreads
);
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto
compiled_kernel
=
m_ctx
->
compiled_kernel_pool
->
get
(
kernel_name
.
str
());
if
(
compiled_kernel
==
nullptr
)
{
codegen
::
CodeWriter
writer
;
CudaKernelBuilder
::
add_pod_typedefs
(
writer
);
CudaKernelBuilder
::
get_reshape_op
(
writer
,
kernel_name
.
str
(),
args
,
dtypes
,
rank
);
compiled_kernel
=
m_ctx
->
compiled_kernel_pool
->
set
(
kernel_name
.
str
(),
writer
.
get_code
());
}
// create the launch primitive
std
::
unique_ptr
<
gpu
::
primitive
>
kernel_launch
(
new
gpu
::
primitive
{[
=
](
void
**
inputs
,
void
**
outputs
)
mutable
{
void
*
param_input_strides
=
runtime
::
gpu
::
invoke_memory_primitive
(
m_ctx
,
idx_input_strides
);
void
*
param_trans_strides
=
runtime
::
gpu
::
invoke_memory_primitive
(
m_ctx
,
idx_trans_strides
);
std
::
vector
<
void
*>
args_list
{
&
inputs
[
0
],
&
outputs
[
0
],
&
param_input_strides
,
&
param_trans_strides
,
&
nthreads
};
std
::
unique_ptr
<
gpu
::
primitive
>
kernel_launch
(
new
gpu
::
primitive
{[
=
](
void
**
inputs
,
void
**
outputs
)
mutable
{
void
**
args_list
=
args
.
resolve_placeholder
(
0
,
&
inputs
[
0
])
.
resolve_placeholder
(
1
,
&
outputs
[
0
])
.
get_argument_list
();
CUDA_SAFE_CALL
(
cuLaunchKernel
(
*
compiled_kernel
.
get
(),
aligned_grid_size_x
,
...
...
@@ -643,7 +643,7 @@ size_t runtime::gpu::CUDAEmitter::build_reshape(const std::array<std::string, 2>
1
,
// block dim
0
,
NULL
,
// shared mem and stream
args_list
.
data
()
,
args_list
,
0
));
// arguments
debug_sync
();
}});
...
...
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.cpp
View file @
68eb2e7d
...
...
@@ -452,12 +452,11 @@ void runtime::gpu::CudaKernelBuilder::get_onehot_op(codegen::CodeWriter& writer,
void
runtime
::
gpu
::
CudaKernelBuilder
::
get_reshape_op
(
codegen
::
CodeWriter
&
writer
,
const
std
::
string
&
name
,
runtime
::
gpu
::
GPUKernelArgs
&
args
,
const
std
::
array
<
std
::
string
,
2
>&
data_types
,
size_t
rank
)
{
writer
<<
"extern
\"
C
\"
__global__ void cuda_"
<<
name
<<
"("
<<
data_types
[
0
]
<<
"* in, "
<<
data_types
[
1
]
<<
"* out, uint32_t* input_strides, uint32_t* trans_strides, uint32_t n)
\n
"
;
writer
<<
"extern
\"
C
\"
__global__ void cuda_"
<<
name
<<
args
.
get_input_signature
();
writer
.
block_begin
();
{
writer
<<
"uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
\n
"
;
...
...
@@ -469,12 +468,12 @@ void runtime::gpu::CudaKernelBuilder::get_reshape_op(codegen::CodeWriter& writer
size_t
i
=
0
;
for
(;
i
<
rank
-
1
;
i
++
)
{
writer
<<
"output_idx += (input_idx / input_strides
["
<<
i
<<
"]) * trans_strides[
"
<<
i
<<
"
]
;
\n
"
;
writer
<<
"input_idx %= input_strides
["
<<
i
<<
"]
;
\n
"
;
writer
<<
"output_idx += (input_idx / input_strides
"
<<
i
<<
") * trans_strides
"
<<
i
<<
";
\n
"
;
writer
<<
"input_idx %= input_strides
"
<<
i
<<
"
;
\n
"
;
}
writer
<<
"output_idx += (input_idx / input_strides
["
<<
i
<<
"]) * trans_strides[
"
<<
i
<<
"
]
;
\n
"
;
writer
<<
"output_idx += (input_idx / input_strides
"
<<
i
<<
") * trans_strides
"
<<
i
<<
";
\n
"
;
writer
<<
"out[output_idx] = in[tid];
\n
"
;
}
writer
.
block_end
();
...
...
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp
View file @
68eb2e7d
...
...
@@ -56,6 +56,7 @@ namespace ngraph
static
void
get_reshape_op
(
codegen
::
CodeWriter
&
writer
,
const
std
::
string
&
name
,
runtime
::
gpu
::
GPUKernelArgs
&
args
,
const
std
::
array
<
std
::
string
,
2
>&
data_types
,
size_t
rank
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment