Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
5df0e17e
Commit
5df0e17e
authored
Jun 15, 2018
by
Jaikrishnan Menon
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into dex2
parents
c829a9c7
f75b8006
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
92 additions
and
29 deletions
+92
-29
code_writer.hpp
src/ngraph/codegen/code_writer.hpp
+4
-4
compiler.cpp
src/ngraph/codegen/compiler.cpp
+0
-1
cpu_external_function.cpp
src/ngraph/runtime/cpu/cpu_external_function.cpp
+3
-2
cpu_rnn_fusion.cpp
src/ngraph/runtime/cpu/pass/cpu_rnn_fusion.cpp
+0
-0
cpu_rnn_fusion.hpp
src/ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp
+14
-0
cuda_emitter.cpp
src/ngraph/runtime/gpu/cuda_emitter.cpp
+9
-9
gpu_cuda_kernel_builder.cpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.cpp
+7
-7
gpu_cuda_kernel_emitters.cpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_emitters.cpp
+4
-4
gpu_emitter.cpp
src/ngraph/runtime/gpu/gpu_emitter.cpp
+0
-0
gpu_emitter.hpp
src/ngraph/runtime/gpu/gpu_emitter.hpp
+1
-1
gpu_external_function.cpp
src/ngraph/runtime/gpu/gpu_external_function.cpp
+0
-0
gpu_external_function.hpp
src/ngraph/runtime/gpu/gpu_external_function.hpp
+7
-0
unit_test.manifest
src/ngraph/runtime/gpu/unit_test.manifest
+0
-1
cpu_fusion.cpp
test/cpu_fusion.cpp
+43
-0
2rnn_layer_1timestep.json
test/models/mxnet/2rnn_layer_1timestep.json
+0
-0
No files found.
src/ngraph/codegen/code_writer.hpp
View file @
5df0e17e
...
...
@@ -68,16 +68,16 @@ public:
std
::
string
generate_temporary_name
(
std
::
string
prefix
=
"tempvar"
);
void
block_begin
(
std
::
string
block_prefix
=
""
)
void
block_begin
()
{
*
this
<<
"{
"
<<
block_prefix
<<
"
\n
"
;
*
this
<<
"{
\n
"
;
indent
++
;
}
void
block_end
(
std
::
string
block_suffix
=
""
)
void
block_end
()
{
indent
--
;
*
this
<<
"}
"
<<
block_suffix
<<
"
\n
"
;
*
this
<<
"}
\n
"
;
}
private
:
...
...
src/ngraph/codegen/compiler.cpp
View file @
5df0e17e
...
...
@@ -265,7 +265,6 @@ void codegen::StaticCompiler::add_header_search_path(const string& p)
vector
<
string
>
paths
=
split
(
p
,
';'
);
for
(
const
string
&
path
:
paths
)
{
NGRAPH_INFO
<<
path
;
if
(
!
contains
(
m_extra_search_path_list
,
path
))
{
m_extra_search_path_list
.
push_back
(
path
);
...
...
src/ngraph/runtime/cpu/cpu_external_function.cpp
View file @
5df0e17e
...
...
@@ -344,9 +344,10 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager
.
register_pass
<
ngraph
::
pass
::
NopElimination
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
LSTMFusion
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
RNNFusion
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUBatchFusion
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
ConcatInputs
>
();
pass_manager
.
register_pass
<
ngraph
::
pass
::
AlgebraicSimplification
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
MultiLayerRNNFusion
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
ConcatInputs
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUBatchFusion
>
();
pass_manager
.
register_pass
<
ngraph
::
pass
::
CommonSubexpressionElimination
>
();
pass_manager
.
register_pass
<
ngraph
::
pass
::
CoreFusion
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUFusion
>
();
...
...
src/ngraph/runtime/cpu/pass/cpu_rnn_fusion.cpp
View file @
5df0e17e
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp
View file @
5df0e17e
...
...
@@ -29,6 +29,7 @@ namespace ngraph
{
class
LSTMFusion
;
class
RNNFusion
;
class
MultiLayerRNNFusion
;
}
}
}
...
...
@@ -61,3 +62,16 @@ public:
private
:
void
construct_rnn_lstm_fprop
();
};
class
ngraph
::
runtime
::
cpu
::
pass
::
MultiLayerRNNFusion
:
public
ngraph
::
pass
::
RecurrentGraphRewrite
{
public
:
MultiLayerRNNFusion
()
:
RecurrentGraphRewrite
()
{
construct_multi_layer_rnn_fusion_fprop
();
}
private
:
void
construct_multi_layer_rnn_fusion_fprop
();
};
src/ngraph/runtime/gpu/cuda_emitter.cpp
View file @
5df0e17e
...
...
@@ -268,8 +268,8 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const runtime::gpu::GPURunti
compiled_kernel
=
ctx
->
compiled_kernel_pool
->
set
(
kernel_name
.
str
(),
writer
.
get_code
());
}
u
nsigned
int
rank
=
static_cast
<
unsigned
in
t
>
(
input_shape
.
size
());
u
nsigned
int
nthreads
=
static_cast
<
unsigned
in
t
>
(
shape_size
(
input_shape
));
u
int32_t
rank
=
static_cast
<
uint32_
t
>
(
input_shape
.
size
());
u
int32_t
nthreads
=
static_cast
<
uint32_
t
>
(
shape_size
(
input_shape
));
GPUShape
pad_below
(
input_shape
.
size
(),
0
);
GPUShape
pad_interior
(
input_shape
.
size
(),
1
);
...
...
@@ -286,14 +286,14 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const runtime::gpu::GPURunti
// get an allocator for transient per kernel gpu memory
GPUAllocator
allocator
=
this
->
m_primitive_emitter
->
get_memory_allocator
();
size_t
idx_input_strides
=
allocator
.
reserve_argspace
(
input_strides
.
data
(),
input_strides
.
size
()
*
sizeof
(
unsigned
in
t
));
size_t
idx_output_strides
=
allocator
.
reserve_argspace
(
output_strides
.
data
(),
output_strides
.
size
()
*
sizeof
(
unsigned
in
t
));
size_t
idx_input_strides
=
allocator
.
reserve_argspace
(
input_strides
.
data
(),
input_strides
.
size
()
*
sizeof
(
uint32_
t
));
size_t
idx_output_strides
=
allocator
.
reserve_argspace
(
output_strides
.
data
(),
output_strides
.
size
()
*
sizeof
(
uint32_
t
));
size_t
idx_padding_below
=
allocator
.
reserve_argspace
(
pad_below
.
data
(),
pad_below
.
size
()
*
sizeof
(
u
nsigned
in
t
));
allocator
.
reserve_argspace
(
pad_below
.
data
(),
pad_below
.
size
()
*
sizeof
(
u
int32_
t
));
size_t
idx_padding_interior
=
allocator
.
reserve_argspace
(
pad_interior
.
data
(),
pad_interior
.
size
()
*
sizeof
(
u
nsigned
in
t
));
allocator
.
reserve_argspace
(
pad_interior
.
data
(),
pad_interior
.
size
()
*
sizeof
(
u
int32_
t
));
// create the launch primitive
std
::
unique_ptr
<
gpu
::
primitive
>
pad_dynamic
(
new
gpu
::
primitive
{[
=
](
void
**
inputs
,
...
...
@@ -1015,7 +1015,7 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_window(const GPURuntimeContext* c
args_list
[
6
]
=
&
nthreads
;
CUDA_SAFE_CALL
(
cuLaunchKernel
(
*
compiled_kernel
.
get
(),
static_cast
<
u
nsigned
in
t
>
(
nthreads
),
static_cast
<
u
int32_
t
>
(
nthreads
),
1
,
1
,
// grid dim
1
,
...
...
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.cpp
View file @
5df0e17e
...
...
@@ -285,19 +285,19 @@ void runtime::gpu::CudaKernelBuilder::get_pad_dynamic_op(
const
std
::
array
<
std
::
string
,
2
>&
data_types
)
{
writer
<<
"extern
\"
C
\"
__global__ void cuda_"
<<
name
<<
"("
<<
data_types
[
0
]
<<
"* in, "
<<
data_types
[
1
]
<<
"* out, u
nsigned int* input_strides, unsigned in
t* output_strides, "
"u
nsigned int* padding_below, unsigned in
t* "
"padding_interior, u
nsigned int rank, unsigned in
t n)
\n
"
;
<<
data_types
[
1
]
<<
"* out, u
int32_t* input_strides, uint32_
t* output_strides, "
"u
int32_t* padding_below, uint32_
t* "
"padding_interior, u
int32_t rank, uint32_
t n)
\n
"
;
writer
.
block_begin
();
{
writer
<<
"u
nsigned in
t tid = blockIdx.x * blockDim.x + threadIdx.x;
\n
"
;
writer
<<
"u
int32_
t tid = blockIdx.x * blockDim.x + threadIdx.x;
\n
"
;
writer
<<
"if (tid < n)
\n
"
;
writer
.
block_begin
();
{
writer
<<
"u
nsigned in
t output_idx = 0;
\n
"
;
writer
<<
"u
nsigned in
t input_idx = tid;
\n
"
;
writer
<<
"u
int32_
t output_idx = 0;
\n
"
;
writer
<<
"u
int32_
t input_idx = tid;
\n
"
;
writer
<<
"for(u
nsigned in
t i = 0; i < rank; i++)
\n
"
;
writer
<<
"for(u
int32_
t i = 0; i < rank; i++)
\n
"
;
writer
.
block_begin
();
{
writer
<<
"output_idx += (input_idx / input_strides[i] * padding_interior[i] + "
...
...
src/ngraph/runtime/gpu/gpu_cuda_kernel_emitters.cpp
View file @
5df0e17e
...
...
@@ -47,7 +47,7 @@ void runtime::gpu::emit_onehot(const std::string& name,
void
*
args_list
[]
=
{
&
in
,
&
out
,
&
repeat_size
,
&
repeat_times
,
&
count
};
CUDA_SAFE_CALL
(
cuLaunchKernel
(
*
compiled_kernel
.
get
(),
static_cast
<
u
nsigned
in
t
>
(
count
),
static_cast
<
u
int32_
t
>
(
count
),
1
,
1
,
// grid dim
1
,
...
...
@@ -84,7 +84,7 @@ void runtime::gpu::emit_reshape(const std::string& name,
void
*
args_list
[]
=
{
&
in
,
&
out
,
&
input_strides
,
&
trans_strides
,
&
rank
,
&
count
};
CUDA_SAFE_CALL
(
cuLaunchKernel
(
*
compiled_kernel
.
get
(),
static_cast
<
u
nsigned
in
t
>
(
count
),
static_cast
<
u
int32_
t
>
(
count
),
1
,
1
,
// grid dim
1
,
...
...
@@ -124,7 +124,7 @@ void runtime::gpu::emit_slice(const std::string& name,
void
*
args_list
[]
=
{
&
in
,
&
out
,
&
input_strides
,
&
lower_bounds
,
&
slice_strides
,
&
output_strides
,
&
rank
,
&
count
};
CUDA_SAFE_CALL
(
cuLaunchKernel
(
*
compiled_kernel
.
get
(),
static_cast
<
u
nsigned
in
t
>
(
count
),
static_cast
<
u
int32_
t
>
(
count
),
1
,
1
,
// grid dim
1
,
...
...
@@ -161,7 +161,7 @@ void runtime::gpu::emit_reverse(const std::string& name,
void
*
args_list
[]
=
{
&
in
,
&
out
,
&
input_shapes
,
&
reverse_axes
,
&
rank
,
&
count
};
CUDA_SAFE_CALL
(
cuLaunchKernel
(
*
compiled_kernel
.
get
(),
static_cast
<
u
nsigned
in
t
>
(
count
),
static_cast
<
u
int32_
t
>
(
count
),
1
,
1
,
// grid dim
1
,
...
...
src/ngraph/runtime/gpu/gpu_emitter.cpp
View file @
5df0e17e
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/gpu/gpu_emitter.hpp
View file @
5df0e17e
...
...
@@ -77,7 +77,7 @@ namespace ngraph
auto
&
cuda_emitter
=
external_function
->
get_primitive_emitter
()
->
get_cuda_emitter
();
writer
.
block_begin
(
" // "
+
node
->
get_name
()
);
writer
.
block_begin
();
{
std
::
vector
<
std
::
string
>
dtypes
;
for
(
auto
&
arg
:
args
)
...
...
src/ngraph/runtime/gpu/gpu_external_function.cpp
View file @
5df0e17e
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/gpu/gpu_external_function.hpp
View file @
5df0e17e
...
...
@@ -83,6 +83,13 @@ namespace ngraph
const
Node
&
,
const
std
::
unordered_map
<
descriptor
::
TensorView
*
,
std
::
vector
<
size_t
>>&
);
void
release_function
()
{
m_function
=
nullptr
;
}
std
::
string
emit_op_as_function
(
const
Node
&
node
,
const
std
::
string
&
function_name
);
std
::
string
strip_comments
(
const
std
::
string
&
s
)
const
;
bool
is_functionally_identical
(
const
Node
&
n1
,
const
Node
&
n2
,
const
std
::
unordered_map
<
const
Node
*
,
std
::
string
>&
node_cache
)
const
;
std
::
unique_ptr
<
codegen
::
Compiler
>
m_compiler
;
std
::
unique_ptr
<
codegen
::
ExecutionEngine
>
m_execution_engine
;
bool
m_emit_timing
;
...
...
src/ngraph/runtime/gpu/unit_test.manifest
View file @
5df0e17e
...
...
@@ -21,7 +21,6 @@ divide_by_zero_float32
divide_by_zero_int32
dot_4d_5d_multi_axis_big_fp64_VERY_SLOW
dot_matrix_vector_int64
function_call
mkldnn_layouts
numeric_double_nan
numeric_float_inf
...
...
test/cpu_fusion.cpp
View file @
5df0e17e
...
...
@@ -35,6 +35,7 @@
#include "ngraph/op/relu.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
...
...
@@ -2197,3 +2198,45 @@ TEST(cpu_fusion, fuse_batch_dot_forward)
EXPECT_TRUE
(
test
::
all_close
(
cpu_results
.
at
(
i
),
int_results
.
at
(
i
),
1.0e-4
f
,
1.0e-4
f
));
}
}
TEST
(
cpu_fusion
,
fuse_rnn_across_layer
)
{
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
LSTMFusion
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
RNNFusion
>
();
pass_manager
.
register_pass
<
ngraph
::
pass
::
AlgebraicSimplification
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
MultiLayerRNNFusion
>
();
const
string
json_path
=
file_util
::
path_join
(
SERIALIZED_ZOO
,
"mxnet/2rnn_layer_1timestep.json"
);
const
string
json_string
=
file_util
::
read_file_to_string
(
json_path
);
stringstream
ss
(
json_string
);
shared_ptr
<
Function
>
func
=
ngraph
::
deserialize
(
ss
);
pass_manager
.
run_passes
(
func
);
size_t
ref_rnn_count
=
1
;
auto
rnn_count
=
count_ops_of_type
<
op
::
Rnn
>
(
func
);
EXPECT_EQ
(
ref_rnn_count
,
rnn_count
);
}
TEST
(
cpu_fusion
,
fuse_rnn_across_2layer_1timestep
)
{
const
std
::
string
file_name
(
"mxnet/2rnn_layer_1timestep.json"
);
auto
cpu_f
=
make_function
(
file_name
);
auto
int_f
=
make_function
(
file_name
);
test
::
Uniform
<
float
>
rng
(
0.0
f
,
1.0
f
);
vector
<
vector
<
float
>>
args
;
for
(
shared_ptr
<
op
::
Parameter
>
param
:
int_f
->
get_parameters
())
{
vector
<
float
>
tensor_val
(
shape_size
(
param
->
get_shape
()));
rng
.
initialize
(
tensor_val
);
args
.
push_back
(
tensor_val
);
}
auto
int_results
=
execute
(
int_f
,
args
,
"INTERPRETER"
);
auto
cpu_results
=
execute
(
cpu_f
,
args
,
"CPU"
);
EXPECT_EQ
(
1
,
count_ops_of_type
<
op
::
Rnn
>
(
cpu_f
));
for
(
size_t
i
=
0
;
i
<
cpu_results
.
size
();
i
++
)
{
EXPECT_TRUE
(
test
::
all_close
(
cpu_results
.
at
(
1
),
int_results
.
at
(
1
),
1.0e-4
f
,
1.0e-4
f
));
}
}
test/models/mxnet/2rnn_layer_1timestep.json
0 → 100644
View file @
5df0e17e
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment