Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
9b8ddba4
Commit
9b8ddba4
authored
Dec 06, 2019
by
YashasSamaga
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add ROIPoolingOp
parent
4b0132ed
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
208 additions
and
1 deletion
+208
-1
math.hpp
modules/dnn/src/cuda/math.hpp
+12
-0
roi_pooling.cu
modules/dnn/src/cuda/roi_pooling.cu
+121
-0
roi_pooling.hpp
modules/dnn/src/cuda4dnn/kernels/roi_pooling.hpp
+19
-0
roi_pooling.hpp
modules/dnn/src/cuda4dnn/primitives/roi_pooling.hpp
+52
-0
pooling_layer.cpp
modules/dnn/src/layers/pooling_layer.cpp
+4
-1
No files found.
modules/dnn/src/cuda/math.hpp
View file @
9b8ddba4
...
...
@@ -120,6 +120,18 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace de
template
<
class
T
>
__device__
T
clamp
(
T
value
,
T
lower
,
T
upper
)
{
return
min
(
max
(
value
,
lower
),
upper
);
}
template
<
class
T
>
__device__
T
round
(
T
value
);
template
<>
inline
__device__
double
round
(
double
value
)
{
return
::
round
(
value
);
}
template
<>
inline
__device__
float
round
(
float
value
)
{
return
roundf
(
value
);
}
template
<>
inline
__device__
__half
round
(
__half
value
)
{
return
hrint
(
value
);
}
template
<>
inline
__device__
__half2
round
(
__half2
value
)
{
return
h2rint
(
value
);
}
template
<
class
T
>
__device__
T
ceil
(
T
value
);
template
<>
inline
__device__
double
ceil
(
double
value
)
{
return
::
ceil
(
value
);
}
template
<>
inline
__device__
float
ceil
(
float
value
)
{
return
ceilf
(
value
);
}
template
<>
inline
__device__
__half
ceil
(
__half
value
)
{
return
hceil
(
value
);
}
template
<>
inline
__device__
__half2
ceil
(
__half2
value
)
{
return
h2ceil
(
value
);
}
}}}}}
/* namespace cv::dnn::cuda4dnn::csl::device */
#endif
/* OPENCV_DNN_SRC_CUDA_MATH_HPP */
modules/dnn/src/cuda/roi_pooling.cu
0 → 100644
View file @
9b8ddba4
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "math.hpp"
#include "limits.hpp"
#include "types.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/tensor.hpp"
#include "../cuda4dnn/csl/span.hpp"
#include <opencv2/core.hpp>
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T>
__global__ void roi_pooling(
Span<T> output, size_type pooled_height, size_type pooled_width,
View<T> input, size_type in_height, size_type in_width,
View<T> rois, size_type num_channels, T spatial_scale)
{
// input: [1, num_channels, in_height, in_width]
// rois: [num_rois, 5]
// output: [num_rois, num_channels, pooled_height, pooled_width]
const auto out_spatial_size = pooled_height * pooled_width;
const auto out_roi_size = num_channels * out_spatial_size;
/* every element in the output is mapped to a window in the input and each thread processes several windows */
for (auto idx : grid_stride_range(output.size()))
{
const auto n = idx / out_roi_size;
const auto c = (idx % out_roi_size) / out_spatial_size;
const auto y = (idx % out_spatial_size) / pooled_width;
const auto x = idx % pooled_width;
const index_type roi_offset = n * 5;
using device::round;
const index_type batch_id = rois[roi_offset + 0];
const index_type x_start_roi = round(rois[roi_offset + 1] * spatial_scale);
const index_type y_start_roi = round(rois[roi_offset + 2] * spatial_scale);
const index_type x_end_roi = round(rois[roi_offset + 3] * spatial_scale);
const index_type y_end_roi = round(rois[roi_offset + 4] * spatial_scale);
using device::max;
const auto roi_width = max<index_type>(x_end_roi - x_start_roi + 1, 1);
const auto roi_height = max<index_type>(y_end_roi - y_start_roi + 1, 1);
const auto roi_width_ratio = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const auto roi_height_ratio = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
auto x_start = x_start_roi + static_cast<index_type>(static_cast<T>(x) * roi_width_ratio);
auto y_start = y_start_roi + static_cast<index_type>(static_cast<T>(y) * roi_height_ratio);
using device::ceil;
auto x_end = x_start_roi + static_cast<index_type>(ceil(static_cast<T>(x + 1) * roi_width_ratio));
auto y_end = y_start_roi + static_cast<index_type>(ceil(static_cast<T>(y + 1) * roi_height_ratio));
using device::max;
x_start = max<index_type>(x_start, 0);
y_start = max<index_type>(y_start, 0);
using device::min;
x_end = min<index_type>(x_end, in_width);
y_end = min<index_type>(y_end, in_height);
/* We have to set the output to zero if (x_start >= x_end) or (y_start >= y_end). If either
* condition is true, the loops below won't execute even a single iteration. Hence, by setting
* `max_val` to zero in this case, we can combine it with the `else` code.
*/
T max_val = (x_start >= x_end || y_start >= y_end) ? T(0) : device::numeric_limits<T>::lowest();
const index_type in_offset = (batch_id * num_channels + c) * in_height * in_width;
for (auto iy = y_start; iy < y_end; iy++)
{
for (auto ix = x_start; ix < x_end; ix++)
{
const auto in_idx = in_offset + iy * in_width + ix;
max_val = max(max_val, input[in_idx]);
}
}
output[idx] = max_val;
}
}
}
template <class T>
void roi_pooling(const Stream& stream, TensorSpan<T> output, TensorView<T> input, View<T> rois, T spatial_scale)
{
CV_Assert(input.get_axis_size(1) == output.get_axis_size(1));
size_type num_channels = output.get_axis_size(1);
size_type pooled_height = output.get_axis_size(2);
size_type pooled_width = output.get_axis_size(3);
size_type in_height = input.get_axis_size(2);
size_type in_width = input.get_axis_size(3);
auto kernel = raw::roi_pooling<T>;
auto policy = make_policy(kernel, output.size(), 0, stream);
launch_kernel(kernel, policy, output, pooled_height, pooled_width, input, in_height, in_width, rois, num_channels, spatial_scale);
}
template void roi_pooling(const Stream& stream, TensorSpan<__half> output, TensorView<__half> input, View<__half> rois, __half spatial_scale);
template void roi_pooling(const Stream& stream, TensorSpan<float> output, TensorView<float> input, View<float> rois, float spatial_scale);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
modules/dnn/src/cuda4dnn/kernels/roi_pooling.hpp
0 → 100644
View file @
9b8ddba4
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ROI_POOLING_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ROI_POOLING_HPP
#include "../csl/stream.hpp"
#include "../csl/tensor.hpp"
#include "../csl/span.hpp"
namespace
cv
{
namespace
dnn
{
namespace
cuda4dnn
{
namespace
kernels
{
template
<
class
T
>
void
roi_pooling
(
const
csl
::
Stream
&
stream
,
csl
::
TensorSpan
<
T
>
output
,
csl
::
TensorView
<
T
>
input
,
csl
::
View
<
T
>
rois
,
T
spatial_scale
);
}}}}
/* namespace cv::dnn::cuda4dnn::kernels */
#endif
/* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ROI_POOLING_HPP */
modules/dnn/src/cuda4dnn/primitives/roi_pooling.hpp
0 → 100644
View file @
9b8ddba4
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_ROI_POOLING_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_ROI_POOLING_HPP
#include "../../op_cuda.hpp"
#include "../csl/stream.hpp"
#include "../kernels/roi_pooling.hpp"
#include <utility>
namespace
cv
{
namespace
dnn
{
namespace
cuda4dnn
{
template
<
class
T
>
class
ROIPoolingOp
final
:
public
CUDABackendNode
{
public
:
using
wrapper_type
=
GetCUDABackendWrapperType
<
T
>
;
ROIPoolingOp
(
csl
::
Stream
stream_
,
float
spatial_scale
)
:
stream
(
std
::
move
(
stream_
)),
spatial_scale
{
spatial_scale
}
{
}
void
forward
(
const
std
::
vector
<
cv
::
Ptr
<
BackendWrapper
>>&
inputs
,
const
std
::
vector
<
cv
::
Ptr
<
BackendWrapper
>>&
outputs
,
csl
::
Workspace
&
workspace
)
override
{
CV_Assert
(
inputs
.
size
()
==
2
&&
outputs
.
size
()
==
1
);
auto
input_wrapper
=
inputs
[
0
].
dynamicCast
<
wrapper_type
>
();
auto
input
=
input_wrapper
->
getView
();
auto
rois_wrapper
=
inputs
[
1
].
dynamicCast
<
wrapper_type
>
();
auto
rois
=
rois_wrapper
->
getView
();
auto
output_wrapper
=
outputs
[
0
].
dynamicCast
<
wrapper_type
>
();
auto
output
=
output_wrapper
->
getSpan
();
kernels
::
roi_pooling
<
T
>
(
stream
,
output
,
input
,
rois
,
spatial_scale
);
}
private
:
csl
::
Stream
stream
;
float
spatial_scale
;
};
}}}
/* namespace cv::dnn::cuda4dnn */
#endif
/* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_ROI_POOLING_HPP */
modules/dnn/src/layers/pooling_layer.cpp
View file @
9b8ddba4
...
...
@@ -68,6 +68,7 @@ using namespace cv::dnn::ocl4dnn;
#ifdef HAVE_CUDA
#include "../cuda4dnn/primitives/pooling.hpp"
#include "../cuda4dnn/primitives/roi_pooling.hpp"
#include "../cuda4dnn/primitives/max_unpooling.hpp"
using
namespace
cv
::
dnn
::
cuda4dnn
;
#endif
...
...
@@ -178,7 +179,7 @@ public:
{
if
(
backendId
==
DNN_BACKEND_CUDA
)
{
return
type
==
MAX
||
type
==
AVE
;
return
type
==
MAX
||
type
==
AVE
||
type
==
ROI
;
}
else
if
(
backendId
==
DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019
)
{
...
...
@@ -313,6 +314,8 @@ public:
)
override
{
auto
context
=
reinterpret_cast
<
csl
::
CSLContext
*>
(
context_
);
if
(
type
==
ROI
)
return
make_cuda_node
<
cuda4dnn
::
ROIPoolingOp
>
(
preferableTarget
,
std
::
move
(
context
->
stream
),
spatialScale
);
auto
input_wrapper
=
inputs
[
0
].
dynamicCast
<
CUDABackendWrapper
>
();
auto
input_shape
=
input_wrapper
->
getShape
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment