Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
db553dc5
Unverified
Commit
db553dc5
authored
Jun 17, 2018
by
Jayaram Bobba
Committed by
GitHub
Jun 17, 2018
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into jmenon/dex2
parents
97a83346
bdfcf5b4
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
506 additions
and
5 deletions
+506
-5
core_fusion.cpp
src/ngraph/pass/core_fusion.cpp
+215
-0
core_fusion.hpp
src/ngraph/pass/core_fusion.hpp
+2
-0
cse.cpp
src/ngraph/pass/cse.cpp
+13
-0
cuda_emitter.cpp
src/ngraph/runtime/gpu/cuda_emitter.cpp
+90
-0
cuda_emitter.hpp
src/ngraph/runtime/gpu/cuda_emitter.hpp
+8
-0
gpu_cuda_kernel_builder.cpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.cpp
+47
-0
gpu_cuda_kernel_builder.hpp
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp
+7
-0
gpu_emitter.cpp
src/ngraph/runtime/gpu/gpu_emitter.cpp
+33
-0
gpu_external_function.cpp
src/ngraph/runtime/gpu/gpu_external_function.cpp
+3
-0
unit_test.manifest
src/ngraph/runtime/gpu/unit_test.manifest
+0
-5
core_fusion.cpp
test/core_fusion.cpp
+56
-0
cse.cpp
test/cse.cpp
+32
-0
No files found.
src/ngraph/pass/core_fusion.cpp
View file @
db553dc5
...
...
@@ -27,6 +27,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
...
...
@@ -162,3 +163,217 @@ void pass::CoreFusion::construct_folded_batch_norm()
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
bn
,
callback
);
this
->
add_matcher
(
m
);
}
static
bool
is_trivial_convolution
(
std
::
shared_ptr
<
op
::
Convolution
>
conv
,
bool
skip_pad_checks
=
false
)
{
Strides
stride_1
{
1
,
1
};
CoordinateDiff
pad_0
{
0
,
0
};
return
conv
->
get_window_dilation_strides
()
==
stride_1
&&
conv
->
get_data_dilation_strides
()
==
stride_1
&&
(
skip_pad_checks
||
(
conv
->
get_padding_above
()
==
pad_0
&&
conv
->
get_padding_below
()
==
pad_0
));
}
static
bool
are_img_dims_equal
(
Shape
conv_shape
,
Shape
image_shape
)
{
if
(
conv_shape
.
size
()
!=
4
)
{
return
false
;
}
return
conv_shape
[
2
]
==
image_shape
[
0
]
&&
conv_shape
[
3
]
==
image_shape
[
1
];
}
static
size_t
shape_to_index
(
Shape
shape
)
{
if
(
shape
.
size
()
!=
4
)
{
return
0
;
}
const
size_t
HEIGHT_DIM
=
2
;
const
size_t
WIDTH_DIM
=
3
;
if
(
shape
.
at
(
HEIGHT_DIM
)
!=
shape
.
at
(
WIDTH_DIM
))
{
return
0
;
}
switch
(
shape
.
at
(
HEIGHT_DIM
))
{
case
28
:
return
1
;
case
14
:
return
2
;
case
7
:
return
3
;
default
:
return
0
;
}
}
// conv(56w3s1) conv(28w3s2)
// | |
// conv(56w1s1) ==> conv(28w1s1)
// | |
//elt------------56 elt------------pool(28s2)
// | | | |
//conv(28w1s2) conv(28w1s2) conv(28w1s1) conv(28w1s1)
void
pass
::
CoreFusion
::
construct_optimized_strided_conv
()
{
Shape
win_size_1
{
1
,
1
,
1
,
1
};
auto
is_bc
=
ngraph
::
pattern
::
has_class
<
op
::
Broadcast
>
();
auto
data_stride3
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
1
,
1
,
128
,
128
});
auto
weights_stride3
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
win_size_1
);
auto
conv_stride3
=
std
::
make_shared
<
op
::
Convolution
>
(
data_stride3
,
weights_stride3
);
auto
conv_stride3_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
conv_stride3
,
nullptr
,
NodeVector
{
conv_stride3
});
auto
broadcast_w3_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
conv_stride3_label
,
is_bc
);
auto
add_w3
=
std
::
make_shared
<
op
::
Add
>
(
conv_stride3_label
,
broadcast_w3_label
);
auto
relu_w3
=
std
::
make_shared
<
op
::
Relu
>
(
add_w3
);
auto
weights_stride1
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
win_size_1
);
auto
conv_stride1
=
std
::
make_shared
<
op
::
Convolution
>
(
relu_w3
,
weights_stride1
);
auto
conv_stride1_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
conv_stride1
,
nullptr
,
NodeVector
{
conv_stride1
});
auto
broadcast_w1_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
conv_stride1_label
,
is_bc
);
auto
add_w1
=
std
::
make_shared
<
op
::
Add
>
(
conv_stride1_label
,
broadcast_w1_label
);
auto
eltwise_arg_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
conv_stride1
->
get_shape
());
auto
add_two_convs
=
std
::
make_shared
<
op
::
Add
>
(
add_w1
,
eltwise_arg_label
);
auto
relu_two_convs
=
std
::
make_shared
<
op
::
Relu
>
(
add_two_convs
);
auto
eltwise_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
relu_two_convs
,
nullptr
,
NodeVector
{
relu_two_convs
});
auto
weights_eltwise
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
win_size_1
);
auto
eltwise_conv
=
std
::
make_shared
<
op
::
Convolution
>
(
eltwise_label
,
weights_eltwise
);
pattern
::
graph_rewrite_callback
callback
=
[
win_size_1
,
eltwise_label
,
conv_stride1_label
,
conv_stride3_label
,
eltwise_arg_label
,
broadcast_w3_label
,
broadcast_w1_label
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In a callback for construct_conv_skip against "
<<
m
.
get_match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
m_eltwise
=
pattern_map
[
eltwise_label
];
auto
strided_convs
=
m_eltwise
->
get_users
();
if
(
strided_convs
.
size
()
!=
2
)
{
NGRAPH_DEBUG
<<
"Number of users of element wise operation isn't equal to 2"
;
return
false
;
}
Shape
supported_shapes
[]
=
{
Shape
{
56
,
56
},
Shape
{
28
,
28
},
Shape
{
14
,
14
},
Shape
{
7
,
7
}};
Shape
shape_1
{
1
,
1
};
Shape
shape_3
{
3
,
3
};
Strides
stride_2
{
2
,
2
};
Strides
stride_1
{
1
,
1
};
CoordinateDiff
pad_0
{
0
,
0
};
CoordinateDiff
pad_1
{
1
,
1
};
Shape
win_size_3
{
1
,
1
,
3
,
3
};
size_t
sparse_shape_index
=
0
;
NodeVector
sconvs
;
for
(
auto
sc
:
strided_convs
)
{
if
(
sc
->
get_argument
(
0
)
!=
m_eltwise
)
{
NGRAPH_DEBUG
<<
"element-wise isn't data"
;
return
false
;
}
auto
sconv
=
std
::
dynamic_pointer_cast
<
op
::
Convolution
>
(
sc
);
sparse_shape_index
=
shape_to_index
(
sconv
->
get_shape
());
if
(
sparse_shape_index
==
0
)
{
NGRAPH_DEBUG
<<
"Unsupported shape of "
<<
sconv
->
get_name
();
return
false
;
}
if
(
!
are_img_dims_equal
(
sconv
->
get_shape
(),
supported_shapes
[
sparse_shape_index
])
||
!
are_img_dims_equal
(
sconv
->
get_argument
(
1
)
->
get_shape
(),
shape_1
)
||
sconv
->
get_window_movement_strides
()
!=
stride_2
||
!
is_trivial_convolution
(
sconv
))
{
NGRAPH_DEBUG
<<
sconv
->
get_name
()
<<
" and its weights are of the wrong shape (not "
<<
vector_to_string
(
supported_shapes
[
sparse_shape_index
])
<<
" and 1x1) and strides (2x2)"
;
return
false
;
}
sconvs
.
push_back
(
sconv
);
}
const
size_t
full_shape_index
=
sparse_shape_index
-
1
;
auto
m_conv_stride1
=
std
::
dynamic_pointer_cast
<
op
::
Convolution
>
(
pattern_map
[
conv_stride1_label
]);
if
(
!
are_img_dims_equal
(
m_conv_stride1
->
get_shape
(),
supported_shapes
[
full_shape_index
])
||
!
are_img_dims_equal
(
m_conv_stride1
->
get_argument
(
1
)
->
get_shape
(),
win_size_1
)
||
m_conv_stride1
->
get_window_movement_strides
()
!=
stride_1
||
!
is_trivial_convolution
(
m_conv_stride1
))
{
NGRAPH_DEBUG
<<
m_conv_stride1
->
get_name
()
<<
" and its weights are of the wrong shape (not "
<<
vector_to_string
(
supported_shapes
[
full_shape_index
])
<<
" and 1x1) and strides (1x1)"
;
return
false
;
}
auto
m_conv_stride3
=
std
::
dynamic_pointer_cast
<
op
::
Convolution
>
(
pattern_map
[
conv_stride3_label
]);
if
(
!
are_img_dims_equal
(
m_conv_stride3
->
get_shape
(),
supported_shapes
[
full_shape_index
])
||
!
are_img_dims_equal
(
m_conv_stride3
->
get_argument
(
1
)
->
get_shape
(),
shape_3
)
||
m_conv_stride3
->
get_window_movement_strides
()
!=
stride_1
||
!
is_trivial_convolution
(
m_conv_stride3
,
true
))
{
NGRAPH_DEBUG
<<
m_conv_stride3
->
get_name
()
<<
" and its weights are of the wrong shape (not "
<<
vector_to_string
(
supported_shapes
[
full_shape_index
])
<<
" and 3x3) and strides (1x1)"
;
return
false
;
}
auto
conv_28w3s2
=
std
::
make_shared
<
op
::
Convolution
>
(
m_conv_stride3
->
get_argument
(
0
),
m_conv_stride3
->
get_argument
(
1
),
stride_2
,
stride_1
,
pad_1
,
pad_1
);
auto
maxpool_w3
=
std
::
make_shared
<
op
::
MaxPool
>
(
pattern_map
[
broadcast_w3_label
],
Shape
{
1
,
1
},
stride_2
);
auto
new_add_conv_28w3s2
=
std
::
make_shared
<
op
::
Add
>
(
conv_28w3s2
,
maxpool_w3
);
auto
new_relu_28w3s2
=
std
::
make_shared
<
op
::
Relu
>
(
new_add_conv_28w3s2
);
auto
conv_28w1s1
=
std
::
make_shared
<
op
::
Convolution
>
(
new_relu_28w3s2
,
m_conv_stride1
->
get_argument
(
1
),
stride_1
,
stride_1
);
auto
maxpool_w1
=
std
::
make_shared
<
op
::
MaxPool
>
(
pattern_map
[
broadcast_w1_label
],
Shape
{
1
,
1
},
stride_2
);
auto
new_add_conv28s1
=
std
::
make_shared
<
op
::
Add
>
(
conv_28w1s1
,
maxpool_w1
);
auto
maxpool
=
std
::
make_shared
<
op
::
MaxPool
>
(
pattern_map
[
eltwise_arg_label
],
Shape
{
1
,
1
},
stride_2
);
auto
new_add_two_convs
=
std
::
make_shared
<
op
::
Add
>
(
new_add_conv28s1
,
maxpool
);
auto
new_relu_two_convs
=
std
::
make_shared
<
op
::
Relu
>
(
new_add_two_convs
);
for
(
auto
sconv
:
sconvs
)
{
auto
sconv_28w1s1
=
std
::
make_shared
<
op
::
Convolution
>
(
new_relu_two_convs
,
sconv
->
get_argument
(
1
),
stride_1
,
stride_1
);
NGRAPH_DEBUG
<<
"Replacing "
<<
sconv
->
get_name
()
<<
" with "
<<
sconv_28w1s1
->
get_name
();
ngraph
::
replace_node
(
sconv
,
sconv_28w1s1
);
}
return
true
;
};
auto
m
=
make_shared
<
pattern
::
Matcher
>
(
eltwise_conv
,
callback
);
this
->
add_matcher
(
m
);
}
src/ngraph/pass/core_fusion.hpp
View file @
db553dc5
...
...
@@ -34,7 +34,9 @@ public:
{
construct_relu
();
construct_folded_batch_norm
();
construct_optimized_strided_conv
();
}
void
construct_relu
();
void
construct_folded_batch_norm
();
void
construct_optimized_strided_conv
();
};
src/ngraph/pass/cse.cpp
View file @
db553dc5
...
...
@@ -77,6 +77,17 @@ static bool cse_binarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
(
a
->
get_argument
(
1
)
==
b
->
get_argument
(
0
)
&&
a
->
get_argument
(
0
)
==
b
->
get_argument
(
1
));
}
static
bool
cse_reduction
(
std
::
shared_ptr
<
Node
>
a
,
std
::
shared_ptr
<
Node
>
b
)
{
NGRAPH_DEBUG
<<
"In cse_reduction for "
<<
a
->
get_name
()
<<
" and "
<<
b
->
get_name
();
auto
ar_a
=
std
::
dynamic_pointer_cast
<
op
::
util
::
ArithmeticReduction
>
(
a
);
auto
ar_b
=
std
::
dynamic_pointer_cast
<
op
::
util
::
ArithmeticReduction
>
(
b
);
return
ar_a
->
get_argument
(
0
)
==
ar_b
->
get_argument
(
0
)
&&
ar_a
->
get_reduction_axes
()
==
ar_b
->
get_reduction_axes
();
}
static
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
bool
(
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>
)
>>
initialize_ops_to_cse_handlers
()
...
...
@@ -110,6 +121,8 @@ static std::unordered_map<std::type_index,
{
TI
(
op
::
Power
),
cse_binarywise
},
//{TI(op::Remainder), cse_binarywise},
{
TI
(
op
::
Subtract
),
cse_binarywise
},
{
TI
(
op
::
Sum
),
cse_reduction
},
{
TI
(
op
::
Product
),
cse_reduction
},
});
}
...
...
src/ngraph/runtime/gpu/cuda_emitter.cpp
View file @
db553dc5
...
...
@@ -331,6 +331,87 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const runtime::gpu::GPURunti
return
primitive_index
;
}
size_t
runtime
::
gpu
::
CUDAEmitter
::
build_reverse_sequence
(
const
runtime
::
gpu
::
GPURuntimeContext
*
ctx
,
const
std
::
array
<
std
::
string
,
3
>&
dtypes
,
GPUShape
input_shape0
,
GPUShape
input_shape1
,
GPUShape
output_shape
,
size_t
batch_axis
,
size_t
sequence_axis
)
{
std
::
stringstream
kernel_name
;
kernel_name
<<
"reverse_sequence_"
<<
join
(
dtypes
,
"_"
)
<<
"_bi_"
<<
batch_axis
<<
"_si_"
<<
sequence_axis
<<
"_r_"
<<
output_shape
.
size
();
std
::
string
hash
=
kernel_name
.
str
()
+
"_i"
+
join
(
input_shape0
,
"_"
)
+
"_i"
+
join
(
input_shape1
,
"_"
)
+
"_o"
+
join
(
output_shape
);
// For backwards compatability we currently use two unordered maps
// 1. one looks up the compiled cuda kernel (CudaFunctionPool)
// 2. the other looks to see if this kernel is already in the primitive list
// check if the requested kernel is already an inserted primitive
size_t
primitive_index
=
m_primitive_emitter
->
lookup
(
hash
);
if
(
primitive_index
!=
std
::
numeric_limits
<
size_t
>::
max
())
{
return
primitive_index
;
}
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto
compiled_kernel
=
ctx
->
compiled_kernel_pool
->
get
(
kernel_name
.
str
());
if
(
compiled_kernel
==
nullptr
)
{
codegen
::
CodeWriter
writer
;
CudaKernelBuilder
::
add_pod_typedefs
(
writer
);
CudaKernelBuilder
::
get_reverse_sequence_op
(
writer
,
kernel_name
.
str
(),
dtypes
,
batch_axis
,
sequence_axis
,
output_shape
.
size
());
compiled_kernel
=
ctx
->
compiled_kernel_pool
->
set
(
kernel_name
.
str
(),
writer
.
get_code
());
}
uint32_t
nthreads
=
static_cast
<
uint32_t
>
(
shape_size
(
output_shape
));
GPUShape
output_strides
=
row_major_strides
(
output_shape
);
// get an allocator for transient per kernel gpu memory
GPUAllocator
allocator
=
this
->
m_primitive_emitter
->
get_memory_allocator
();
size_t
idx_output_shape
=
allocator
.
reserve_argspace
(
output_shape
.
data
(),
output_shape
.
size
()
*
sizeof
(
uint32_t
));
size_t
idx_output_strides
=
allocator
.
reserve_argspace
(
output_strides
.
data
(),
output_strides
.
size
()
*
sizeof
(
uint32_t
));
// create the launch primitive
std
::
unique_ptr
<
gpu
::
primitive
>
reserve_sequence
(
new
gpu
::
primitive
{[
=
](
void
**
inputs
,
void
**
outputs
)
mutable
{
void
*
param_output_shape
=
runtime
::
gpu
::
invoke_memory_primitive
(
ctx
,
idx_output_shape
);
void
*
param_output_strides
=
runtime
::
gpu
::
invoke_memory_primitive
(
ctx
,
idx_output_strides
);
std
::
vector
<
void
*>
args_list
{
&
inputs
[
0
],
&
inputs
[
1
],
&
outputs
[
0
],
&
param_output_shape
,
&
param_output_strides
,
&
nthreads
};
CUDA_SAFE_CALL
(
cuLaunchKernel
(
*
compiled_kernel
.
get
(),
static_cast
<
uint32_t
>
(
nthreads
),
1
,
1
,
// grid dim
1
,
1
,
1
,
// block dim
0
,
NULL
,
// shared mem and stream
args_list
.
data
(),
0
));
// arguments
CUDA_SAFE_CALL
(
cuCtxSynchronize
());
// Retrieve and print output.
}});
primitive_index
=
this
->
m_primitive_emitter
->
insert
(
std
::
move
(
reserve_sequence
));
m_primitive_emitter
->
cache
(
hash
,
primitive_index
);
return
primitive_index
;
}
size_t
runtime
::
gpu
::
CUDAEmitter
::
build_1d_max_pool
(
const
GPURuntimeContext
*
ctx
,
const
std
::
array
<
std
::
string
,
2
>&
dtypes
,
GPUShape
input_shape
,
...
...
@@ -1377,6 +1458,15 @@ __device__ __forceinline__ float load(const float* __restrict__ in, int i=0, b
}
return v;
}
__device__ __forceinline__ int32_t load(const int32_t* __restrict__ in, int i=0, bool b=true)
{
int32_t v = 0;
if (b)
{
v = __ldg(in + i);
}
return v;
}
__device__ __forceinline__ int64_t load(const int64_t* __restrict__ in, int i=0, bool b=true)
{
int64_t v = 0;
...
...
src/ngraph/runtime/gpu/cuda_emitter.hpp
View file @
db553dc5
...
...
@@ -78,6 +78,14 @@ namespace ngraph
GPUShape
reduce_window_shape
,
GPUShape
reduce_window_strides
);
size_t
build_reverse_sequence
(
const
runtime
::
gpu
::
GPURuntimeContext
*
ctx
,
const
std
::
array
<
std
::
string
,
3
>&
dtypes
,
GPUShape
input_shape0
,
GPUShape
input_shape1
,
GPUShape
output_shape
,
size_t
batch_axis
,
size_t
sequence_axis
);
template
<
typename
T
>
size_t
build_elementwise
(
const
GPURuntimeContext
*
ctx
,
const
std
::
vector
<
std
::
string
>&
dtypes
,
...
...
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.cpp
View file @
db553dc5
...
...
@@ -313,6 +313,53 @@ void runtime::gpu::CudaKernelBuilder::get_pad_dynamic_op(
writer
.
block_end
();
}
void
runtime
::
gpu
::
CudaKernelBuilder
::
get_reverse_sequence_op
(
codegen
::
CodeWriter
&
writer
,
const
std
::
string
&
name
,
const
std
::
array
<
std
::
string
,
3
>&
data_types
,
const
size_t
batch_axis
,
const
size_t
sequence_axis
,
const
size_t
rank
)
{
writer
<<
"extern
\"
C
\"
__global__ void cuda_"
<<
name
<<
"("
<<
data_types
[
0
]
<<
"* in, "
<<
data_types
[
1
]
<<
"* sequence, "
<<
data_types
[
2
]
<<
"* out, "
<<
"uint32_t* output_shape, uint32_t* output_strides, uint32_t n)
\n
"
;
writer
.
block_begin
();
{
writer
<<
"uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
\n
"
;
writer
<<
"if (tid < n)
\n
"
;
writer
.
block_begin
();
{
writer
<<
"uint32_t input_idx = tid;
\n
"
;
for
(
size_t
i
=
0
;
i
<
rank
-
1
;
i
++
)
{
writer
<<
"uint32_t output_idx_"
<<
i
<<
" = input_idx / output_strides["
<<
i
<<
"];
\n
"
;
writer
<<
"input_idx %= output_strides["
<<
i
<<
"];
\n
"
;
}
writer
<<
"uint32_t output_idx_"
<<
rank
-
1
<<
" = input_idx / output_strides["
<<
rank
-
1
<<
"];
\n
"
;
writer
<<
"uint32_t sequence_length = sequence[output_idx_"
<<
batch_axis
<<
"];
\n
"
;
writer
<<
"assert(sequence_length <= output_shape["
<<
sequence_axis
<<
"]);
\n
"
;
writer
<<
"bool need_reverse = (output_idx_"
<<
sequence_axis
<<
" < sequence_length) && (sequence_length > 1);
\n
"
;
writer
<<
"output_idx_"
<<
sequence_axis
<<
" = need_reverse ? sequence_length - output_idx_"
<<
sequence_axis
<<
" - 1 : output_idx_"
<<
sequence_axis
<<
";
\n
"
;
writer
<<
"uint32_t output_idx = need_reverse ? "
;
writer
<<
"output_idx_"
<<
0
<<
" * output_strides["
<<
0
<<
"]"
;
for
(
size_t
i
=
1
;
i
<
rank
;
i
++
)
{
writer
<<
" + output_idx_"
<<
i
<<
" * output_strides["
<<
i
<<
"]"
;
}
writer
<<
" : tid;
\n
"
;
writer
<<
"out[output_idx] = in[tid];
\n
"
;
}
writer
.
block_end
();
}
writer
.
block_end
();
}
void
runtime
::
gpu
::
CudaKernelBuilder
::
get_slice_op
(
codegen
::
CodeWriter
&
writer
,
const
std
::
string
&
name
,
const
std
::
array
<
std
::
string
,
2
>&
data_types
)
...
...
src/ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp
View file @
db553dc5
...
...
@@ -76,6 +76,13 @@ namespace ngraph
const
std
::
vector
<
std
::
string
>&
data_types
,
const
size_t
rank
);
static
void
get_reverse_sequence_op
(
codegen
::
CodeWriter
&
writer
,
const
std
::
string
&
name
,
const
std
::
array
<
std
::
string
,
3
>&
data_types
,
const
size_t
batch_axis
,
const
size_t
sequence_axis
,
const
size_t
rank
);
static
void
get_device_helper
(
codegen
::
CodeWriter
&
writer
,
const
std
::
string
&
name
,
const
std
::
string
&
math_kernel
,
...
...
src/ngraph/runtime/gpu/gpu_emitter.cpp
View file @
db553dc5
...
...
@@ -80,6 +80,7 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/sign.hpp"
...
...
@@ -1151,6 +1152,38 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer
.
block_end
();
}
template
<>
void
GPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
ReverseSequence
)
{
if
(
out
[
0
].
get_size
()
==
0
)
{
return
;
}
auto
rs
=
static_cast
<
const
ngraph
::
op
::
ReverseSequence
*>
(
node
);
size_t
bi
=
rs
->
get_batch_axis
();
size_t
si
=
rs
->
get_sequence_axis
();
auto
arg_shape0
=
args
[
0
].
get_shape
();
auto
arg_shape1
=
args
[
1
].
get_shape
();
auto
out_shape
=
out
[
0
].
get_shape
();
auto
&
cuda_emitter
=
external_function
->
get_primitive_emitter
()
->
get_cuda_emitter
();
auto
rs_index
=
cuda_emitter
->
build_reverse_sequence
(
external_function
->
ctx
().
get
(),
{{
args
[
0
].
get_type
(),
args
[
1
].
get_type
(),
out
[
0
].
get_type
()}},
arg_shape0
,
arg_shape1
,
out_shape
,
bi
,
si
);
writer
<<
"gpu::invoke_primitive(ctx, "
<<
rs_index
<<
", "
;
writer
<<
"std::vector<void*>{"
<<
args
[
0
].
get_name
()
<<
", "
<<
args
[
1
].
get_name
()
<<
"}.data(), "
;
writer
<<
"std::vector<void*>{"
<<
out
[
0
].
get_name
()
<<
"}.data()"
;
writer
<<
");
\n
"
;
}
template
<>
void
GPU_Emitter
::
EMITTER_DECL
(
ngraph
::
op
::
Multiply
)
{
...
...
src/ngraph/runtime/gpu/gpu_external_function.cpp
View file @
db553dc5
...
...
@@ -91,6 +91,7 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/sign.hpp"
...
...
@@ -223,6 +224,8 @@ static const runtime::gpu::OpMap dispatcher{
{
TI
(
ngraph
::
op
::
Not
),
&
runtime
::
gpu
::
GPU_Emitter
::
emit_elementwise
<
ngraph
::
op
::
Not
>
},
{
TI
(
ngraph
::
op
::
MaxPool
),
&
runtime
::
gpu
::
GPU_Emitter
::
emit
<
ngraph
::
op
::
MaxPool
>
},
{
TI
(
ngraph
::
op
::
Reverse
),
&
runtime
::
gpu
::
GPU_Emitter
::
emit
<
ngraph
::
op
::
Reverse
>
},
{
TI
(
ngraph
::
op
::
ReverseSequence
),
&
runtime
::
gpu
::
GPU_Emitter
::
emit
<
ngraph
::
op
::
ReverseSequence
>
},
{
TI
(
ngraph
::
op
::
Result
),
&
runtime
::
gpu
::
GPU_Emitter
::
emit
<
ngraph
::
op
::
Result
>
},
{
TI
(
ngraph
::
op
::
ReduceWindow
),
&
runtime
::
gpu
::
GPU_Emitter
::
emit
<
ngraph
::
op
::
ReduceWindow
>
},
{
TI
(
ngraph
::
op
::
SelectAndScatter
),
...
...
src/ngraph/runtime/gpu/unit_test.manifest
View file @
db553dc5
abc_int64
backwards_reverse_sequence_n4d2c3h2w2
backwards_reverse_sequence_n3_c2_h3
backwards_slice
batch_norm_one_output
batch_norm_three_outputs
...
...
@@ -31,9 +29,6 @@ one_hot_vector_1_barely_oob
one_hot_vector_1_far_oob
one_hot_vector_1_fp_nonint
parameter_as_output
reverse_sequence_n4d2c3h2w2
reverse_sequence_n4c3h2w2
reverse_sequence_n2c3h4w2
scalar_constant_float32
scalar_constant_int64
select_and_scatter_3d_without_overlap
...
...
test/core_fusion.cpp
View file @
db553dc5
...
...
@@ -55,3 +55,59 @@ TEST(core_fusion, core_fusion_pass_basic)
pass_manager
.
run_passes
(
func
);
ASSERT_NE
(
std
::
dynamic_pointer_cast
<
op
::
Relu
>
(
graph
->
get_argument
(
0
)),
nullptr
);
}
TEST
(
core_fusion
,
sparsity_opt_56x56
)
{
Shape
win_size_3
{
1
,
1
,
3
,
3
};
Shape
win_size_1
{
1
,
1
,
1
,
1
};
Strides
stride_2
{
2
,
2
};
Strides
stride_1
{
1
,
1
};
CoordinateDiff
pad_0
{
0
,
0
};
CoordinateDiff
pad_1
{
1
,
1
};
auto
data_stride3
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
64
,
56
,
56
});
auto
weights_stride3
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
64
,
64
,
3
,
3
});
auto
conv_stride3
=
std
::
make_shared
<
op
::
Convolution
>
(
data_stride3
,
weights_stride3
,
stride_1
,
stride_1
,
pad_1
,
pad_1
);
auto
param_broadcast_w3
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
64
});
auto
broadcast_w3
=
std
::
make_shared
<
op
::
Broadcast
>
(
param_broadcast_w3
,
Shape
{
1
,
64
,
56
,
56
},
AxisSet
{
0
,
2
,
3
});
auto
add_w3
=
std
::
make_shared
<
op
::
Add
>
(
conv_stride3
,
broadcast_w3
);
auto
relu_w3
=
std
::
make_shared
<
op
::
Relu
>
(
add_w3
);
///
auto
weights_stride1
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
256
,
64
,
1
,
1
});
auto
conv_stride1
=
std
::
make_shared
<
op
::
Convolution
>
(
relu_w3
,
weights_stride1
);
auto
param_broadcast_w1
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
256
});
auto
broadcast_w1
=
std
::
make_shared
<
op
::
Broadcast
>
(
param_broadcast_w1
,
Shape
{
1
,
256
,
56
,
56
},
AxisSet
{
0
,
2
,
3
});
auto
add_w1
=
std
::
make_shared
<
op
::
Add
>
(
conv_stride1
,
broadcast_w1
);
////
auto
other_arg
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
256
,
56
,
56
});
auto
add_two_convs
=
std
::
make_shared
<
op
::
Add
>
(
add_w1
,
other_arg
);
auto
relu_two_convs
=
std
::
make_shared
<
op
::
Relu
>
(
add_two_convs
);
///
auto
weights_conv_s2
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
512
,
256
,
1
,
1
});
auto
conv_s2_1
=
std
::
make_shared
<
op
::
Convolution
>
(
relu_two_convs
,
weights_conv_s2
,
stride_2
);
auto
conv_s2_2
=
std
::
make_shared
<
op
::
Convolution
>
(
relu_two_convs
,
weights_conv_s2
,
stride_2
);
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
CoreFusion
>
();
auto
params
=
op
::
ParameterVector
{
data_stride3
,
weights_stride3
,
param_broadcast_w3
,
weights_stride1
,
param_broadcast_w1
,
other_arg
,
weights_conv_s2
};
auto
func
=
make_shared
<
Function
>
(
NodeVector
{
conv_s2_1
,
conv_s2_2
},
params
);
pass_manager
.
run_passes
(
func
);
auto
results
=
func
->
get_results
();
auto
t_eltwise_conv1
=
std
::
dynamic_pointer_cast
<
op
::
Convolution
>
(
results
.
at
(
0
)
->
get_argument
(
0
));
auto
t_eltwise_conv2
=
std
::
dynamic_pointer_cast
<
op
::
Convolution
>
(
results
.
at
(
1
)
->
get_argument
(
0
));
ASSERT_TRUE
(
t_eltwise_conv1
);
ASSERT_TRUE
(
t_eltwise_conv2
);
ASSERT_EQ
(
t_eltwise_conv1
->
get_window_movement_strides
(),
stride_1
);
ASSERT_EQ
(
t_eltwise_conv2
->
get_window_movement_strides
(),
stride_1
);
}
test/cse.cpp
View file @
db553dc5
...
...
@@ -188,3 +188,35 @@ TEST(CSE, abs_add_abs_add_negative)
ASSERT_EQ
(
oadd4
->
get_argument
(
1
),
D
);
ASSERT_EQ
(
oadd3
->
get_argument
(
0
),
oadd4
->
get_argument
(
0
));
}
template
<
typename
T
>
static
void
execute_cse_reduction_test
()
{
Shape
zero_shape
{
0
};
auto
A
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
3
,
5
});
auto
a_reduction_op
=
std
::
make_shared
<
T
>
(
A
,
AxisSet
{
0
,
1
});
auto
a_reduction_op2
=
std
::
make_shared
<
T
>
(
A
,
AxisSet
{
0
,
1
});
auto
a_reduction_op3
=
std
::
make_shared
<
T
>
(
A
,
AxisSet
{
0
});
auto
sub_aa
=
a_reduction_op
-
a_reduction_op2
;
auto
B
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
3
,
5
});
auto
b_reduction_op
=
std
::
make_shared
<
T
>
(
B
,
AxisSet
{
0
,
1
});
auto
sub_ab
=
a_reduction_op
-
b_reduction_op
;
auto
f
=
std
::
make_shared
<
Function
>
(
NodeVector
{
sub_aa
,
sub_ab
,
a_reduction_op3
},
op
::
ParameterVector
{
A
,
B
});
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
ngraph
::
pass
::
CommonSubexpressionElimination
>
();
pass_manager
.
run_passes
(
f
);
ASSERT_EQ
(
sub_aa
->
get_argument
(
0
),
sub_aa
->
get_argument
(
1
));
ASSERT_NE
(
sub_ab
->
get_argument
(
0
),
sub_ab
->
get_argument
(
1
));
ASSERT_NE
(
f
->
get_results
().
at
(
2
)
->
get_argument
(
0
),
sub_aa
->
get_argument
(
0
));
}
TEST
(
CSE
,
reduction_ops
)
{
execute_cse_reduction_test
<
op
::
Sum
>
();
execute_cse_reduction_test
<
op
::
Product
>
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment