• Yashas Samaga B L's avatar
    Merge pull request #14827 from YashasSamaga:cuda4dnn-csl-low · 613c12e5
    Yashas Samaga B L authored
    CUDA backend for the DNN module
    
    * stub cuda4dnn design
    
    * minor fixes for tests and doxygen
    
    * add csl public api directory to module headers
    
    * add low-level CSL components
    
    * add high-level CSL components
    
    * integrate csl::Tensor into backbone code
    
    * switch to CPU iff unsupported; otherwise, fail on error
    
    * add fully connected layer
    
    * add softmax layer
    
    * add activation layers
    
    * support arbitary rank TensorDescriptor
    
    * pass input wrappers to `initCUDA()`
    
    * add 1d/2d/3d-convolution
    
    * add pooling layer
    
    * reorganize and refactor code
    
    * fixes for gcc, clang and doxygen; remove cxx14/17 code
    
    * add blank_layer
    
    * add LRN layer
    
    * add rounding modes for pooling layer
    
    * split tensor.hpp into tensor.hpp and tensor_ops.hpp
    
    * add concat layer
    
    * add scale layer
    
    * add batch normalization layer
    
    * split math.cu into activations.cu and math.hpp
    
    * add eltwise layer
    
    * add flatten layer
    
    * add tensor transform api
    
    * add asymmetric padding support for convolution layer
    
    * add reshape layer
    
    * fix rebase issues
    
    * add permute layer
    
    * add padding support for concat layer
    
    * refactor and reorganize code
    
    * add normalize layer
    
    * optimize bias addition in scale layer
    
    * add prior box layer
    
    * fix and optimize normalize layer
    
    * add asymmetric padding support for pooling layer
    
    * add event API
    
    * improve pooling performance for some padding scenarios
    
    * avoid over-allocation of compute resources to kernels
    
    * improve prior box performance
    
    * enable layer fusion
    
    * add const layer
    
    * add resize layer
    
    * add slice layer
    
    * add padding layer
    
    * add deconvolution layer
    
    * fix channelwise  ReLU initialization
    
    * add vector traits
    
    * add vectorized versions of relu, clipped_relu, power
    
    * add vectorized concat kernels
    
    * improve concat_with_offsets performance
    
    * vectorize scale and bias kernels
    
    * add support for multi-billion element tensors
    
    * vectorize prior box kernels
    
    * fix address alignment check
    
    * improve bias addition performance of conv/deconv/fc layers
    
    * restructure code for supporting multiple targets
    
    * add DNN_TARGET_CUDA_FP64
    
    * add DNN_TARGET_FP16
    
    * improve vectorization
    
    * add region layer
    
    * improve tensor API, add dynamic ranks
    
    1. use ManagedPtr instead of a Tensor in backend wrapper
    2. add new methods to tensor classes
      - size_range: computes the combined size of for a given axis range
      - tensor span/view can be constructed from a raw pointer and shape
    3. the tensor classes can change their rank at runtime (previously rank was fixed at compile-time)
    4. remove device code from tensor classes (as they are unused)
    5. enforce strict conditions on tensor class APIs to improve debugging ability
    
    * fix parametric relu activation
    
    * add squeeze/unsqueeze tensor API
    
    * add reorg layer
    
    * optimize permute and enable 2d permute
    
    * enable 1d and 2d slice
    
    * add split layer
    
    * add shuffle channel layer
    
    * allow tensors of different ranks in reshape primitive
    
    * patch SliceOp to allow Crop Layer
    
    * allow extra shape inputs in reshape layer
    
    * use `std::move_backward` instead of `std::move` for insert in resizable_static_array
    
    * improve workspace management
    
    * add spatial LRN
    
    * add nms (cpu) to region layer
    
    * add max pooling with argmax ( and a fix to limits.hpp)
    
    * add max unpooling layer
    
    * rename DNN_TARGET_CUDA_FP32 to DNN_TARGET_CUDA
    
    * update supportBackend to be more rigorous
    
    * remove stray include from preventing non-cuda build
    
    * include op_cuda.hpp outside condition #if
    
    * refactoring, fixes and many optimizations
    
    * drop DNN_TARGET_CUDA_FP64
    
    * fix gcc errors
    
    * increase max. tensor rank limit to six
    
    * add Interp layer
    
    * drop custom layers; use BackendNode
    
    * vectorize activation kernels
    
    * fixes for gcc
    
    * remove wrong assertion
    
    * fix broken assertion in unpooling primitive
    
    * fix build errors in non-CUDA build
    
    * completely remove workspace from public API
    
    * fix permute layer
    
    * enable accuracy and perf. tests for DNN_TARGET_CUDA
    
    * add asynchronous forward
    
    * vectorize eltwise ops
    
    * vectorize fill kernel
    
    * fixes for gcc
    
    * remove CSL headers from public API
    
    * remove csl header source group from cmake
    
    * update min. cudnn version in cmake
    
    * add numerically stable FP32 log1pexp
    
    * refactor code
    
    * add FP16 specialization to cudnn based tensor addition
    
    * vectorize scale1 and bias1 + minor refactoring
    
    * fix doxygen build
    
    * fix invalid alignment assertion
    
    * clear backend wrappers before allocateLayers
    
    * ignore memory lock failures
    
    * do not allocate internal blobs
    
    * integrate NVTX
    
    * add numerically stable half precision log1pexp
    
    * fix indentation, following coding style,  improve docs
    
    * remove accidental modification of IE code
    
    * Revert "add asynchronous forward"
    
    This reverts commit 1154b9da9da07e9b52f8a81bdcea48cf31c56f70.
    
    * [cmake] throw error for unsupported CC versions
    
    * fix rebase issues
    
    * add more docs, refactor code, fix bugs
    
    * minor refactoring and fixes
    
    * resolve warnings/errors from clang
    
    * remove haveCUDA() checks from supportBackend()
    
    * remove NVTX integration
    
    * changes based on review comments
    
    * avoid exception when no CUDA device is present
    
    * add color code for CUDA in Net::dump
    613c12e5
pointer.hpp 16.3 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.

#ifndef OPENCV_DNN_SRC_CUDA4DNN_CSL_POINTER_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_CSL_POINTER_HPP

#include "nvcc_defs.hpp"
#include "error.hpp"
#include "stream.hpp"

#include <opencv2/core.hpp>

#include <cuda_runtime_api.h>

#include <cstddef>
#include <type_traits>
#include <ostream>

namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {

    /** @brief provides a type-safe device pointer
     *
     * DevicePtr wraps a raw pointer and mimics its behaviour. It does not implicitly convert
     * to a raw pointer. This ensures that accidental mixing of host and device pointers do not happen.
     *
     * It is meant to point to locations in device memory. Hence, it provides dereferencing or
     * array subscript capability for device code only.
     *
     * A `const DevicePtr<T>` represents an immutable pointer to a mutable memory.
     * A `DevicePtr<const T>` represents a mutable pointer to an immutable memory.
     * A `const DevicePtr<const T>` represents an immutable pointer to an immutable memory.
     *
     * A `DevicePtr<T>` can implicitly convert to `DevicePtr<const T>`.
     *
     * Specalizations:
     * - DevicePtr<void>/DevicePtr<const void> do not support pointer arithmetic (but relational operators are provided)
     * - any device pointer pointing to mutable memory is implicitly convertible to DevicePtr<void>
     * - any device pointer is implicitly convertible to DevicePtr<const void>
     * - DevicePtr<void> can be explicitly converted to any device pointer
     * - DevicePtr<const void> can be explicitly converted to any device pointer pointing to immutable memory
     */
    template <class T>
    class DevicePtr {
        static_assert(std::is_standard_layout<T>::value, "T must satisfy StandardLayoutType");

    public:
        using element_type = T;
        using difference_type = std::ptrdiff_t;
        using pointer = typename std::add_pointer<element_type>::type;
        using reference = typename std::add_lvalue_reference<element_type>::type;

        DevicePtr() = default;
        CUDA4DNN_HOST_DEVICE explicit DevicePtr(pointer ptr_) noexcept : ptr{ ptr_ } { }

        CUDA4DNN_HOST_DEVICE DevicePtr operator=(pointer ptr_) noexcept { ptr = ptr_; return *this; }

        CUDA4DNN_HOST_DEVICE pointer get() const noexcept { return ptr; };

        CUDA4DNN_DEVICE reference operator[](difference_type idx) const noexcept { return get()[idx]; }
        CUDA4DNN_DEVICE reference operator*() const noexcept { return *get(); }
        CUDA4DNN_DEVICE pointer operator->() const noexcept { return get(); }

        template<class U = T, typename std::enable_if<!std::is_const<U>::value, bool>::type = true>
        CUDA4DNN_HOST_DEVICE operator DevicePtr<typename std::add_const<U>::type>() const noexcept {
            return DevicePtr<typename std::add_const<U>::type>{ptr};
        }

        CUDA4DNN_HOST_DEVICE explicit operator bool() const noexcept { return ptr; }

        CUDA4DNN_HOST_DEVICE DevicePtr operator++() noexcept {
            ++ptr;
            return *this;
        }

        CUDA4DNN_HOST_DEVICE DevicePtr operator++(int) noexcept {
            auto tmp = DevicePtr(*this);
            ptr++;
            return tmp;
        }

        CUDA4DNN_HOST_DEVICE DevicePtr operator--() noexcept {
            --ptr;
            return *this;
        }

        CUDA4DNN_HOST_DEVICE DevicePtr operator--(int) noexcept {
            auto tmp = DevicePtr(*this);
            ptr--;
            return tmp;
        }

        CUDA4DNN_HOST_DEVICE DevicePtr operator+=(std::ptrdiff_t offset) noexcept {
            ptr += offset;
            return *this;
        }

        CUDA4DNN_HOST_DEVICE DevicePtr operator-=(std::ptrdiff_t offset) noexcept {
            ptr -= offset;
            return *this;
        }

        CUDA4DNN_HOST_DEVICE friend DevicePtr operator+(DevicePtr lhs, std::ptrdiff_t offset) noexcept {
            return lhs += offset;
        }

        CUDA4DNN_HOST_DEVICE friend DevicePtr operator-(DevicePtr lhs, std::ptrdiff_t offset) noexcept {
            return lhs -= offset;
        }

        CUDA4DNN_HOST_DEVICE friend difference_type operator-(DevicePtr lhs, DevicePtr rhs) noexcept {
            return lhs.ptr - rhs.ptr;
        }

        CUDA4DNN_HOST_DEVICE friend bool operator==(DevicePtr lhs, DevicePtr rhs) noexcept { return lhs.ptr == rhs.ptr; }
        CUDA4DNN_HOST_DEVICE friend bool operator!=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(lhs == rhs); }
        CUDA4DNN_HOST_DEVICE friend bool operator<(DevicePtr lhs, DevicePtr rhs) noexcept { return lhs.ptr < rhs.ptr; }
        CUDA4DNN_HOST_DEVICE friend bool operator>(DevicePtr lhs, DevicePtr rhs) noexcept { return rhs < lhs; }
        CUDA4DNN_HOST_DEVICE friend bool operator<=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(rhs < lhs); }
        CUDA4DNN_HOST_DEVICE friend bool operator>=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(lhs < rhs); }

        CUDA4DNN_HOST_DEVICE explicit operator pointer() const noexcept { return ptr; }

        CUDA4DNN_HOST friend void swap(DevicePtr& lhs, DevicePtr& rhs) noexcept {
            using std::swap;
            swap(lhs.ptr, rhs.ptr);
        }

        template <class U, class V>
        CUDA4DNN_HOST friend std::basic_ostream<U, V>& operator<<(std::basic_ostream<U, V>& os, DevicePtr other) {
            os << other.get() << " (device)";
            return os;
        }

    private:
        pointer ptr;
    };

    template <>
    class DevicePtr<const void> {
    public:
        using element_type = const void;
        using pointer = typename std::add_pointer<element_type>::type;

        DevicePtr() = default;

        /* host const void pointer to const void device pointer */
        CUDA4DNN_HOST_DEVICE explicit DevicePtr(pointer ptr_) noexcept : ptr{ ptr_ } { }

        /* allow any device pointer to be implicitly convereted to void device pointer */
        template <class T>
        CUDA4DNN_HOST_DEVICE DevicePtr(DevicePtr<T> ptr_) noexcept : ptr{ ptr_.get() } { }

        CUDA4DNN_HOST_DEVICE DevicePtr operator=(pointer ptr_) noexcept { ptr = ptr_; return *this; }

        CUDA4DNN_HOST_DEVICE pointer get() const noexcept { return ptr; };

        CUDA4DNN_HOST_DEVICE explicit operator bool() const noexcept { return ptr; }

        CUDA4DNN_HOST_DEVICE friend bool operator==(DevicePtr lhs, DevicePtr rhs) noexcept { return lhs.ptr == rhs.ptr; }
        CUDA4DNN_HOST_DEVICE friend bool operator!=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(lhs == rhs); }
        CUDA4DNN_HOST_DEVICE friend bool operator<(DevicePtr lhs, DevicePtr rhs) noexcept { return lhs.ptr < rhs.ptr; }
        CUDA4DNN_HOST_DEVICE friend bool operator>(DevicePtr lhs, DevicePtr rhs) noexcept { return rhs < lhs; }
        CUDA4DNN_HOST_DEVICE friend bool operator<=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(rhs < lhs); }
        CUDA4DNN_HOST_DEVICE friend bool operator>=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(lhs < rhs); }

        /* explicit conversion into host void pointer */
        CUDA4DNN_HOST_DEVICE explicit operator pointer() const noexcept { return ptr; }

        /* const void device pointer can be explicitly casted into any const device pointer type */
        template <class T, typename std::enable_if<std::is_const<T>::value, bool>::type = true>
        CUDA4DNN_HOST_DEVICE explicit operator DevicePtr<T>() const noexcept {
            return static_cast<T*>(ptr);
        }

        CUDA4DNN_HOST friend void swap(DevicePtr& lhs, DevicePtr& rhs) noexcept {
            using std::swap;
            swap(lhs.ptr, rhs.ptr);
        }

        template <class U, class V>
        CUDA4DNN_HOST friend std::basic_ostream<U, V>& operator<<(std::basic_ostream<U, V>& os, DevicePtr other) {
            os << other.get() << " (device)";
            return os;
        }

    private:
        pointer ptr;
    };

    template <>
    class DevicePtr<void> {
    public:
        using element_type = void;
        using pointer = typename std::add_pointer<element_type>::type;

        DevicePtr() = default;

        /* host pointer to device pointer */
        CUDA4DNN_HOST_DEVICE explicit DevicePtr(pointer ptr_) noexcept : ptr{ ptr_ } { }

        /* allow any device pointer to mutable memory to be implicitly convereted to void device pointer */
        template <class T, typename std::enable_if<!std::is_const<T>::value, bool>::type = false>
        CUDA4DNN_HOST_DEVICE DevicePtr(DevicePtr<T> ptr_) noexcept : ptr { ptr_.get() } { }

        CUDA4DNN_HOST_DEVICE DevicePtr operator=(pointer ptr_) noexcept { ptr = ptr_; return *this; }

        CUDA4DNN_HOST_DEVICE pointer get() const noexcept { return ptr; };

        CUDA4DNN_HOST_DEVICE operator DevicePtr<const void>() const noexcept { return DevicePtr<const void>{ptr}; }

        CUDA4DNN_HOST_DEVICE explicit operator bool() const noexcept { return ptr; }

        CUDA4DNN_HOST_DEVICE friend bool operator==(DevicePtr lhs, DevicePtr rhs) noexcept { return lhs.ptr == rhs.ptr; }
        CUDA4DNN_HOST_DEVICE friend bool operator!=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(lhs == rhs); }
        CUDA4DNN_HOST_DEVICE friend bool operator<(DevicePtr lhs, DevicePtr rhs) noexcept { return lhs.ptr < rhs.ptr; }
        CUDA4DNN_HOST_DEVICE friend bool operator>(DevicePtr lhs, DevicePtr rhs) noexcept { return rhs < lhs; }
        CUDA4DNN_HOST_DEVICE friend bool operator<=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(rhs < lhs); }
        CUDA4DNN_HOST_DEVICE friend bool operator>=(DevicePtr lhs, DevicePtr rhs) noexcept { return !(lhs < rhs); }

        /* explicit conversion into host void pointer */
        CUDA4DNN_HOST_DEVICE explicit operator pointer() const noexcept { return ptr; }

        /* void device pointer can be explicitly casted into any device pointer type */
        template <class T>
        CUDA4DNN_HOST_DEVICE explicit operator DevicePtr<T>() const noexcept {
            return DevicePtr<T>(static_cast<T*>(ptr));
        }

        CUDA4DNN_HOST friend void swap(DevicePtr& lhs, DevicePtr& rhs) noexcept {
            using std::swap;
            swap(lhs.ptr, rhs.ptr);
        }

        template <class U, class V>
        CUDA4DNN_HOST friend std::basic_ostream<U, V>& operator<<(std::basic_ostream<U, V>& os, DevicePtr other) {
            os << other.get() << " (device)";
            return os;
        }

    private:
        pointer ptr;
    };

    template <class T>
    bool is_aligned(DevicePtr<const T> ptr, std::size_t alignment) {
        auto addr = reinterpret_cast<std::intptr_t>(ptr.get());
        return addr % alignment == 0;
    }

    /** copies \p n elements from \p src to \p dest4
     *
     * \param[in]   src     device pointer
     * \param[out]  dest    host pointer
     *
     * Pre-conditions:
     * - memory pointed by \p dest and \p src must be large enough to hold \p n elements
     *
     * Exception Guarantee: Basic
     */
    template <class T>
    void memcpy(T *dest, DevicePtr<const T> src, std::size_t n) {
        if (n <= 0) {
            CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive");
        }

        CUDA4DNN_CHECK_CUDA(cudaMemcpy(dest, src.get(), n * sizeof(T), cudaMemcpyDefault));
    }

    /** copies \p n elements from \p src to \p dest
     *
     * \param[in]   src     host pointer
     * \param[out]  dest    device pointer
     *
     * Pre-conditions:
     * - memory pointed by \p dest and \p src must be large enough to hold \p n elements
     *
     * Exception Guarantee: Basic
     */
    template <class T>
    void memcpy(DevicePtr<T> dest, const T* src, std::size_t n) {
        if (n <= 0) {
            CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive");
        }

        CUDA4DNN_CHECK_CUDA(cudaMemcpy(dest.get(), src, n * sizeof(T), cudaMemcpyDefault));
    }

    /** copies \p n elements from \p src to \p dest
     *
     * \param[in]   src     device pointer
     * \param[out]  dest    device pointer
     *
     * Pre-conditions:
     * - memory pointed by \p dest and \p src must be large enough to hold \p n elements
     *
     * Exception Guarantee: Basic
     */
    template <class T>
    void memcpy(DevicePtr<T> dest, DevicePtr<const T> src, std::size_t n) {
        if (n <= 0) {
            CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive");
        }

        CUDA4DNN_CHECK_CUDA(cudaMemcpy(dest.get(), src.get(), n * sizeof(T), cudaMemcpyDefault));
    }

    /** sets \p n elements to \p ch in \p dest
     *
     * \param[in]   src     device pointer
     * \param[out]  ch      8-bit value to fill the device memory with
     *
     * Pre-conditions:
     * - memory pointed by \p dest must be large enough to hold \p n elements
     *
     * Exception Guarantee: Basic
     */
    template <class T>
    void memset(DevicePtr<T> dest, std::int8_t ch, std::size_t n) {
        if (n <= 0) {
            CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive");
        }

        CUDA4DNN_CHECK_CUDA(cudaMemset(dest.get(), ch, n * sizeof(T)));
    }

    /** copies \p n elements from \p src to \p dest asynchronously
     *
     * \param[in]   src     device pointer
     * \param[out]  dest    host pointer
     * \param       stream  CUDA stream that has to be used for the memory transfer
     *
     * Pre-conditions:
     * - memory pointed by \p dest and \p src must be large enough to hold \p n elements
     * - \p dest points to page-locked memory
     *
     * Exception Guarantee: Basic
     */
    template <class T>
    void memcpy(T *dest, DevicePtr<const T> src, std::size_t n, const Stream& stream) {
        if (n <= 0) {
            CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive");
        }

        CUDA4DNN_CHECK_CUDA(cudaMemcpyAsync(dest, src.get(), n * sizeof(T), cudaMemcpyDefault, stream.get()));
    }

    /** copies data from memory pointed by \p src to \p dest asynchronously
     *
     * \param[in]   src     host pointer
     * \param[out]  dest    device pointer
     * \param       stream  CUDA stream that has to be used for the memory transfer
     *
     * Pre-conditions:
     * - memory pointed by \p dest and \p src must be large enough to hold \p n elements
     * - \p src points to page-locked memory
     *
     * Exception Guarantee: Basic
     */
    template <class T>
    void memcpy(DevicePtr<T> dest, const T *src, std::size_t n, const Stream& stream) {
        if (n <= 0) {
            CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive");
        }

        CUDA4DNN_CHECK_CUDA(cudaMemcpyAsync(dest.get(), src, n * sizeof(T), cudaMemcpyDefault, stream.get()));
    }

    /** copies \p n elements from \p src to \p dest asynchronously
     *
     * \param[in]   src     device pointer
     * \param[out]  dest    device pointer
     * \param       stream  CUDA stream that has to be used for the memory transfer
     *
     * Pre-conditions:
     * - memory pointed by \p dest and \p src must be large enough to hold \p n elements
     *
     * Exception Guarantee: Basic
     */
    template <class T>
    void memcpy(DevicePtr<T> dest, DevicePtr<const T> src, std::size_t n, const Stream& stream) {
        if (n <= 0) {
            CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive");
        }

        CUDA4DNN_CHECK_CUDA(cudaMemcpyAsync(dest.get(), src.get(), n * sizeof(T), cudaMemcpyDefault, stream.get()));
    }

    /** sets \p n elements to \p ch in \p dest asynchronously
     *
     * \param[in]   src     device pointer
     * \param[out]  ch      8-bit value to fill the device memory with
     * \param       stream  CUDA stream that has to be used for the memory operation
     *
     * Pre-conditions:
     * - memory pointed by \p dest must be large enough to hold \p n elements
     *
     * Exception Guarantee: Basic
     */
    template <class T>
    void memset(DevicePtr<T> dest, std::int8_t ch, std::size_t n, const Stream& stream) {
        if (n <= 0) {
            CV_Error(Error::StsBadArg, "number of elements to copy is zero or negtaive");
        }

        CUDA4DNN_CHECK_CUDA(cudaMemsetAsync(dest.get(), ch, n * sizeof(T), stream.get()));
    }

}}}} /* namespace cv::dnn::cuda4dnn::csl */

#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_POINTER_HPP */