Commit 613c12e5 authored by Yashas Samaga B L's avatar Yashas Samaga B L Committed by Alexander Alekhin

Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low

CUDA backend for the DNN module

* stub cuda4dnn design

* minor fixes for tests and doxygen

* add csl public api directory to module headers

* add low-level CSL components

* add high-level CSL components

* integrate csl::Tensor into backbone code

* switch to CPU iff unsupported; otherwise, fail on error

* add fully connected layer

* add softmax layer

* add activation layers

* support arbitary rank TensorDescriptor

* pass input wrappers to `initCUDA()`

* add 1d/2d/3d-convolution

* add pooling layer

* reorganize and refactor code

* fixes for gcc, clang and doxygen; remove cxx14/17 code

* add blank_layer

* add LRN layer

* add rounding modes for pooling layer

* split tensor.hpp into tensor.hpp and tensor_ops.hpp

* add concat layer

* add scale layer

* add batch normalization layer

* split math.cu into activations.cu and math.hpp

* add eltwise layer

* add flatten layer

* add tensor transform api

* add asymmetric padding support for convolution layer

* add reshape layer

* fix rebase issues

* add permute layer

* add padding support for concat layer

* refactor and reorganize code

* add normalize layer

* optimize bias addition in scale layer

* add prior box layer

* fix and optimize normalize layer

* add asymmetric padding support for pooling layer

* add event API

* improve pooling performance for some padding scenarios

* avoid over-allocation of compute resources to kernels

* improve prior box performance

* enable layer fusion

* add const layer

* add resize layer

* add slice layer

* add padding layer

* add deconvolution layer

* fix channelwise  ReLU initialization

* add vector traits

* add vectorized versions of relu, clipped_relu, power

* add vectorized concat kernels

* improve concat_with_offsets performance

* vectorize scale and bias kernels

* add support for multi-billion element tensors

* vectorize prior box kernels

* fix address alignment check

* improve bias addition performance of conv/deconv/fc layers

* restructure code for supporting multiple targets

* add DNN_TARGET_CUDA_FP64

* add DNN_TARGET_FP16

* improve vectorization

* add region layer

* improve tensor API, add dynamic ranks

1. use ManagedPtr instead of a Tensor in backend wrapper
2. add new methods to tensor classes
  - size_range: computes the combined size of for a given axis range
  - tensor span/view can be constructed from a raw pointer and shape
3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
4. remove device code from tensor classes (as they are unused)
5. enforce strict conditions on tensor class APIs to improve debugging ability

* fix parametric relu activation

* add squeeze/unsqueeze tensor API

* add reorg layer

* optimize permute and enable 2d permute

* enable 1d and 2d slice

* add split layer

* add shuffle channel layer

* allow tensors of different ranks in reshape primitive

* patch SliceOp to allow Crop Layer

* allow extra shape inputs in reshape layer

* use `std::move_backward` instead of `std::move` for insert in resizable_static_array

* improve workspace management

* add spatial LRN

* add nms (cpu) to region layer

* add max pooling with argmax ( and a fix to limits.hpp)

* add max unpooling layer

* rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA

* update supportBackend to be more rigorous

* remove stray include from preventing non-cuda build

* include op_cuda.hpp outside condition #if

* refactoring, fixes and many optimizations

* drop DNN_TARGET_CUDA_FP64

* fix gcc errors

* increase max. tensor rank limit to six

* add Interp layer

* drop custom layers; use BackendNode

* vectorize activation kernels

* fixes for gcc

* remove wrong assertion

* fix broken assertion in unpooling primitive

* fix build errors in non-CUDA build

* completely remove workspace from public API

* fix permute layer

* enable accuracy and perf. tests for DNN_TARGET_CUDA

* add asynchronous forward

* vectorize eltwise ops

* vectorize fill kernel

* fixes for gcc

* remove CSL headers from public API

* remove csl header source group from cmake

* update min. cudnn version in cmake

* add numerically stable FP32 log1pexp

* refactor code

* add FP16 specialization to cudnn based tensor addition

* vectorize scale1 and bias1 + minor refactoring

* fix doxygen build

* fix invalid alignment assertion

* clear backend wrappers before allocateLayers

* ignore memory lock failures

* do not allocate internal blobs

* integrate NVTX

* add numerically stable half precision log1pexp

* fix indentation, following coding style,  improve docs

* remove accidental modification of IE code

* Revert "add asynchronous forward"

This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.

* [cmake] throw error for unsupported CC versions

* fix rebase issues

* add more docs, refactor code, fix bugs

* minor refactoring and fixes

* resolve warnings/errors from clang

* remove haveCUDA() checks from supportBackend()

* remove NVTX integration

* changes based on review comments

* avoid exception when no CUDA device is present

* add color code for CUDA in Net::dump
parent 8ec65446
......@@ -2,7 +2,7 @@ if(NOT DEFINED MIN_VER_CMAKE)
set(MIN_VER_CMAKE 3.5.1)
endif()
set(MIN_VER_CUDA 6.5)
set(MIN_VER_CUDNN 6)
set(MIN_VER_CUDNN 7.5)
set(MIN_VER_PYTHON2 2.7)
set(MIN_VER_PYTHON3 3.2)
set(MIN_VER_ZLIB 1.2.3)
......
......@@ -90,6 +90,14 @@ endif()
if(OPENCV_DNN_CUDA AND HAVE_CUDA AND HAVE_CUBLAS AND HAVE_CUDNN)
list(APPEND include_dirs ${CUDA_TOOLKIT_INCLUDE} ${CUDNN_INCLUDE_DIRS})
set(CC_LIST ${CUDA_ARCH_BIN})
separate_arguments(CC_LIST)
foreach(cc ${CC_LIST})
if(cc VERSION_LESS 5.3)
message(FATAL_ERROR "CUDA backend for DNN module requires CC 5.3 or higher. Please remove unsupported architectures from CUDA_ARCH_BIN option.")
endif()
endforeach()
unset(CC_LIST)
else()
set(sources_options ${sources_options} EXCLUDE_CUDA)
endif()
......
......@@ -71,7 +71,8 @@ CV__DNN_INLINE_NS_BEGIN
DNN_BACKEND_HALIDE,
DNN_BACKEND_INFERENCE_ENGINE, //!< Intel's Inference Engine computational backend.
DNN_BACKEND_OPENCV,
DNN_BACKEND_VKCOM
DNN_BACKEND_VKCOM,
DNN_BACKEND_CUDA
};
/**
......@@ -85,7 +86,9 @@ CV__DNN_INLINE_NS_BEGIN
DNN_TARGET_OPENCL_FP16,
DNN_TARGET_MYRIAD,
DNN_TARGET_VULKAN,
DNN_TARGET_FPGA //!< FPGA device with CPU fallbacks using Inference Engine's Heterogeneous plugin.
DNN_TARGET_FPGA, //!< FPGA device with CPU fallbacks using Inference Engine's Heterogeneous plugin.
DNN_TARGET_CUDA,
DNN_TARGET_CUDA_FP16
};
CV_EXPORTS std::vector< std::pair<Backend, Target> > getAvailableBackends();
......@@ -274,6 +277,20 @@ CV__DNN_INLINE_NS_BEGIN
virtual Ptr<BackendNode> initInfEngine(const std::vector<Ptr<BackendWrapper> > &inputs);
virtual Ptr<BackendNode> initVkCom(const std::vector<Ptr<BackendWrapper> > &inputs);
/**
* @brief Returns a CUDA backend node
*
* @param context void pointer to CSLContext object
* @param inputs layer inputs
* @param outputs layer outputs
*/
virtual Ptr<BackendNode> initCUDA(
void *context,
const std::vector<Ptr<BackendWrapper>>& inputs,
const std::vector<Ptr<BackendWrapper>>& outputs
);
/**
* @brief Automatic Halide scheduling based on layer hyper-parameters.
* @param[in] node Backend node with Halide functions.
......@@ -515,13 +532,15 @@ CV__DNN_INLINE_NS_BEGIN
* @see Target
*
* List of supported combinations backend / target:
* | | DNN_BACKEND_OPENCV | DNN_BACKEND_INFERENCE_ENGINE | DNN_BACKEND_HALIDE |
* |------------------------|--------------------|------------------------------|--------------------|
* | DNN_TARGET_CPU | + | + | + |
* | DNN_TARGET_OPENCL | + | + | + |
* | DNN_TARGET_OPENCL_FP16 | + | + | |
* | DNN_TARGET_MYRIAD | | + | |
* | DNN_TARGET_FPGA | | + | |
* | | DNN_BACKEND_OPENCV | DNN_BACKEND_INFERENCE_ENGINE | DNN_BACKEND_HALIDE | DNN_BACKEND_CUDA |
* |------------------------|--------------------|------------------------------|--------------------|-------------------|
* | DNN_TARGET_CPU | + | + | + | |
* | DNN_TARGET_OPENCL | + | + | + | |
* | DNN_TARGET_OPENCL_FP16 | + | + | | |
* | DNN_TARGET_MYRIAD | | + | | |
* | DNN_TARGET_FPGA | | + | | |
* | DNN_TARGET_CUDA | | | | + |
* | DNN_TARGET_CUDA_FP16 | | | | + |
*/
CV_WRAP void setPreferableTarget(int targetId);
......
......@@ -111,8 +111,8 @@ PERF_TEST_P_(Conv3D, conv3d)
Backend backendId = get<0>(get<1>(GetParam()));
Target targetId = get<1>(get<1>(GetParam()));
if (targetId != DNN_TARGET_CPU)
throw SkipTestException("Only CPU is supported");
if (targetId != DNN_TARGET_CPU && backendId != DNN_BACKEND_CUDA)
throw SkipTestException("Only CPU and CUDA is supported");
int inChannels = inputShape[1];
......
This diff is collapsed.
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA_ARRAY_HPP
#define OPENCV_DNN_SRC_CUDA_ARRAY_HPP
#include <cuda_runtime.h>
#include "types.hpp"
#include <cstddef>
#include <type_traits>
#include <iterator>
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
template <class T, std::size_t N>
struct array {
using value_type = T;
using size_type = device::size_type;
using difference_type = std::ptrdiff_t;
using reference = typename std::add_lvalue_reference<value_type>::type;
using const_reference = typename std::add_lvalue_reference<typename std::add_const<value_type>::type>::type;
using pointer = typename std::add_pointer<value_type>::type;
using const_pointer = typename std::add_pointer<typename std::add_const<value_type>::type>::type;
using iterator = pointer;
using const_iterator = const_pointer;
using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
__host__ __device__ bool empty() const noexcept { return N == 0; }
__host__ __device__ size_type size() const noexcept { return N; }
__host__ __device__ iterator begin() noexcept { return ptr; }
__host__ __device__ iterator end() noexcept { return ptr + N; }
__host__ __device__ const_iterator begin() const noexcept { return ptr; }
__host__ __device__ const_iterator end() const noexcept { return ptr + N; }
__host__ __device__ const_iterator cbegin() const noexcept { return ptr; }
__host__ __device__ const_iterator cend() const noexcept { return ptr + N; }
__host__ __device__ reverse_iterator rbegin() noexcept { return ptr + N; }
__host__ __device__ reverse_iterator rend() noexcept { return ptr; }
__host__ __device__ const_reverse_iterator rbegin() const noexcept { return ptr + N; }
__host__ __device__ const_reverse_iterator rend() const noexcept { return ptr; }
__host__ __device__ const_reverse_iterator crbegin() const noexcept { return ptr + N; }
__host__ __device__ const_reverse_iterator crend() const noexcept { return ptr; }
template <class InputItr>
__host__ void assign(InputItr first, InputItr last) {
std::copy(first, last, std::begin(ptr));
}
__host__ __device__ reference operator[](int idx) { return ptr[idx]; }
__host__ __device__ const_reference operator[](int idx) const { return ptr[idx]; }
__host__ __device__ reference front() { return ptr[0]; }
__host__ __device__ const_reference front() const { return ptr[0]; }
__host__ __device__ reference back() { return ptr[N - 1]; }
__host__ __device__ const_reference back() const { return ptr[N - 1]; }
__host__ __device__ pointer data() noexcept { return ptr; }
__host__ __device__ const_pointer data() const noexcept { return ptr; }
T ptr[N];
};
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
#endif /* OPENCV_DNN_SRC_CUDA_ARRAY_HPP */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA_ATOMICS_HPP
#define OPENCV_DNN_SRC_CUDA_ATOMICS_HPP
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700
#else
inline __device__ void atomicAdd(__half* address, __half val) {
unsigned int* address_as_ui = (unsigned int *)((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do {
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
__half tmpres = hsum + val;
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
}
#endif
#endif /* OPENCV_DNN_SRC_CUDA_ATOMICS_HPP */
This diff is collapsed.
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "math.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "vector_traits.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/span.hpp"
#include <opencv2/core.hpp>
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, std::size_t N>
__global__ void eltwise_max_2_vec(Span<T> output, View<T> x, View<T> y) {
using vector_type = get_vector_type_t<T, N>;
auto output_vPtr = vector_type::get_pointer(output.data());
auto x_vPtr = vector_type::get_pointer(x.data());
auto y_vPtr = vector_type::get_pointer(y.data());
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
vector_type vec_x, vec_y;
v_load(vec_x, x_vPtr[i]);
v_load(vec_y, y_vPtr[i]);
for (int j = 0; j < vector_type::size(); j++) {
using device::max;
vec_x.data[j] = max(vec_x.data[j], vec_y.data[j]);
}
v_store(output_vPtr[i], vec_x);
}
}
template <class T, std::size_t N>
__global__ void eltwise_sum_2_vec(Span<T> output, View<T> x, View<T> y) {
using vector_type = get_vector_type_t<T, N>;
auto output_vPtr = vector_type::get_pointer(output.data());
auto x_vPtr = vector_type::get_pointer(x.data());
auto y_vPtr = vector_type::get_pointer(y.data());
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
vector_type vec_x, vec_y;
v_load(vec_x, x_vPtr[i]);
v_load(vec_y, y_vPtr[i]);
for (int j = 0; j < vector_type::size(); j++)
vec_x.data[j] = vec_x.data[j] + vec_y.data[j];
v_store(output_vPtr[i], vec_x);
}
}
template <class T, std::size_t N>
__global__ void eltwise_sum_coeff_2_vec(Span<T> output, T coeff_x, View<T> x, T coeff_y, View<T> y) {
using vector_type = get_vector_type_t<T, N>;
auto output_vPtr = vector_type::get_pointer(output.data());
auto x_vPtr = vector_type::get_pointer(x.data());
auto y_vPtr = vector_type::get_pointer(y.data());
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
vector_type vec_x, vec_y;
v_load(vec_x, x_vPtr[i]);
v_load(vec_y, y_vPtr[i]);
for (int j = 0; j < vector_type::size(); j++)
vec_x.data[j] = coeff_x * vec_x.data[j] + coeff_y * vec_y.data[j];
v_store(output_vPtr[i], vec_x);
}
}
template <class T, std::size_t N>
__global__ void eltwise_prod_2_vec(Span<T> output, View<T> x, View<T> y) {
using vector_type = get_vector_type_t<T, N>;
auto output_vPtr = vector_type::get_pointer(output.data());
auto x_vPtr = vector_type::get_pointer(x.data());
auto y_vPtr = vector_type::get_pointer(y.data());
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
vector_type vec_x, vec_y;
v_load(vec_x, x_vPtr[i]);
v_load(vec_y, y_vPtr[i]);
for (int j = 0; j < vector_type::size(); j++)
vec_x.data[j] = vec_x.data[j] * vec_y.data[j];
v_store(output_vPtr[i], vec_x);
}
}
}
template <class T, std::size_t N>
void launch_vectorized_eltwise_max_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
CV_Assert(is_fully_aligned<T>(output, N));
CV_Assert(is_fully_aligned<T>(x, N));
CV_Assert(is_fully_aligned<T>(y, N));
auto kernel = raw::eltwise_max_2_vec<T, N>;
auto policy = make_policy(kernel, output.size() / N, 0, stream);
launch_kernel(kernel, policy, output, x, y);
}
template <class T>
void eltwise_max_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
CV_Assert(x.size() == y.size());
CV_Assert(x.size() == output.size());
if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
launch_vectorized_eltwise_max_2<T, 4>(stream, output, x, y);
} else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 2)) {
launch_vectorized_eltwise_max_2<T, 2>(stream, output, x, y);
} else {
launch_vectorized_eltwise_max_2<T, 1>(stream, output, x, y);
}
}
template void eltwise_max_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y);
template void eltwise_max_2(const Stream& stream, Span<float> output, View<float> x, View<float> y);
template <class T, std::size_t N>
void launch_vectorized_eltwise_sum_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
CV_Assert(is_fully_aligned<T>(output, N));
CV_Assert(is_fully_aligned<T>(x, N));
CV_Assert(is_fully_aligned<T>(y, N));
auto kernel = raw::eltwise_sum_2_vec<T, N>;
auto policy = make_policy(kernel, output.size() / N, 0, stream);
launch_kernel(kernel, policy, output, x, y);
}
template <class T>
void eltwise_sum_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
CV_Assert(x.size() == y.size());
CV_Assert(x.size() == output.size());
if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
launch_vectorized_eltwise_sum_2<T, 4>(stream, output, x, y);
} else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 2)) {
launch_vectorized_eltwise_sum_2<T, 2>(stream, output, x, y);
} else {
launch_vectorized_eltwise_sum_2<T, 1>(stream, output, x, y);
}
}
template void eltwise_sum_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y);
template void eltwise_sum_2(const Stream& stream, Span<float> output, View<float> x, View<float> y);
template <class T, std::size_t N>
void launch_vectorized_eltwise_sum_coeff_2(const Stream& stream, Span<T> output, T coeff_x, View<T> x, T coeff_y, View<T> y) {
CV_Assert(is_fully_aligned<T>(output, N));
CV_Assert(is_fully_aligned<T>(x, N));
CV_Assert(is_fully_aligned<T>(y, N));
auto kernel = raw::eltwise_sum_coeff_2_vec<T, N>;
auto policy = make_policy(kernel, output.size() / N, 0, stream);
launch_kernel(kernel, policy, output, coeff_x, x, coeff_y, y);
}
template <class T>
void eltwise_sum_coeff_2(const Stream& stream, Span<T> output, T coeff_x, View<T> x, T coeff_y, View<T> y) {
CV_Assert(x.size() == y.size());
CV_Assert(x.size() == output.size());
if (static_cast<float>(coeff_x) == 1.0f && static_cast<float>(coeff_y) == 1.0f) {
eltwise_sum_2(stream, output, x, y);
return;
}
if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
launch_vectorized_eltwise_sum_coeff_2<T, 4>(stream, output, coeff_x, x, coeff_y, y);
} else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 2)) {
launch_vectorized_eltwise_sum_coeff_2<T, 2>(stream, output, coeff_x, x, coeff_y, y);
} else {
launch_vectorized_eltwise_sum_coeff_2<T, 1>(stream, output, coeff_x, x, coeff_y, y);
}
}
template void eltwise_sum_coeff_2(const Stream&, Span<__half>, __half, View<__half>, __half, View<__half>);
template void eltwise_sum_coeff_2(const Stream&, Span<float>, float, View<float>, float, View<float>);
template <class T, std::size_t N>
void launch_vectorized_eltwise_prod_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
CV_Assert(is_fully_aligned<T>(output, N));
CV_Assert(is_fully_aligned<T>(x, N));
CV_Assert(is_fully_aligned<T>(y, N));
auto kernel = raw::eltwise_prod_2_vec<T, N>;
auto policy = make_policy(kernel, output.size() / N, 0, stream);
launch_kernel(kernel, policy, output, x, y);
}
template <class T>
void eltwise_prod_2(const Stream& stream, Span<T> output, View<T> x, View<T> y) {
CV_Assert(x.size() == y.size());
CV_Assert(x.size() == output.size());
if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(x, 4) && is_fully_aligned<T>(y, 4)) {
launch_vectorized_eltwise_prod_2<T, 4>(stream, output, x, y);
} else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(x, 2) && is_fully_aligned<T>(y, 2)) {
launch_vectorized_eltwise_prod_2<T, 2>(stream, output, x, y);
} else {
launch_vectorized_eltwise_prod_2<T, 1>(stream, output, x, y);
}
}
template void eltwise_prod_2(const Stream& stream, Span<__half> output, View<__half> x, View<__half> y);
template void eltwise_prod_2(const Stream& stream, Span<float> output, View<float> x, View<float> y);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA_EXECUTION_HPP
#define OPENCV_DNN_SRC_CUDA_EXECUTION_HPP
#include "../cuda4dnn/csl/error.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include <opencv2/core.hpp>
#include <cuda_runtime_api.h>
#include <cstddef>
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
struct execution_policy {
execution_policy(dim3 grid_size, dim3 block_size)
: grid{ grid_size }, block{ block_size }, sharedMem{ 0 }, stream{ 0 } { }
execution_policy(dim3 grid_size, dim3 block_size, std::size_t shared_mem)
: grid{ grid_size }, block{ block_size }, sharedMem{ shared_mem }, stream{ nullptr } { }
execution_policy(dim3 grid_size, dim3 block_size, const Stream& strm)
: grid{ grid_size }, block{ block_size }, sharedMem{ 0 }, stream{ strm.get() } { }
execution_policy(dim3 grid_size, dim3 block_size, std::size_t shared_mem, const Stream& strm)
: grid{ grid_size }, block{ block_size }, sharedMem{ shared_mem }, stream{ strm.get() } { }
dim3 grid;
dim3 block;
std::size_t sharedMem;
cudaStream_t stream;
};
/* this overload shouldn't be necessary; we should always provide a bound on the number of threads */
/*
template <class Kernel> inline
execution_policy make_policy(Kernel kernel, std::size_t sharedMem = 0, const Stream& stream = 0) {
int grid_size, block_size;
CUDA4DNN_CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&grid_size, &block_size, kernel, sharedMem));
return execution_policy(grid_size, block_size, sharedMem, stream);
}*/
template <class Kernel> inline
execution_policy make_policy(Kernel kernel, std::size_t max_threads, std::size_t sharedMem = 0, const Stream& stream = 0) {
CV_Assert(max_threads > 0);
int grid_size = 0, block_size = 0;
CUDA4DNN_CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&grid_size, &block_size, kernel, sharedMem));
if (grid_size * block_size > max_threads) {
grid_size = (max_threads + block_size - 1) / block_size;
if (block_size > max_threads)
block_size = max_threads;
}
CV_Assert(grid_size >= 1 && block_size >= 1);
return execution_policy(grid_size, block_size, sharedMem, stream);
}
template <class Kernel, typename ...Args> inline
void launch_kernel(Kernel kernel, Args ...args) {
auto policy = make_policy(kernel);
kernel <<<policy.grid, policy.block>>> (std::forward<Args>(args)...);
}
template <class Kernel, typename ...Args> inline
void launch_kernel(Kernel kernel, dim3 grid, dim3 block, Args ...args) {
kernel <<<grid, block>>> (std::forward<Args>(args)...);
}
template <class Kernel, typename ...Args> inline
void launch_kernel(Kernel kernel, execution_policy policy, Args ...args) {
kernel <<<policy.grid, policy.block, policy.sharedMem, policy.stream>>> (std::forward<Args>(args)...);
}
}}}} /* namespace cv::dnn::cuda4dnn::csl */
#endif /* OPENCV_DNN_SRC_CUDA_EXECUTION_HPP */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "vector_traits.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/span.hpp"
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, std::size_t N>
__global__ void fill_vec(Span<T> output, T value) {
using vector_type = get_vector_type_t<T, N>;
auto output_vPtr = vector_type::get_pointer(output.data());
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
vector_type vec;
for (int j = 0; j < vector_type::size(); j++)
vec.data[j] = value;
v_store(output_vPtr[i], vec);
}
}
}
template <class T, std::size_t N>
void launch_vectorized_fill(const Stream& stream, Span<T> output, T value) {
CV_Assert(is_fully_aligned<T>(output, N));
auto kernel = raw::fill_vec<T, N>;
auto policy = make_policy(kernel, output.size() / N, 0, stream);
launch_kernel(kernel, policy, output, value);
}
template <class T>
void fill(const Stream& stream, Span<T> output, T value) {
if (is_fully_aligned<T>(output, 4)) {
launch_vectorized_fill<T, 4>(stream, output, value);
} else if (is_fully_aligned<T>(output, 2)) {
launch_vectorized_fill<T, 2>(stream, output, value);
} else {
launch_vectorized_fill<T, 1>(stream, output, value);
}
}
template void fill(const Stream&, Span<__half>, __half);
template void fill(const Stream&, Span<float>, float);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA_GRID_STRIDE_RANGE_HPP
#define OPENCV_DNN_SRC_CUDA_GRID_STRIDE_RANGE_HPP
#include "types.hpp"
#include <cuda_runtime.h>
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
namespace detail {
template <int> __device__ auto getGridDim()->decltype(dim3::x);
template <> inline __device__ auto getGridDim<0>()->decltype(dim3::x) { return gridDim.x; }
template <> inline __device__ auto getGridDim<1>()->decltype(dim3::x) { return gridDim.y; }
template <> inline __device__ auto getGridDim<2>()->decltype(dim3::x) { return gridDim.z; }
template <int> __device__ auto getBlockDim()->decltype(dim3::x);
template <> inline __device__ auto getBlockDim<0>()->decltype(dim3::x) { return blockDim.x; }
template <> inline __device__ auto getBlockDim<1>()->decltype(dim3::x) { return blockDim.y; }
template <> inline __device__ auto getBlockDim<2>()->decltype(dim3::x) { return blockDim.z; }
template <int> __device__ auto getBlockIdx()->decltype(uint3::x);
template <> inline __device__ auto getBlockIdx<0>()->decltype(uint3::x) { return blockIdx.x; }
template <> inline __device__ auto getBlockIdx<1>()->decltype(uint3::x) { return blockIdx.y; }
template <> inline __device__ auto getBlockIdx<2>()->decltype(uint3::x) { return blockIdx.z; }
template <int> __device__ auto getThreadIdx()->decltype(uint3::x);
template <> inline __device__ auto getThreadIdx<0>()->decltype(uint3::x) { return threadIdx.x; }
template <> inline __device__ auto getThreadIdx<1>()->decltype(uint3::x) { return threadIdx.y; }
template <> inline __device__ auto getThreadIdx<2>()->decltype(uint3::x) { return threadIdx.z; }
}
template <int dim, class index_type = device::index_type, class size_type = device::size_type>
class grid_stride_range_generic {
public:
__device__ grid_stride_range_generic(index_type to_) : from(0), to(to_) { }
__device__ grid_stride_range_generic(index_type from_, index_type to_) : from(from_), to(to_) { }
class iterator
{
public:
__device__ iterator(index_type pos_) : pos(pos_) {}
/* these iterators return the index when dereferenced; this allows us to loop
* through the indices using a range based for loop
*/
__device__ index_type operator*() const { return pos; }
__device__ iterator& operator++() {
pos += detail::getGridDim<dim>() * static_cast<index_type>(detail::getBlockDim<dim>());
return *this;
}
__device__ bool operator!=(const iterator& other) const {
/* NOTE HACK
** 'pos' can move in large steps (see operator++)
** expansion of range for loop uses != as the loop conditioion
** => operator!= must return false if 'pos' crosses the end
*/
return pos < other.pos;
}
private:
index_type pos;
};
__device__ iterator begin() const {
using detail::getBlockDim;
using detail::getBlockIdx;
using detail::getThreadIdx;
return iterator(from + getBlockDim<dim>() * getBlockIdx<dim>() + getThreadIdx<dim>());
}
__device__ iterator end() const {
return iterator(to);
}
private:
index_type from, to;
};
using grid_stride_range_x = grid_stride_range_generic<0>;
using grid_stride_range_y = grid_stride_range_generic<1>;
using grid_stride_range_z = grid_stride_range_generic<2>;
using grid_stride_range = grid_stride_range_x;
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
#endif /* OPENCV_DNN_SRC_CUDA_GRID_STRIDE_RANGE_HPP */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA_KERNEL_DISPATCHER_HPP
#define OPENCV_DNN_SRC_CUDA_KERNEL_DISPATCHER_HPP
#include <cstddef>
#include <type_traits>
/* The performance of many kernels are highly dependent on the tensor rank. Instead of having
* one kernel which can work with the maximally ranked tensors, we make one kernel for each supported
* tensor rank. This is to ensure that the requirements of the maximally ranked tensors do not take a
* toll on the performance of the operation for low ranked tensors. Hence, many kernels take the tensor
* rank as a template parameter.
*
* The kernel is a template and we have different instantiations for each rank. This causes the following pattern
* to arise frequently:
*
* if(rank == 3)
* kernel<T, 3>();
* else if(rank == 2)
* kernel<T, 2>();
* else
* kernel<T, 1>();
*
* The rank is a runtime variable. To facilitate creation of such structures, we use GENERATE_KERNEL_DISPATCHER.
* This macro creates a function which selects the correct kernel instantiation at runtime.
*
* Example:
*
* // function which setups the kernel and launches it
* template <class T, std::size_t Rank>
* void launch_some_kernel(...);
*
* // creates the dispatcher named "some_dispatcher" which invokves the correct instantiation of "launch_some_kernel"
* GENERATE_KERNEL_DISPATCHER(some_dispatcher, launch_some_kernel);
*
* // internal API function
* template <class T>
* void some(...) {
* // ...
* auto rank = input.rank();
* some_dispatcher<T, MIN_RANK, MAX_RANK>(rank, ...);
* }
*/
/*
* name name of the dispatcher function that is generated
* func template function that requires runtime selection
*
* T first template parameter to `func`
* start starting rank
* end ending rank (inclusive)
*
* Executes func<T, selector> based on runtime `selector` argument given `selector` lies
* within the range [start, end]. If outside the range, no instantiation of `func` is executed.
*/
#define GENERATE_KERNEL_DISPATCHER(name,func); \
template <class T, std::size_t start, std::size_t end, class... Args> static \
typename std::enable_if<start == end, void> \
::type name(int selector, Args&& ...args) { \
if(selector == start) \
func<T, start>(std::forward<Args>(args)...); \
} \
\
template <class T, std::size_t start, std::size_t end, class... Args> static \
typename std::enable_if<start != end, void> \
::type name(int selector, Args&& ...args) { \
if(selector == start) \
func<T, start>(std::forward<Args>(args)...); \
else \
name<T, start + 1, end, Args...>(selector, std::forward<Args>(args)...); \
}
#endif /* OPENCV_DNN_SRC_CUDA_KERNEL_DISPATCHER_HPP */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA_LIMITS_HPP
#define OPENCV_DNN_SRC_CUDA_LIMITS_HPP
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cfloat>
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
template <class T>
struct numeric_limits;
template <>
struct numeric_limits<__half> {
__device__ static __half min() { return 0.0000610; }
__device__ static __half max() { return 65504.0; }
__device__ static __half lowest() { return -65504.0; }
};
template <>
struct numeric_limits<float> {
__device__ static float min() { return FLT_MIN; }
__device__ static float max() { return FLT_MAX; }
__device__ static float lowest() { return -FLT_MAX; }
};
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
#endif /* OPENCV_DNN_SRC_CUDA_LIMITS_HPP */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA_MATH_HPP
#define OPENCV_DNN_SRC_CUDA_MATH_HPP
#include <cuda_runtime.h>
#include <cuda_fp16.h>
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
template <class T> __device__ T abs(T val) { return (val < T(0) ? -val : val); }
template <> inline __device__ __half2 abs(__half2 val) {
val.x = abs(val.x);
val.y = abs(val.y);
return val;
}
template <> inline __device__ float abs(float val) { return fabsf(val); }
template <> inline __device__ double abs(double val) { return fabs(val); }
template <class T> __device__ T exp(T val);
template <> inline __device__ __half exp(__half val) { return hexp(val); }
template <> inline __device__ __half2 exp(__half2 val) { return h2exp(val); }
template <> inline __device__ float exp(float val) { return expf(val); }
template <> inline __device__ double exp(double val) { return ::exp(val); }
template <class T> __device__ T expm1(T val);
template <> inline __device__ __half expm1(__half val) { return hexp(val) + __half(1); }
template <> inline __device__ __half2 expm1(__half2 val) { return h2exp(val) + __half2(1, 1); }
template <> inline __device__ float expm1(float val) { return expm1f(val); }
template <> inline __device__ double expm1(double val) { return ::expm1(val); }
template <class T> __device__ T max(T x, T y) { return (x > y ? x : y); }
template <> inline __device__ __half2 max(__half2 a, __half2 b) {
a.x = max(a.x, a.x);
a.y = max(a.y, b.y);
return a;
}
template <> inline __device__ float max(float x, float y) { return fmaxf(x, y); }
template <> inline __device__ double max(double x, double y) { return fmax(x, y); }
template <class T> __device__ T min(T x, T y) { return (x > y ? y : x); }
template <> inline __device__ __half2 min(__half2 a, __half2 b) {
a.x = min(a.x, a.x);
a.y = min(a.y, b.y);
return a;
}
template <> inline __device__ float min(float x, float y) { return fminf(x, y); }
template <> inline __device__ double min(double x, double y) { return fmin(x, y); }
template <class T> __device__ T log1p(T val);
template <> inline __device__ __half log1p(__half val) { return hlog(val) + __half(1); }
template <> inline __device__ __half2 log1p(__half2 val) { return h2log(val) + __half2(1, 1); }
template <> inline __device__ float log1p(float val) { return log1pf(val); }
template <class T> __device__ T log1pexp(T val);
template <> inline __device__ __half log1pexp(__half val) {
if (val <= __half(-4.0))
return exp(val);
else if (val <= __half(8.0))
return log1p(exp(val));
else if (val <= __half(8.7))
return val + exp(-val);
else
return val;
}
template <> inline __device__ __half2 log1pexp(__half2 val) {
val.x = log1pexp(val.x);
val.y = log1pexp(val.y);
return val;
}
template <> inline __device__ float log1pexp(float val) {
if (val <= -20)
return expf(val);
else if (val <= 9.0)
return log1pf(expf(val));
else if (val <= 14.6)
return val + exp(-val);
else
return val;
}
template <> inline __device__ double log1pexp(double val) {
if (val <= -37)
return exp(val);
else if (val <= 18)
return log1p(exp(val));
else if (val <= 33.3)
return val + exp(-val);
else
return val;
}
template <class T> __device__ T tanh(T val);
template <> inline __device__ __half tanh(__half val) { return tanhf(val); }
template <> inline __device__ __half2 tanh(__half2 val) { return __half2(tanh(val.x), tanh(val.y)); }
template <> inline __device__ float tanh(float val) { return tanhf(val); }
template <> inline __device__ double tanh(double val) { return ::tanh(val); }
template <class T> __device__ T pow(T val, T exp);
template <> inline __device__ __half pow(__half val, __half exp) { return powf(val, exp); }
template <> inline __device__ __half2 pow(__half2 val, __half2 exp) { return __half2(pow(val.x, exp.x), pow(val.y, exp.y)); }
template <> inline __device__ float pow(float val, float exp) { return powf(val, exp); }
template <> inline __device__ double pow(double val, double exp) { return ::pow(val, exp); }
template <class T> __device__ T sqrt(T val);
template <> inline __device__ __half sqrt(__half val) { return hsqrt(val); }
template <> inline __device__ __half2 sqrt(__half2 val) { return h2sqrt(val); }
template <> inline __device__ float sqrt(float val) { return sqrtf(val); }
template <> inline __device__ double sqrt(double val) { return ::sqrt(val); }
template <class T> __device__ T rsqrt(T val);
template <> inline __device__ __half rsqrt(__half val) { return hrsqrt(val); }
template <> inline __device__ __half2 rsqrt(__half2 val) { return h2rsqrt(val); }
template <> inline __device__ float rsqrt(float val) { return rsqrtf(val); }
template <> inline __device__ double rsqrt(double val) { return ::rsqrt(val); }
template <class T> __device__ T sigmoid(T val) { return T(1) / (T(1) + exp(-val)); }
template <> inline __device__ __half2 sigmoid(__half2 val) { return __half2(1, 1) / (__half2(1, 1) + exp(__hneg2(val))); }
template <class T> __device__ T clamp(T value, T lower, T upper) { return min(max(value, lower), upper); }
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
#endif /* OPENCV_DNN_SRC_CUDA_MATH_HPP */
This diff is collapsed.
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "array.hpp"
#include "math.hpp"
#include "types.hpp"
#include "atomics.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/span.hpp"
#include "../cuda4dnn/kernels/fill.hpp"
#include "../cuda4dnn/kernels/scale_shift.hpp"
#include <opencv2/core.hpp>
#include <cstddef>
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T>
__global__ void reduce_sum_abs(Span<T> output, View<T> input, size_type outer_stride, size_type mid_stride) {
for (auto idx : grid_stride_range(input.size())) {
const index_type outer_idx = idx / outer_stride;
const index_type inner_idx = idx % mid_stride;
const index_type sum_idx = outer_idx * mid_stride + inner_idx;
atomicAdd(&output[sum_idx], device::abs(input[idx]));
}
}
template <class T>
__global__ void reciprocal(Span<T> output, T epsilon) {
for (auto idx : grid_stride_range(output.size()))
output[idx] = T(1) / (output[idx] + epsilon);
}
template <class T>
__global__ void reduce_sum_squared(Span<T> output, View<T> input, size_type outer_stride, size_type mid_stride) {
for (auto idx : grid_stride_range(input.size())) {
const index_type outer_idx = idx / outer_stride;
const index_type inner_idx = idx % mid_stride;
const index_type sum_idx = outer_idx * mid_stride + inner_idx;
atomicAdd(&output[sum_idx], input[idx] * input[idx]);
}
}
template <class T>
__global__ void rsqrt(Span<T> output, T epsilon) {
for (auto idx : grid_stride_range(output.size())) {
using device::sqrt;
output[idx] = T(1) / sqrt(output[idx] + epsilon);
}
}
template <class T>
__global__ void apply_norm(Span<T> output, View<T> input, size_type outer_stride, size_type mid_stride, View<T> sums) {
for (auto idx : grid_stride_range(output.size())) {
const index_type outer_idx = idx / outer_stride;
const index_type inner_idx = idx % mid_stride;
const index_type sum_idx = outer_idx * mid_stride + inner_idx;
output[idx] = input[idx] * sums[sum_idx];
}
}
}
template <class T>
void normalize(
const Stream& stream,
Span<T> output,
View<T> input, std::size_t outer_size, std::size_t mid_size, std::size_t inner_size, std::size_t norm, T epsilon,
Span<T> workspace)
{
CV_Assert(output.size() == input.size());
CV_Assert(output.size() == outer_size * mid_size * inner_size);
CV_Assert(norm == 1 || norm == 2);
CV_Assert(workspace.size() >= outer_size * inner_size);
auto sums = Span<T>(workspace.data(), outer_size * inner_size);
fill<T>(stream, sums, 0.0);
if (norm == 1) {
auto reduce_kernel = raw::reduce_sum_abs<T>;
auto policy = make_policy(reduce_kernel, input.size(), 0, stream);
launch_kernel(reduce_kernel, policy, sums, input, mid_size * inner_size, inner_size);
auto reciprocal_kernel = raw::reciprocal<T>;
policy = make_policy(reciprocal_kernel, sums.size(), 0, stream);
launch_kernel(reciprocal_kernel, policy, sums, epsilon);
} else {
auto reduce_kernel = raw::reduce_sum_squared<T>;
auto policy = make_policy(reduce_kernel, input.size(), 0, stream);
launch_kernel(reduce_kernel, policy, sums, input, mid_size * inner_size, inner_size);
auto rsqrt_kernel = raw::rsqrt<T>;
policy = make_policy(rsqrt_kernel, sums.size(), 0, stream);
launch_kernel(rsqrt_kernel, policy, sums, epsilon);
}
auto scale_kernel = raw::apply_norm<T>;
auto policy = make_policy(scale_kernel, output.size(), 0, stream);
launch_kernel(scale_kernel, policy, output, input, mid_size * inner_size, inner_size, sums);
}
template void normalize(const Stream&, Span<__half>, View<__half>, std::size_t, std::size_t, std::size_t, std::size_t, __half, Span<__half>);
template void normalize(const Stream&, Span<float>, View<float>, std::size_t, std::size_t, std::size_t, std::size_t, float, Span<float>);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "array.hpp"
#include "math.hpp"
#include "types.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "kernel_dispatcher.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/tensor.hpp"
#include "../cuda4dnn/csl/span.hpp"
#include <opencv2/core.hpp>
#include <cstddef>
#include <vector>
#include <utility>
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, std::size_t Rank>
__global__ void copy_with_reflection101(
Span<T> output, array<size_type, Rank> out_strides, array<index_type, Rank> start, array<index_type, Rank> end,
View<T> input, array<size_type, Rank> in_strides)
{
for (auto i : grid_stride_range(output.size())) {
/* compute output axis indices corresponding to element 'i' */
array<index_type, Rank> out_index;
out_index[0] = i / out_strides[0];
for (int j = 1; j < Rank; j++)
out_index[j] = (i % out_strides[j - 1]) / out_strides[j];
/* compute input axis indices corresponding to output axis indices */
array<index_type, Rank> in_index;
for (int j = 0; j < Rank; j++) {
/* if out_index < start, the point is in the left reflection region
* the reflected value's index is the absolute value of the difference
*
* otherwise, if the value is in the copy region, out_index - start gives the input index
*/
using device::abs;
in_index[j] = abs(out_index[j] - start[j]);
/* if out_index >= end, it's in the right reflection region */
if (out_index[j] >= end[j])
in_index[j] = (end[j] - start[j]) - (out_index[j] - end[j]) - 2;
}
/* compute input element number from input axis indices */
index_type iidx = 0;
for (int j = 0; j < Rank; j++)
iidx += in_index[j] * in_strides[j];
output[i] = input[iidx];
}
}
}
template <class T, std::size_t Rank> static
void launch_copy_with_reflection101(
const Stream& stream,
Span<T> output, const std::vector<std::size_t>& outStride,
View<T> input, const std::vector<std::size_t>& inStride,
const std::vector<std::pair<std::size_t, std::size_t>>& ranges)
{
CV_Assert(outStride.size() == Rank);
CV_Assert(inStride.size() == Rank);
CV_Assert(ranges.size() == Rank);
array<size_type, Rank> outStride_k, inStride_k;
outStride_k.assign(std::begin(outStride), std::end(outStride));
inStride_k.assign(std::begin(inStride), std::end(inStride));
array<index_type, Rank> start_k, end_k;
for (int i = 0; i < Rank; i++) {
start_k[i] = ranges[i].first;
end_k[i] = ranges[i].second;
}
auto kernel = raw::copy_with_reflection101<T, Rank>;
auto policy = make_policy(kernel, output.size(), 0, stream);
launch_kernel(kernel, policy, output, outStride_k, start_k, end_k, input, inStride_k);
}
GENERATE_KERNEL_DISPATCHER(copy_with_reflection101_dispatcher, launch_copy_with_reflection101);
template <class T>
void copy_with_reflection101(
const Stream& stream,
TensorSpan<T> output, TensorView<T> input,
std::vector<std::pair<std::size_t, std::size_t>> ranges)
{
CV_Assert(output.rank() == input.rank());
CV_Assert(output.rank() == ranges.size());
/* squeezable axes at the begining of both tensors can be eliminated
*
* Reasoning:
* ----------
* Suppose an item's indices in the input tensor is [i1, i2, ...]. The indices in the
* output tensor will be [i1 + off1, i2 + off2, ...]. The rest of the elements in the output are padding.
* The padding operation essentially copies items from the input tensor to new locations in the output tensor
* and pads the remaining.
*
* If the size of the first axis of the input and output tensor is unity, the input and output indices
* for all the elements will be of the form be [0, i2, ...] and [0, i2 + off2, ...] respectively. Note that
* there cannot be extra padding since the axes have unit size. The first index does not contribute to the
* element's address calculation and hence does nothing apart from eating up few cycles.
*/
while (input.get_axis_size(0) == 1 && output.get_axis_size(0) == 1) {
CV_Assert(ranges[0].first == 0 && ranges[0].second == 1);
input.squeeze(0);
output.squeeze(0);
ranges.erase(std::begin(ranges));
CV_Assert(output.rank() == input.rank());
CV_Assert(output.rank() == ranges.size());
}
auto inShape = input.shape_as_vector();
auto outShape = output.shape_as_vector();
/* contiguous axes which do not have any padding can be combined into one axis
*
* Reasoning:
* ----------
* Suppose an item's indices in the input tensor is [i1, i2, i3, ...]. Let the first two axes not have any
* padding. The indices in the output tensor will be [i1, i2, i3 + off3, ...].
*
* Each axis in the contiguous unpadded axes sequence will add an offset of iN * strideN. In the above example,
* the two axes add a total offset of `i1 * stride1 + i2 * stride2`. We can merge the two axes into one axis with
* a size of `size1 * size2`. The new offset added will be `i12 * stride2` as the kernel iterates through `i12`.
* Note that `i12` is actually `(i1 * size2 + i2)` in the original tensor.
*/
for (int i = 0; i < inShape.size(); i++) {
/* check if axis `i` requires any padding */
if (ranges[i].first == 0 && ranges[i].second == inShape[i]) {
/* loop invariant: `i` is the first axis in the contiguous unpadded axis sequence */
CV_Assert(inShape[i] == outShape[i]);
/* we now iterate through the axes which follow and try to merge */
int j = i + 1; /* `j` is the axis which we will attempt to merge */
while (j < inShape.size() && ranges[j].first == 0 && ranges[j].second == inShape[j]) {
CV_Assert(inShape[j] == outShape[j]);
/* `j` is also unpadded; merge `i` and `j` */
auto new_size = inShape[i] * inShape[j];
inShape[i] = new_size;
outShape[i] = new_size;
ranges[i].second = new_size;
/* delete axis `j` */
inShape.erase(std::begin(inShape) + j);
outShape.erase(std::begin(outShape) + j);
ranges.erase(std::begin(ranges) + j);
/* optimizations should not break the invariants */
CV_Assert(inShape.size() == outShape.size());
CV_Assert(inShape.size() == ranges.size());
CV_Assert(inShape[i] == outShape[i]);
CV_Assert(ranges[i].first == 0 && ranges[i].second == inShape[i]);
}
}
}
auto rank = inShape.size();
std::vector<std::size_t> inStride(rank), outStride(rank);
inStride.back() = 1;
outStride.back() = 1;
/* garbage, ..., garbage, 1 */
std::copy(std::begin(inShape) + 1, std::end(inShape), std::begin(inStride));
std::copy(std::begin(outShape) + 1, std::end(outShape), std::begin(outStride));
/* dim[0], dim[1], ..., dim[-1], 1 */
std::partial_sum(inStride.rbegin(), inStride.rend(), inStride.rbegin(), std::multiplies<int>());
std::partial_sum(outStride.rbegin(), outStride.rend(), outStride.rbegin(), std::multiplies<int>());
/* stride[0], stride[1], ..., stride[-2], 1 */
CV_Assert(1 <= rank && rank <= CSL_MAX_TENSOR_RANK);
copy_with_reflection101_dispatcher<T, 1, CSL_MAX_TENSOR_RANK>(rank, stream, output, outStride, input, inStride, ranges);
}
template void copy_with_reflection101(const Stream&, TensorSpan<__half>, TensorView<__half>, std::vector<std::pair<std::size_t, std::size_t>> ranges);
template void copy_with_reflection101(const Stream&, TensorSpan<float>, TensorView<float>, std::vector<std::pair<std::size_t, std::size_t>> ranges);
}}}} /* namespace namespace cv::dnn::cuda4dnn::kernels */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "array.hpp"
#include "types.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "kernel_dispatcher.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/tensor.hpp"
#include "../cuda4dnn/csl/span.hpp"
#include <opencv2/core.hpp>
#include <cstddef>
#include <vector>
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, std::size_t Rank>
__global__ void permute(
array<index_type, Rank> axis_order,
Span<T> output, array<size_type, Rank> outStrides,
View<T> input, array<size_type, Rank> inStrides)
{
for (auto i : grid_stride_range(input.size())) {
index_type oldPosition = 0;
index_type newPosition = i;
for (int j = 0; j < Rank; j++)
{
auto order = axis_order[j];
oldPosition += (newPosition / outStrides[j]) * inStrides[order];
newPosition %= outStrides[j];
}
output[i] = input[oldPosition];
}
}
}
template <class T, std::size_t Rank> static
void launch_permute_kernel(
const Stream& stream,
const std::vector<std::size_t>& order,
Span<T> output, const std::vector<std::size_t>& outStride,
View<T> input, const std::vector<std::size_t>& inStride)
{
CV_Assert(order.size() == Rank);
CV_Assert(outStride.size() == Rank);
CV_Assert(inStride.size() == Rank);
array<index_type, Rank> order_k;
order_k.assign(std::begin(order), std::end(order));
array<size_type, Rank> outStride_k, inStride_k;
outStride_k.assign(std::begin(outStride), std::end(outStride));
inStride_k.assign(std::begin(inStride), std::end(inStride));
auto kernel = raw::permute<T, Rank>;
auto policy = make_policy(kernel, input.size(), 0, stream);
launch_kernel(kernel, policy, order_k, output, outStride_k, input, inStride_k);
}
GENERATE_KERNEL_DISPATCHER(permute_dispatcher, launch_permute_kernel);
template <class T>
void permute(
const Stream& stream,
TensorSpan<T> output, TensorView<T> input,
std::vector<std::size_t> order)
{
CV_Assert(output.rank() == input.rank());
CV_Assert(input.rank() == order.size());
CV_Assert(input.size() == output.size());
/* squeezable axes at the begining of both tensors which aren't permuted can be eliminated
*
* Reasoning:
* ----------
* Suppose an item's indices in the input tensor is [i1, i2, ...]. The indices in the
* output tensor will be some permutation of the input tensor indices. Let the output
* tensor indices be [o1, o2, ...]. The permutation operation essentially copies items
* from the input tensor to new locations in the output tensor as dictated by the indices.
*
* If the size of the first axis of the input and output tensor is one and these axes are
* not involved in any permutation, i.e. order[0] = 0, the input and output indicies for
* all the elements will be of the form be [0, i2, ...] and [0, o2, ...] respectively.
* The first index does not contribute to the element's address calculation and hence does
* nothing apart from eating up few cycles.
*/
while (order[0] == 0 && input.get_axis_size(0) == 1 && output.get_axis_size(0) == 1) {
/* remove the axes */
input.squeeze(0);
output.squeeze(0);
/* when we remove axis zero, the axis index will be one less than the previous index
* for the remaining axes
*/
order.erase(order.begin());
for (auto& axis : order)
axis--;
/* optimizations should not break the invariants */
CV_Assert(output.rank() == input.rank());
CV_Assert(input.rank() == order.size());
CV_Assert(input.size() == output.size());
}
auto rank = output.rank();
auto inShape = input.shape_as_vector();
auto outShape = output.shape_as_vector();
std::vector<std::size_t> inStride(rank), outStride(rank);
inStride.back() = 1;
outStride.back() = 1;
/* garbage, ..., garbage, 1 */
std::copy(std::begin(inShape) + 1, std::end(inShape), std::begin(inStride));
std::copy(std::begin(outShape) + 1, std::end(outShape), std::begin(outStride));
/* dim[0], dim[1], ..., dim[-1], 1 */
std::partial_sum(inStride.rbegin(), inStride.rend(), inStride.rbegin(), std::multiplies<std::size_t>());
std::partial_sum(outStride.rbegin(), outStride.rend(), outStride.rbegin(), std::multiplies<std::size_t>());
/* stride[0], stride[1], ..., stride[-2], 1 */
CV_Assert(2 <= rank && rank <= CSL_MAX_TENSOR_RANK);
permute_dispatcher<T, 2, CSL_MAX_TENSOR_RANK>(rank, stream, order, output, outStride, input, inStride);
}
template void permute(const Stream&, TensorSpan<__half>, TensorView<__half>, std::vector<std::size_t>);
template void permute(const Stream&, TensorSpan<float>, TensorView<float>, std::vector<std::size_t>);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "array.hpp"
#include "math.hpp"
#include "types.hpp"
#include "vector_traits.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/span.hpp"
#include <cstddef>
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, bool Normalize>
__global__ void prior_box(
Span<T> output,
View<float> boxWidth, View<float> boxHeight, View<float> offsetX, View<float> offsetY, float stepX, float stepY,
size_type layerWidth, size_type layerHeight,
size_type imageWidth, size_type imageHeight)
{
/* each box consists of two pair of coordinates and hence 4 values in total */
/* since the entire output consists (first channel at least) of these boxes,
* we are garunteeed that the output is aligned to a boundary of 4 values
*/
using vector_type = get_vector_type_t<T, 4>;
auto output_vPtr = vector_type::get_pointer(output.data());
/* num_points contains the number of points in the feature map of interest
* each iteration of the stride loop selects a point and generates prior boxes for it
*/
size_type num_points = layerWidth * layerHeight;
for (auto idx : grid_stride_range(num_points)) {
const index_type x = idx % layerWidth,
y = idx / layerWidth;
index_type output_offset_v4 = idx * offsetX.size() * boxWidth.size();
for (int i = 0; i < boxWidth.size(); i++) {
for (int j = 0; j < offsetX.size(); j++) {
float center_x = (x + offsetX[j]) * stepX;
float center_y = (y + offsetY[j]) * stepY;
vector_type vec;
if(Normalize) {
vec.data[0] = (center_x - boxWidth[i] * 0.5f) / imageWidth;
vec.data[1] = (center_y - boxHeight[i] * 0.5f) / imageHeight;
vec.data[2] = (center_x + boxWidth[i] * 0.5f) / imageWidth;
vec.data[3] = (center_y + boxHeight[i] * 0.5f) / imageHeight;
} else {
vec.data[0] = center_x - boxWidth[i] * 0.5f;
vec.data[1] = center_y - boxHeight[i] * 0.5f;
vec.data[2] = center_x + boxWidth[i] * 0.5f - 1.0f;
vec.data[3] = center_y + boxHeight[i] * 0.5f - 1.0f;
}
v_store(output_vPtr[output_offset_v4], vec);
output_offset_v4++;
}
}
}
}
template <class T>
__global__ void prior_box_clip(Span<T> output) {
for (auto i : grid_stride_range(output.size())) {
using device::clamp;
output[i] = clamp<T>(output[i], 0.0, 1.0);
}
}
template <class T>
__global__ void prior_box_set_variance1(Span<T> output, float variance) {
using vector_type = get_vector_type_t<T, 4>;
auto output_vPtr = vector_type::get_pointer(output.data());
for (auto i : grid_stride_range(output.size() / 4)) {
vector_type vec;
for (int j = 0; j < 4; j++)
vec.data[j] = variance;
v_store(output_vPtr[i], vec);
}
}
template <class T>
__global__ void prior_box_set_variance4(Span<T> output, array<float, 4> variance) {
using vector_type = get_vector_type_t<T, 4>;
auto output_vPtr = vector_type::get_pointer(output.data());
for (auto i : grid_stride_range(output.size() / 4)) {
vector_type vec;
for(int j = 0; j < 4; j++)
vec.data[j] = variance[j];
v_store(output_vPtr[i], vec);
}
}
}
template <class T, bool Normalize> static
void launch_prior_box_kernel(
const Stream& stream,
Span<T> output, View<float> boxWidth, View<float> boxHeight, View<float> offsetX, View<float> offsetY, float stepX, float stepY,
std::size_t layerWidth, std::size_t layerHeight, std::size_t imageWidth, std::size_t imageHeight)
{
auto num_points = layerWidth * layerHeight;
auto kernel = raw::prior_box<T, Normalize>;
auto policy = make_policy(kernel, num_points, 0, stream);
launch_kernel(kernel, policy,
output, boxWidth, boxHeight, offsetX, offsetY, stepX, stepY,
layerWidth, layerHeight, imageWidth, imageHeight);
}
template <class T>
void generate_prior_boxes(
const Stream& stream,
Span<T> output,
View<float> boxWidth, View<float> boxHeight, View<float> offsetX, View<float> offsetY, float stepX, float stepY,
std::vector<float> variance,
std::size_t numPriors,
std::size_t layerWidth, std::size_t layerHeight,
std::size_t imageWidth, std::size_t imageHeight,
bool normalize, bool clip)
{
if (normalize) {
launch_prior_box_kernel<T, true>(
stream, output, boxWidth, boxHeight, offsetX, offsetY, stepX, stepY,
layerWidth, layerHeight, imageWidth, imageHeight
);
} else {
launch_prior_box_kernel<T, false>(
stream, output, boxWidth, boxHeight, offsetX, offsetY, stepX, stepY,
layerWidth, layerHeight, imageWidth, imageHeight
);
}
std::size_t channel_size = layerHeight * layerWidth * numPriors * 4;
CV_Assert(channel_size * 2 == output.size());
if (clip) {
auto output_span_c1 = Span<T>(output.data(), channel_size);
auto kernel = raw::prior_box_clip<T>;
auto policy = make_policy(kernel, output_span_c1.size(), 0, stream);
launch_kernel(kernel, policy, output_span_c1);
}
auto output_span_c2 = Span<T>(output.data() + channel_size, channel_size);
if (variance.size() == 1) {
auto kernel = raw::prior_box_set_variance1<T>;
auto policy = make_policy(kernel, output_span_c2.size() / 4, 0, stream);
launch_kernel(kernel, policy, output_span_c2, variance[0]);
} else {
array<float, 4> variance_k;
variance_k.assign(std::begin(variance), std::end(variance));
auto kernel = raw::prior_box_set_variance4<T>;
auto policy = make_policy(kernel, output_span_c2.size() / 4, 0, stream);
launch_kernel(kernel, policy, output_span_c2, variance_k);
}
}
template void generate_prior_boxes(const Stream&, Span<__half>, View<float>, View<float>, View<float>, View<float>, float, float,
std::vector<float>, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, bool, bool);
template void generate_prior_boxes(const Stream&, Span<float>, View<float>, View<float>, View<float>, View<float>, float, float,
std::vector<float>, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, bool, bool);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "math.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "limits.hpp"
#include "vector_traits.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/span.hpp"
#include <opencv2/core.hpp>
#include <cstddef>
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T>
__global__ void sigmoid_strided(Span<T> output, View<T> input, size_type n, size_type stride, size_type offset) {
/* - the input is divided into equal blocks strided by `stride`
* - we must apply sigmoid to a continuous range of `n` values starting from `offset` in every block
*/
for (auto i : grid_stride_range(n * output.size() / stride)) {
auto block_idx = i / n;
auto index = block_idx * stride + offset + (i % n);
using device::sigmoid;
output[index] = sigmoid(input[index]);
}
}
template <class T>
__global__ void softmax_strided(Span<T> output, View<T> input, size_type n, size_type stride, size_type offset_) {
for (auto idx : grid_stride_range(output.size() / stride)) {
index_type offset = idx * stride + offset_;
auto largest = numeric_limits<T>::lowest();
for (int i = 0; i < n; i++) {
using device::max;
largest = max(largest, output[offset + i]);
}
auto sum = T(0);
for (int i = 0; i < n; i++) {
using device::exp;
auto temp = exp(output[offset + i] - largest);
sum += temp;
output[offset + i] = temp;
}
for (int i = 0; i < n; i++) {
output[offset + i] /= sum;
}
}
}
template <class T>
__global__ void region_finalize(Span<T> output, View<T> input, View<T> bias,
T object_prob_cutoff, T class_prob_cutoff,
size_type height_norm, size_type width_norm,
size_type rows, size_type cols,
size_type boxes_per_cell,
size_type box_size,
size_type classes)
{
for (auto box_index : grid_stride_range(output.size() / box_size)) {
auto box_of_the_cell = box_index % boxes_per_cell; /* box number within a cell */
auto box_offset = box_index * box_size;
auto batch_inner_size = rows * cols * boxes_per_cell;
auto row_inner_size = cols * boxes_per_cell;
auto col_inner_size = boxes_per_cell;
auto y = (box_index % batch_inner_size) / row_inner_size;
auto x = (box_index % row_inner_size) / col_inner_size;
using device::sigmoid;
using device::exp;
output[box_offset + 0] = (T(x) + sigmoid(input[box_offset + 0])) / T(cols);
output[box_offset + 1] = (T(y) + sigmoid(input[box_offset + 1])) / T(rows);
output[box_offset + 2] = exp(input[box_offset + 2]) * bias[2 * box_of_the_cell + 0] / T(width_norm);
output[box_offset + 3] = exp(input[box_offset + 3]) * bias[2 * box_of_the_cell + 1] / T(height_norm);
/* squash objectness score into a probability */
using device::sigmoid;
T objectness_prob = sigmoid(output[box_offset + 4]);
output[box_offset + 4] = objectness_prob;
/* ignore prediction if the objectness probability is less than the cutoff */
if (objectness_prob < object_prob_cutoff)
objectness_prob = 0;
/* the class probabilities we have currently are conditional class probabilities
* given the object
*
* to obtain the actual class probability, we multiply the conditional probability
* with the object probability
*/
const index_type class_begin = box_offset + 5; /* 4 box coordinates, 1 obj prob, class probs... */
const index_type class_end = class_begin + classes;
index_type offset = class_begin;
using vector_type = get_vector_type_t<T, 4>;
/* process each class independently until the offset is aligned to an n-element boundary */
while (offset % vector_type::size() != 0 && offset < class_end) {
T actual_class_prob = objectness_prob * output[offset];
if (actual_class_prob <= class_prob_cutoff)
actual_class_prob = T(0);
output[offset] = actual_class_prob;
offset++;
}
auto output_vPtr = vector_type::get_pointer(output.data() + offset);
auto input_vPtr = vector_type::get_pointer(input.data() + offset);
for (int i = 0; (offset + vector_type::size()) < class_end; i++) {
vector_type vec;
v_load(vec, output_vPtr[i]);
for (int j = 0; j < vector_type::size(); j++) {
T actual_class_prob = objectness_prob * vec.data[j];
if (actual_class_prob <= class_prob_cutoff)
actual_class_prob = T(0);
vec.data[j] = actual_class_prob;
}
v_store(output_vPtr[i], vec);
offset += vector_type::size();
}
/* process the remaining classes */
while (offset < class_end) {
T actual_class_prob = objectness_prob * output[offset];
if (actual_class_prob <= class_prob_cutoff)
actual_class_prob = T(0);
output[offset] = actual_class_prob;
offset++;
}
}
}
}
template <class T>
void sigmoid_strided(const Stream& stream, Span<T> output, View<T> input, std::size_t n, std::size_t stride, std::size_t offset) {
CV_Assert(output.size() % stride == 0);
auto kernel = raw::sigmoid_strided<T>;
auto policy = make_policy(kernel, n * output.size() / stride, 0, stream);
launch_kernel(kernel, policy, output, input, n, stride, offset);
}
template void sigmoid_strided(const Stream&, Span<__half>, View<__half>, std::size_t, std::size_t, std::size_t);
template void sigmoid_strided(const Stream&, Span<float>, View<float>, std::size_t, std::size_t, std::size_t);
template <class T>
void softmax_strided(const Stream& stream, Span<T> output, View<T> input, std::size_t n, std::size_t stride, std::size_t offset) {
CV_Assert(output.size() % stride == 0);
auto kernel = raw::softmax_strided<T>;
auto policy = make_policy(kernel, output.size() / stride, 0, stream);
launch_kernel(kernel, policy, output, input, n, stride, offset);
}
template void softmax_strided(const Stream&, Span<__half>, View<__half>, std::size_t, std::size_t, std::size_t);
template void softmax_strided(const Stream&, Span<float>, View<float>, std::size_t, std::size_t, std::size_t);
template <class T>
void region_finalize(const Stream& stream, Span<T> output, View<T> input, View<T> bias,
T object_prob_cutoff, T class_prob_cutoff,
std::size_t height_norm, std::size_t width_norm,
std::size_t rows, std::size_t cols,
std::size_t boxes_per_cell,
std::size_t box_size,
std::size_t classes)
{
CV_Assert(output.size() % box_size == 0);
auto kernel = raw::region_finalize<T>;
auto policy = make_policy(kernel, output.size() / box_size, 0, stream);
launch_kernel(kernel, policy, output, input, bias,
object_prob_cutoff, class_prob_cutoff,
height_norm, width_norm,
rows, cols, boxes_per_cell, box_size, classes);
}
template void region_finalize(const Stream&, Span<__half>, View<__half>, View<__half>,
__half, __half, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t);
template void region_finalize(const Stream&, Span<float>, View<float>, View<float>,
float, float, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t, std::size_t);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "math.hpp"
#include "types.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/tensor.hpp"
#include "../cuda4dnn/csl/span.hpp"
#include <cuda_runtime.h>
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T>
__global__ void resize_nn(
Span<T> output, size_type out_height, size_type out_width,
View<T> input, size_type in_height, size_type in_width)
{
auto in_image_size = in_height * in_width;
auto out_image_size = out_height * out_width;
/* o2i = output to input */
auto o2i_fx = static_cast<float>(in_width) / out_width;
auto o2i_fy = static_cast<float>(in_height) / out_height;
/* think of the output and input as a collection of 2d images with the last axis
* representing the width and the last but one axis representing the height
*
* the remaining axis together form a collection of these images
*/
for (auto idx : grid_stride_range(output.size())) {
const index_type n = idx / out_image_size;
const index_type x = (idx % out_image_size) % out_width;
const index_type y = (idx % out_image_size) / out_width;
auto in_x = static_cast<index_type>(x * o2i_fx);
auto in_y = static_cast<index_type>(y * o2i_fy);
index_type in_idx = n * in_image_size + in_y * in_width + in_x;
output[idx] = input[in_idx];
}
}
template <class T>
__global__ void resize_bilinear(
Span<T> output, size_type out_height, size_type out_width,
View<T> input, size_type in_height, size_type in_width,
float o2i_fy, float o2i_fx)
{
auto in_image_size = in_height * in_width;
auto out_image_size = out_height * out_width;
/* think of the output and input as a collection of 2d images with the last axis
* representing the width and the last but one axis representing the height
*
* the remaining axis together form a collection of these images
*/
for (auto idx : grid_stride_range(output.size())) {
const index_type n = idx / out_image_size;
const index_type x = (idx % out_image_size) % out_width;
const index_type y = (idx % out_image_size) / out_width;
auto in_x = x * o2i_fx;
auto in_y = y * o2i_fy;
auto in_x0 = static_cast<index_type>(in_x);
auto in_y0 = static_cast<index_type>(in_y);
using device::min;
auto in_x1 = min<index_type>(in_x0 + 1, in_width - 1);
auto in_y1 = min<index_type>(in_y0 + 1, in_height - 1);
const index_type in_offset_r0 = n * in_image_size + in_y0 * in_width;
const index_type in_offset_r1 = n * in_image_size + in_y1 * in_width;
auto v_00 = input[in_offset_r0 + in_x0],
v_01 = input[in_offset_r0 + in_x1],
v_10 = input[in_offset_r1 + in_x0],
v_11 = input[in_offset_r1 + in_x1];
output[idx] =
v_00 +
T(in_y - in_y0) * T(v_10 - v_00) +
T(in_x - in_x0) * T(v_01 - v_00) +
T(in_y - in_y0) * T(in_x - in_x0) * T(v_11 - v_01 - v_10 + v_00);
}
}
}
template <class T>
void resize_nn(const Stream& stream, TensorSpan<T> output, TensorView<T> input) {
auto in_height = input.get_axis_size(-2);
auto in_width = input.get_axis_size(-1);
auto out_height = output.get_axis_size(-2);
auto out_width = output.get_axis_size(-1);
auto kernel = raw::resize_nn<T>;
auto policy = make_policy(kernel, output.size(), 0, stream);
launch_kernel(kernel, policy, output, out_height, out_width, input, in_height, in_width);
}
template void resize_nn<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>);
template void resize_nn<float>(const Stream&, TensorSpan<float>, TensorView<float>);
template <class T>
void resize_bilinear(const Stream& stream, TensorSpan<T> output, TensorView<T> input, float scale_y, float scale_x) {
auto in_height = input.get_axis_size(-2);
auto in_width = input.get_axis_size(-1);
auto out_height = output.get_axis_size(-2);
auto out_width = output.get_axis_size(-1);
auto kernel = raw::resize_bilinear<T>;
auto policy = make_policy(kernel, output.size(), 0, stream);
launch_kernel(kernel, policy, output, out_height, out_width, input, in_height, in_width, scale_y, scale_x);
}
template void resize_bilinear<__half>(const Stream&, TensorSpan<__half>, TensorView<__half>, float, float);
template void resize_bilinear<float>(const Stream&, TensorSpan<float>, TensorView<float>, float, float);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
This diff is collapsed.
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "array.hpp"
#include "types.hpp"
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "kernel_dispatcher.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/tensor.hpp"
#include "../cuda4dnn/csl/span.hpp"
#include <opencv2/core.hpp>
#include <cstddef>
#include <vector>
#include <iostream>
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, std::size_t Rank>
__global__ void slice(
Span<T> output, array<size_type, Rank> out_strides,
View<T> input, array<size_type, Rank> in_strides, array<index_type, Rank> in_offset)
{
for (auto i : grid_stride_range(output.size())) {
index_type out_index = i / out_strides[0];
index_type in_index = in_offset[0] + out_index;
index_type iidx = in_index * in_strides[0];
for (int j = 1; j < Rank; j++) {
out_index = (i % out_strides[j - 1]) / out_strides[j];
in_index = in_offset[j] + out_index;
iidx += in_index * in_strides[j];
}
output[i] = input[iidx];
}
}
}
template <class T, std::size_t Rank> static
void launch_slice(
const Stream& stream,
Span<T> output, const std::vector<std::size_t>& outStride,
View<T> input, const std::vector<std::size_t>& inStride, const std::vector<std::size_t>& inOffset)
{
CV_Assert(outStride.size() == Rank);
CV_Assert(inStride.size() == Rank);
CV_Assert(inOffset.size() == Rank);
array<size_type, Rank> outStride_k, inStride_k;
outStride_k.assign(std::begin(outStride), std::end(outStride));
inStride_k.assign(std::begin(inStride), std::end(inStride));
array<index_type, Rank> inOffset_k;
inOffset_k.assign(std::begin(inOffset), std::end(inOffset));
auto kernel = raw::slice<T, Rank>;
auto policy = make_policy(kernel, output.size(), 0, stream);
launch_kernel(kernel, policy, output, outStride_k, input, inStride_k, inOffset_k);
}
GENERATE_KERNEL_DISPATCHER(slice_dispatcher, launch_slice);
template <class T>
void slice(const Stream& stream,
TensorSpan<T> output, TensorView<T> input,
std::vector<std::size_t> offsets)
{
CV_Assert(output.rank() == input.rank());
CV_Assert(output.rank() == offsets.size());
/* squeezable axes at the begining of both tensors can be eliminated
*
* Reasoning:
* ----------
* Suppose an item's indices in the output tensor is [o1, o2, ...]. The indices in the input
* tensor will be [o1 + off1, o2 + off2, ...]. The rest of the elements in the input are igored.
*
* If the size of the first axis of the input and output tensor is unity, the input and output indices
* for all the elements will be of the form be [0, o2 + off2, ...] and [0, o2, ...] respectively. Note that
* there cannot be any ignored items since the axes have unit size. The first index does not contribute to the
* element's address calculation and hence does nothing apart from eating up few cycles.
*/
while (input.get_axis_size(0) == 1 && output.get_axis_size(0) == 1) {
CV_Assert(offsets[0] == 0);
input.squeeze(0);
output.squeeze(0);
offsets.erase(std::begin(offsets));
CV_Assert(output.rank() == input.rank());
CV_Assert(output.rank() == offsets.size());
}
auto inShape = input.shape_as_vector();
auto outShape = output.shape_as_vector();
/* contiguous axes which do not undergo slicing can be combined into one axis
*
* Reasoning:
* ----------
* Suppose an item's indices in the output tensor is [o1, o2, o3, ...]. Let the first two axes not undergo any
* slicing. The indices in the input tensor will be [o1, o2, o3 + off3, ...].
*
* Each axis in the contiguous unsliced axes sequence will add an offset of iN * strideN. In the above example,
* the two axes add a total offset of `o1 * stride1 + o2 * stride2`. We can merge the two axes into one axis with
* a size of `size1 * size2`. The new offset added will be o12 * stride2` as the kernel iterates through `o12`.
* Note that `o12` is actually `(o1 * size2 + o2)` in the original tensor.
*/
for (int i = 0; i < inShape.size(); i++) {
/* check if axis `i` requires any slicing */
if (offsets[i] == 0 && inShape[i] == outShape[i]) {
/* loop invariant: `i` is the first axis in the contiguous unsliced axis sequence */
int j = i + 1; /* `j` is the axis which we will attempt to merge */
while (j < inShape.size() && offsets[j] == 0 && inShape[j] == outShape[j]) {
/* `j` axis is also unsliced; merge `i` and `j` */
auto new_size = inShape[i] * inShape[j];
inShape[i] = new_size;
outShape[i] = new_size;
offsets[i] = 0; /* redundant */
/* delete axis `j` */
inShape.erase(std::begin(inShape) + j);
outShape.erase(std::begin(outShape) + j);
offsets.erase(std::begin(offsets) + j);
/* optimizations should not break the invariants */
CV_Assert(inShape.size() == outShape.size());
CV_Assert(inShape.size() == offsets.size());
CV_Assert(inShape[i] == outShape[i]);
CV_Assert(offsets[i] == 0);
}
}
}
auto rank = inShape.size();
std::vector<std::size_t> inStride(rank), outStride(rank);
inStride.back() = 1;
outStride.back() = 1;
/* garbage, ..., garbage, 1 */
std::copy(std::begin(inShape) + 1, std::end(inShape), std::begin(inStride));
std::copy(std::begin(outShape) + 1, std::end(outShape), std::begin(outStride));
/* dim[0], dim[1], ..., dim[-1], 1 */
std::partial_sum(inStride.rbegin(), inStride.rend(), inStride.rbegin(), std::multiplies<std::size_t>());
std::partial_sum(outStride.rbegin(), outStride.rend(), outStride.rbegin(), std::multiplies<std::size_t>());
/* stride[0], stride[1], ..., stride[-2], 1 */
CV_Assert(1 <= rank && rank <= CSL_MAX_TENSOR_RANK);
slice_dispatcher<T, 1, CSL_MAX_TENSOR_RANK>(rank, stream, output, outStride, input, inStride, offsets);
}
template void slice(const Stream&, TensorSpan<__half>, TensorView<__half>, std::vector<std::size_t>);
template void slice(const Stream&, TensorSpan<float>, TensorView<float>, std::vector<std::size_t>);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA_TYPES_HPP
#define OPENCV_DNN_SRC_CUDA_TYPES_HPP
#include <cstdint>
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
/* For indices, we can use 32bit variables or 64bit variables. The GPU registers are 32 bits in size.
* Hence, a 64bit variable requires two registers and is significantly slower than the 32bit versions.
*
* If we do not need to handle huge tensors, we can use 32-bit indices and get better performance.
*/
#ifdef __CUDACC__
using size_type = int;
using index_type = int;
#else
using size_type = std::int32_t;
using index_type = std::int32_t;
#endif
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
#endif /* OPENCV_DNN_SRC_CUDA_TYPES_HPP */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA_VECTOR_TRAITS_HPP
#define OPENCV_DNN_SRC_CUDA_VECTOR_TRAITS_HPP
#include <cuda_runtime.h>
#include "types.hpp"
#include "../cuda4dnn/csl/pointer.hpp"
#include <type_traits>
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {
/** \file vector_traits.hpp
* \brief utility classes and functions for vectorized memory loads/stores
*
* Example:
* using vector_type = get_vector_type_t<float, 4>;
*
* auto input_vPtr = type::get_pointer(iptr); // iptr is of type DevicePtr<const float>
* auto output_vPtr = type::get_pointer(optr); // optr is of type DevicePtr<float>
*
* vector_type vec;
* v_load(vec, input_vPtr);
*
* for(int i = 0; i < vector_type::size(); i++)
* vec[i] = do_something(vec[i]);
*
* v_store(output_vPtr, vec);
*/
namespace detail {
template <size_type N> struct raw_type_ { };
template <> struct raw_type_<256> { typedef ulonglong4 type; };
template <> struct raw_type_<128> { typedef uint4 type; };
template <> struct raw_type_<64> { typedef uint2 type; };
template <> struct raw_type_<32> { typedef uint1 type; };
template <> struct raw_type_<16> { typedef uchar2 type; };
template <> struct raw_type_<8> { typedef uchar1 type; };
template <size_type N> struct raw_type {
using type = typename raw_type_<N>::type;
static_assert(sizeof(type) * 8 == N, "");
};
}
/* \tparam T type of element in the vector
* \tparam N "number of elements" of type T in the vector
*/
template <class T, size_type N>
union vector_type {
using value_type = T;
using raw_type = typename detail::raw_type<N * sizeof(T) * 8>::type;
__device__ vector_type() { }
__device__ static constexpr size_type size() { return N; }
raw_type raw;
T data[N];
template <class U> static __device__
typename std::enable_if<std::is_const<U>::value, const vector_type*>
::type get_pointer(csl::DevicePtr<U> ptr) {
return reinterpret_cast<const vector_type*>(ptr.get());
}
template <class U> static __device__
typename std::enable_if<!std::is_const<U>::value, vector_type*>
::type get_pointer(csl::DevicePtr<U> ptr) {
return reinterpret_cast<vector_type*>(ptr.get());
}
};
template <class V>
__device__ void v_load(V& dest, const V& src) {
dest.raw = src.raw;
}
template <class V>
__device__ void v_load(V& dest, const V* src) {
dest.raw = src->raw;
}
template <class V>
__device__ void v_store(V* dest, const V& src) {
dest->raw = src.raw;
}
template <class V>
__device__ void v_store(V& dest, const V& src) {
dest.raw = src.raw;
}
template <class T, size_type N>
struct get_vector_type {
typedef vector_type<T, N> type;
};
template <class T, size_type N>
using get_vector_type_t = typename get_vector_type<T, N>::type;
}}}}} /* namespace cv::dnn::cuda4dnn::csl::device */
#endif /* OPENCV_DNN_SRC_CUDA_VECTOR_TRAITS_HPP */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP
#include "error.hpp"
#include "stream.hpp"
#include "pointer.hpp"
#include "fp16.hpp"
#include <opencv2/core.hpp>
#include <cublas_v2.h>
#include <cstddef>
#include <memory>
#include <utility>
#define CUDA4DNN_CHECK_CUBLAS(call) \
::cv::dnn::cuda4dnn::csl::cublas::detail::check((call), CV_Func, __FILE__, __LINE__)
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cublas {
/** @brief exception class for errors thrown by the cuBLAS API */
class cuBLASException : public CUDAException {
public:
using CUDAException::CUDAException;
};
namespace detail {
static void check(cublasStatus_t status, const char* func, const char* file, int line) {
auto cublasGetErrorString = [](cublasStatus_t err) {
switch (err) {
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "UNKNOWN_CUBLAS_ERROR";
};
if (status != CUBLAS_STATUS_SUCCESS)
throw cuBLASException(Error::GpuApiCallError, cublasGetErrorString(status), func, file, line);
}
}
/** noncopyable cuBLAS smart handle
*
* UniqueHandle is a smart non-sharable wrapper for cuBLAS handle which ensures that the handle
* is destroyed after use. The handle can be associated with a CUDA stream by specifying the
* stream during construction. By default, the handle is associated with the default stream.
*/
class UniqueHandle {
public:
UniqueHandle() { CUDA4DNN_CHECK_CUBLAS(cublasCreate(&handle)); }
UniqueHandle(UniqueHandle&) = delete;
UniqueHandle(UniqueHandle&& other) noexcept
: stream(std::move(other.stream)), handle{ other.handle } {
other.handle = nullptr;
}
UniqueHandle(Stream strm) : stream(std::move(strm)) {
CUDA4DNN_CHECK_CUBLAS(cublasCreate(&handle));
try {
CUDA4DNN_CHECK_CUBLAS(cublasSetStream(handle, stream.get()));
} catch (...) {
/* cublasDestroy won't throw if a valid handle is passed */
CUDA4DNN_CHECK_CUBLAS(cublasDestroy(handle));
throw;
}
}
~UniqueHandle() noexcept {
if (handle != nullptr) {
/* cublasDestroy won't throw if a valid handle is passed */
CUDA4DNN_CHECK_CUBLAS(cublasDestroy(handle));
}
}
UniqueHandle& operator=(const UniqueHandle&) = delete;
UniqueHandle& operator=(UniqueHandle&& other) noexcept {
stream = std::move(other.stream);
handle = other.handle;
other.handle = nullptr;
return *this;
}
/** @brief returns the raw cuBLAS handle */
cublasHandle_t get() const noexcept { return handle; }
private:
Stream stream;
cublasHandle_t handle;
};
/** @brief sharable cuBLAS smart handle
*
* Handle is a smart sharable wrapper for cuBLAS handle which ensures that the handle
* is destroyed after all references to the handle are destroyed. The handle can be
* associated with a CUDA stream by specifying the stream during construction. By default,
* the handle is associated with the default stream.
*
* @note Moving a Handle object to another invalidates the former
*/
class Handle {
public:
Handle() : handle(std::make_shared<UniqueHandle>()) { }
Handle(const Handle&) = default;
Handle(Handle&&) = default;
Handle(Stream strm) : handle(std::make_shared<UniqueHandle>(std::move(strm))) { }
Handle& operator=(const Handle&) = default;
Handle& operator=(Handle&&) = default;
/** returns true if the handle is valid */
explicit operator bool() const noexcept { return static_cast<bool>(handle); }
cublasHandle_t get() const noexcept {
CV_Assert(handle);
return handle->get();
}
private:
std::shared_ptr<UniqueHandle> handle;
};
/** @brief GEMM for colummn-major matrices
*
* \f$ C = \alpha AB + \beta C \f$
*
* @tparam T matrix element type (must be `half` or `float`)
*
* @param handle valid cuBLAS Handle
* @param transa use transposed matrix of A for computation
* @param transb use transposed matrix of B for computation
* @param rows_c number of rows in C
* @param cols_c number of columns in C
* @param common_dim common dimension of A (or trans A) and B (or trans B)
* @param alpha scale factor for AB
* @param[in] A pointer to column-major matrix A in device memory
* @param lda leading dimension of matrix A
* @param[in] B pointer to column-major matrix B in device memory
* @param ldb leading dimension of matrix B
* @param beta scale factor for C
* @param[in,out] C pointer to column-major matrix C in device memory
* @param ldc leading dimension of matrix C
*
* Exception Guarantee: Basic
*/
template <class T>
void gemm(const Handle& handle,
bool transa, bool transb,
std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
T alpha, const DevicePtr<const T> A, std::size_t lda,
const DevicePtr<const T> B, std::size_t ldb,
T beta, const DevicePtr<T> C, std::size_t ldc);
template <> inline
void gemm<half>(const Handle& handle,
bool transa, bool transb,
std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
half alpha, const DevicePtr<const half> A, std::size_t lda,
const DevicePtr<const half> B, std::size_t ldb,
half beta, const DevicePtr<half> C, std::size_t ldc)
{
CV_Assert(handle);
auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N,
opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
int irows_c = static_cast<int>(rows_c),
icols_c = static_cast<int>(cols_c),
icommon_dim = static_cast<int>(common_dim),
ilda = static_cast<int>(lda),
ildb = static_cast<int>(ldb),
ildc = static_cast<int>(ldc);
CUDA4DNN_CHECK_CUBLAS(
cublasHgemm(
handle.get(),
opa, opb,
irows_c, icols_c, icommon_dim,
&alpha, A.get(), ilda,
B.get(), ildb,
&beta, C.get(), ildc
)
);
}
template <> inline
void gemm<float>(const Handle& handle,
bool transa, bool transb,
std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
float alpha, const DevicePtr<const float> A, std::size_t lda,
const DevicePtr<const float> B, std::size_t ldb,
float beta, const DevicePtr<float> C, std::size_t ldc)
{
CV_Assert(handle);
auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N,
opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
int irows_c = static_cast<int>(rows_c),
icols_c = static_cast<int>(cols_c),
icommon_dim = static_cast<int>(common_dim),
ilda = static_cast<int>(lda),
ildb = static_cast<int>(ldb),
ildc = static_cast<int>(ldc);
CUDA4DNN_CHECK_CUBLAS(
cublasSgemm(
handle.get(),
opa, opb,
irows_c, icols_c, icommon_dim,
&alpha, A.get(), ilda,
B.get(), ildb,
&beta, C.get(), ildc
)
);
}
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cublas */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP */
......@@ -2,17 +2,9 @@
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// this file is a stub and will be removed once actual code is added
#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_CUDNN_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_CSL_CUDNN_HPP
#include "../precomp.hpp"
#include "cudnn/cudnn.hpp"
#ifndef HAVE_CUDA
# error "CUDA4DNN should be enabled iff CUDA and cuDNN were found"
#endif
#include <cudnn.h>
void cuda4dnn_build_test_func() {
auto ver = cudnnGetVersion();
CV_UNUSED(ver);
}
#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_CUDNN_HPP */
This diff is collapsed.
This diff is collapsed.
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_LRN_HPP
#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_LRN_HPP
#include "cudnn.hpp"
#include "../pointer.hpp"
#include "../workspace.hpp"
#include <opencv2/core.hpp>
#include <cudnn.h>
#include <cstddef>
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn {
class LRNDescriptor {
public:
enum class LRNType {
ACROSS_CHANNELS,
WITHIN_CHANNEL
};
LRNDescriptor() noexcept : descriptor{ nullptr } { }
LRNDescriptor(const LRNDescriptor&) = delete;
LRNDescriptor(LRNDescriptor&& other) noexcept
: descriptor{ other.descriptor }, type{ other.type } {
other.descriptor = nullptr;
}
/** sets up a LRN descriptor
*
* @param local_size size of the normalization window
* @param alpha variance scaling parameter
* @param beta power parameter
* @param k bias parameter
*
* @note \p alpha is divided by the window width in across channels mode
* @note \p alpha is divided by the (window width)^spatialDimensions in within channel mode
*
* @note the \p alpha, \p beta and \p k will be type casted to the tensor datatype during operation
*
* Exception Guarantee: Basic
*/
LRNDescriptor(std::size_t local_size, double alpha, double beta, double k, LRNType type_) {
constructor(local_size, alpha, beta, k, type_);
}
~LRNDescriptor() noexcept {
if (descriptor != nullptr) {
/* cudnnDestroyLRNDescriptor will not fail for a valid descriptor */
CUDA4DNN_CHECK_CUDNN(cudnnDestroyLRNDescriptor(descriptor));
}
}
LRNDescriptor& operator=(const LRNDescriptor&) = delete;
LRNDescriptor& operator=(LRNDescriptor&& other) noexcept {
descriptor = other.descriptor;
type = other.type;
other.descriptor = nullptr;
return *this;
};
cudnnLRNDescriptor_t get() const noexcept { return descriptor; }
LRNType getType() const noexcept { return type; }
private:
void constructor(std::size_t local_size, double alpha, double beta, double k, LRNType type_) {
CV_Assert(CUDNN_LRN_MIN_N <= local_size && local_size <= CUDNN_LRN_MAX_N);
type = type_;
CUDA4DNN_CHECK_CUDNN(cudnnCreateLRNDescriptor(&descriptor));
try {
CUDA4DNN_CHECK_CUDNN(
cudnnSetLRNDescriptor(
descriptor,
local_size,
alpha,
beta,
k
)
);
} catch (...) {
/* cudnnDestroyLRNDescriptor will not fail for a valid descriptor */
CUDA4DNN_CHECK_CUDNN(cudnnDestroyLRNDescriptor(descriptor));
throw;
}
}
cudnnLRNDescriptor_t descriptor;
LRNType type;
};
/** @brief performs local response normalization
*
* dstValue = alpha * result + beta * priorDstValue
*
* @tparam T element type (must be `half` or `float`)
*
* @param handle valid cuDNN Handle
* @param lrnDesc LRN description
* @param inputDesc tensor descriptor describing the input
* @param[in] inputPtr pointer to input tensor in device memory
* @param alpha result scale factor
* @param beta previous value scale factor
* @param outputDesc tensor descriptor describing the output
* @param[out] outputPtr pointer to output tensor in device memory
* @param workspace workspace memory which meets the requirements of \p convAlgo
*
* Exception Guarantee: Basic
*/
template <class T>
void LRNForward(
const Handle& handle,
const LRNDescriptor& lrnDesc,
const TensorDescriptor<T>& inputDesc,
DevicePtr<const T> inputPtr,
T alpha, T beta,
const TensorDescriptor<T>& outputDesc,
DevicePtr<T> outputPtr,
WorkspaceInstance workspace)
{
CV_Assert(handle);
if (lrnDesc.getType() == LRNDescriptor::LRNType::ACROSS_CHANNELS) {
CUDA4DNN_CHECK_CUDNN(
cudnnLRNCrossChannelForward(
handle.get(),
lrnDesc.get(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
&alpha, inputDesc.get(), inputPtr.get(),
&beta, outputDesc.get(), outputPtr.get()
)
);
} else if (lrnDesc.getType() == LRNDescriptor::LRNType::WITHIN_CHANNEL) {
std::size_t size;
CUDA4DNN_CHECK_CUDNN(cudnnGetTensorSizeInBytes(inputDesc.get(), &size));
DevicePtr<void> temp1 = workspace.get_span<half>(size).data();
DevicePtr<void> temp2 = workspace.get_span<half>(size).data();
CUDA4DNN_CHECK_CUDNN(
cudnnDivisiveNormalizationForward(
handle.get(),
lrnDesc.get(), CUDNN_DIVNORM_PRECOMPUTED_MEANS,
&alpha, inputDesc.get(), inputPtr.get(),
NULL,
static_cast<void*>(temp1), static_cast<void*>(temp2),
&beta, outputDesc.get(), outputPtr.get()
)
);
}
}
template <> inline
void LRNForward(
const Handle& handle,
const LRNDescriptor& lrnDesc,
const TensorDescriptor<half>& inputDesc,
DevicePtr<const half> inputPtr,
half alpha, half beta,
const TensorDescriptor<half>& outputDesc,
DevicePtr<half> outputPtr,
WorkspaceInstance workspace)
{
CV_Assert(handle);
/* we specalize for fp16 as the scaling factors must be provided as `float` */
float alpha_ = alpha, beta_ = beta;
if (lrnDesc.getType() == LRNDescriptor::LRNType::ACROSS_CHANNELS) {
CUDA4DNN_CHECK_CUDNN(
cudnnLRNCrossChannelForward(
handle.get(),
lrnDesc.get(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
&alpha_, inputDesc.get(), inputPtr.get(),
&beta_, outputDesc.get(), outputPtr.get()
)
);
} else if (lrnDesc.getType() == LRNDescriptor::LRNType::WITHIN_CHANNEL) {
std::size_t size;
CUDA4DNN_CHECK_CUDNN(cudnnGetTensorSizeInBytes(inputDesc.get(), &size));
DevicePtr<void> temp1 = workspace.get_span<half>(size).data();
DevicePtr<void> temp2 = workspace.get_span<half>(size).data();
CUDA4DNN_CHECK_CUDNN(
cudnnDivisiveNormalizationForward(
handle.get(),
lrnDesc.get(), CUDNN_DIVNORM_PRECOMPUTED_MEANS,
&alpha_, inputDesc.get(), inputPtr.get(),
NULL,
static_cast<void*>(temp1), static_cast<void*>(temp2),
&beta_, outputDesc.get(), outputPtr.get()
)
);
}
}
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_LRN_HPP */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_POOLING_HPP
#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_POOLING_HPP
#include "cudnn.hpp"
#include "../pointer.hpp"
#include <opencv2/core.hpp>
#include <cudnn.h>
#include <cstddef>
#include <array>
#include <algorithm>
#include <vector>
#include <type_traits>
#include <iterator>
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn {
class PoolingDescriptor {
public:
enum class PoolingType {
MAX,
MAX_DETERMINISTIC,
AVERAGE_EXCLUDE_PADDING,
AVERAGE_INCLUDE_PADDING
};
PoolingDescriptor() noexcept : descriptor{ nullptr } { }
PoolingDescriptor(const PoolingDescriptor&) = delete;
PoolingDescriptor(PoolingDescriptor&& other) noexcept
: descriptor{ other.descriptor } {
other.descriptor = nullptr;
}
/** constructs a pooling descriptor
*
* Pre-conditions:
* - \p window_size, \p padding and \p stride must have the same size
*
* The length of the containers is interpreted as the order of the pooling operation.
*
* Exception Guarantee: Basic
*/
template <class SequenceContainer, typename = decltype(std::begin(std::declval<SequenceContainer>()))>
PoolingDescriptor(
const SequenceContainer& window_size,
const SequenceContainer& padding,
const SequenceContainer& stride,
PoolingType type)
{
constructor(window_size, padding, stride, type);
}
~PoolingDescriptor() noexcept {
if (descriptor != nullptr) {
/* cudnnDestroyPoolingDescriptor will not fail for a valid descriptor */
CUDA4DNN_CHECK_CUDNN(cudnnDestroyPoolingDescriptor(descriptor));
}
}
PoolingDescriptor& operator=(const PoolingDescriptor&) = delete;
PoolingDescriptor& operator=(PoolingDescriptor&& other) noexcept {
descriptor = other.descriptor;
other.descriptor = nullptr;
return *this;
};
cudnnPoolingDescriptor_t get() const noexcept { return descriptor; }
private:
template <class SequenceContainer>
void constructor(
const SequenceContainer& window_size,
const SequenceContainer& padding,
const SequenceContainer& stride,
PoolingType type)
{
CV_Assert(window_size.size() == padding.size());
CV_Assert(window_size.size() == stride.size());
auto get_pooling_type = [] (PoolingType type) {
switch (type) {
case PoolingType::MAX:
return CUDNN_POOLING_MAX;
case PoolingType::MAX_DETERMINISTIC:
return CUDNN_POOLING_MAX_DETERMINISTIC;
case PoolingType::AVERAGE_EXCLUDE_PADDING:
return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
case PoolingType::AVERAGE_INCLUDE_PADDING:
return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
}
CV_Error(Error::StsBadArg, "unknown pooling type");
};
CUDA4DNN_CHECK_CUDNN(cudnnCreatePoolingDescriptor(&descriptor));
try {
const auto rank = window_size.size();
if (rank == 2) {
CUDA4DNN_CHECK_CUDNN(
cudnnSetPooling2dDescriptor(
descriptor,
get_pooling_type(type), CUDNN_PROPAGATE_NAN,
window_size[0], window_size[1],
padding[0], padding[1],
stride[0], stride[1]
)
);
} else {
std::vector<int> iwindow_size(std::begin(window_size), std::end(window_size));
std::vector<int> ipadding(std::begin(padding), std::end(padding));
std::vector<int> istride(std::begin(stride), std::end(stride));
CUDA4DNN_CHECK_CUDNN(
cudnnSetPoolingNdDescriptor(
descriptor,
get_pooling_type(type), CUDNN_PROPAGATE_NAN,
rank, iwindow_size.data(), ipadding.data(), istride.data()
)
);
}
} catch (...) {
/* cudnnDestroyPoolingDescriptor will not fail for a valid descriptor */
CUDA4DNN_CHECK_CUDNN(cudnnDestroyPoolingDescriptor(descriptor));
throw;
}
}
cudnnPoolingDescriptor_t descriptor;
};
/** gives the shape of the output tensor after pooling
*
* @note it's not required to enforce the this shape in the output tensor; slightly different shapes will work
*
* Exception Guarantee: Basic
*/
template <class T> inline
void getPoolingForwardOutputDim(
const PoolingDescriptor& poolingDesc,
const TensorDescriptor<T>& inputDesc,
std::vector<int>& output_dim)
{
output_dim.clear();
output_dim.resize(CUDNN_DIM_MAX); /* we use `output_dim` to hold temporaries */
std::vector<int> temp(CUDNN_DIM_MAX);
cudnnDataType_t tempDataType;
CUDA4DNN_CHECK_CUDNN(
cudnnGetTensorNdDescriptor(
inputDesc.get(),
CUDNN_DIM_MAX + 1, /* according to docs, this is what we do to get the rank */
&tempDataType,
output_dim.data(),
temp.data(),
temp.data()
)
);
const auto rank = output_dim[0];
output_dim.resize(rank);
CUDA4DNN_CHECK_CUDNN(
cudnnGetPoolingNdForwardOutputDim(poolingDesc.get(), inputDesc.get(), rank, output_dim.data())
);
}
/** @brief performs pooling operation
*
* dstValue = alpha * result + beta * priorDstValue
*
* @tparam T pooling element type (must be `half` or `float`)
*
* @param handle valid cuDNN Handle
* @param poolingDesc pooling description
* @param inputDesc tensor descriptor describing the input
* @param[in] inputPtr pointer to input tensor in device memory
* @param alpha result scale factor
* @param beta previous value scale factor
* @param outputDesc tensor descriptor describing the output
* @param[out] outputPtr pointer to output tensor in device memory
*
* Exception Guarantee: Basic
*/
template <class T>
void pool(
const Handle& handle,
const PoolingDescriptor& poolingDesc,
const TensorDescriptor<T>& inputDesc,
const DevicePtr<const T> inputPtr,
T alpha, T beta,
const TensorDescriptor<T>& outputDesc,
DevicePtr<T> outputPtr)
{
CV_Assert(handle);
CUDA4DNN_CHECK_CUDNN(
cudnnPoolingForward(
handle.get(),
poolingDesc.get(),
&alpha, inputDesc.get(), inputPtr.get(),
&beta, outputDesc.get(), outputPtr.get()
)
);
}
template <> inline
void pool(
const Handle& handle,
const PoolingDescriptor& poolingDesc,
const TensorDescriptor<half>& inputDesc,
const DevicePtr<const half> inputPtr,
half alpha, half beta,
const TensorDescriptor<half>& outputDesc,
DevicePtr<half> outputPtr)
{
CV_Assert(handle);
/* we specalize for fp16 as the scaling factors must be provided as `float` */
float alpha_ = alpha, beta_ = beta;
CUDA4DNN_CHECK_CUDNN(
cudnnPoolingForward(
handle.get(),
poolingDesc.get(),
&alpha_, inputDesc.get(), inputPtr.get(),
&beta_, outputDesc.get(), outputPtr.get()
)
);
}
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_POOLING_HPP */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP
#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP
#include "cudnn.hpp"
#include "../pointer.hpp"
#include <cudnn.h>
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn {
/** @brief computes softmax (or log softmax)
*
* @tparam T element type (must be `half` or `float`)
*
* @param handle valid cuDNN handle
* @param outputDesc tensor descriptor for A
* @param[out] output pointer to tensor in device memory
* @param inputDesc tensor descriptor for C
* @param[in] input pointer to tensor in device memory
* @param log apply log on probabilities
*
* Exception Guarantee: Basic
*/
template <class T>
void softmax(const cudnn::Handle& handle,
const TensorDescriptor<T>& outputDesc, DevicePtr<T> output,
const TensorDescriptor<T>& inputDesc, DevicePtr<const T> input,
bool log)
{
T alpha = 1.0, beta = 0.0;
cudnnSoftmaxAlgorithm_t algo = log ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE;
CUDA4DNN_CHECK_CUDNN(
cudnnSoftmaxForward(
handle.get(),
algo, CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha, inputDesc.get(), input.get(),
&beta, outputDesc.get(), output.get()
)
);
}
template <> inline
void softmax(const cudnn::Handle& handle,
const TensorDescriptor<half>& outputDesc, DevicePtr<half> output,
const TensorDescriptor<half>& inputDesc, DevicePtr<const half> input,
bool log)
{
/* we specalize for fp16 as the scaling factors must be provided as `float` */
float alpha = 1.0, beta = 0.0;
cudnnSoftmaxAlgorithm_t algo = log ? CUDNN_SOFTMAX_LOG : CUDNN_SOFTMAX_ACCURATE;
CUDA4DNN_CHECK_CUDNN(
cudnnSoftmaxForward(
handle.get(),
algo, CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha, inputDesc.get(), input.get(),
&beta, outputDesc.get(), output.get()
)
);
}
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */
#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_SOFTMAX_HPP */
This diff is collapsed.
This diff is collapsed.
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_ERROR_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_CSL_ERROR_HPP
#include <opencv2/core.hpp>
#include <cuda_runtime_api.h>
#define CUDA4DNN_CHECK_CUDA(call) \
::cv::dnn::cuda4dnn::csl::detail::check((call), CV_Func, __FILE__, __LINE__)
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
/** @brief exception class for errors thrown by the CUDA APIs */
class CUDAException : public cv::Exception {
public:
using cv::Exception::Exception;
};
namespace detail {
inline void check(cudaError_t err, const char* func, const char* file, int line) {
if (err != cudaSuccess)
throw CUDAException(Error::GpuApiCallError, cudaGetErrorString(err), func, file, line);
}
}
}}}} /* namespace cv::dnn::cuda4dnn::csl */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_ERROR_HPP */
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -2,17 +2,19 @@
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// this file is a stub and will be removed once actual code is added
#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_NVCC_DEFS_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_CSL_NVCC_DEFS_HPP
#include "../precomp.hpp"
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#ifndef HAVE_CUDA
# error "CUDA files should not be compiled if CUDA was not enabled"
#ifdef __CUDACC__
# define CUDA4DNN_HOST __host__
# define CUDA4DNN_DEVICE __device__
# define CUDA4DNN_HOST_DEVICE CUDA4DNN_HOST CUDA4DNN_DEVICE
#else
# define CUDA4DNN_HOST
# define CUDA4DNN_DEVICE
# define CUDA4DNN_HOST_DEVICE
#endif
__global__ void cuda4dnn_build_test_kernel(float* addr) {
int idx = threadIdx.x;
addr[idx] = 0.0;
}
#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_NVCC_DEFS_HPP */
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_FILL_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_FILL_HPP
#include "../csl/stream.hpp"
#include "../csl/span.hpp"
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
template <class T>
void fill(const csl::Stream& stream, csl::Span<T> output, T value);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_FILL_HPP */
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment