Commit d85e67d3 authored by Yashas Samaga B L's avatar Yashas Samaga B L Committed by Alexander Alekhin

Merge pull request #16063 from YashasSamaga:cuda4dnn-shortcut-unequal

support eltwise sum with different number of input channels in CUDA backend

* add shortcut primitive

* add offsets in shortcut kernel

* skip tests involving more than two inputs

* remove redundant modulus operation

* support multiple inputs

* remove whole file indentation

* skip acc in0 trunc test if weighted

* use shortcut iff channels are unequal
parent c30af724
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "grid_stride_range.hpp"
#include "execution.hpp"
#include "vector_traits.hpp"
#include "../cuda4dnn/csl/stream.hpp"
#include "../cuda4dnn/csl/span.hpp"
#include "../cuda4dnn/csl/tensor.hpp"
#include <opencv2/core.hpp>
using namespace cv::dnn::cuda4dnn::csl;
using namespace cv::dnn::cuda4dnn::csl::device;
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
namespace raw {
template <class T, std::size_t N>
__global__ void input_shortcut_vec(
Span<T> output,
View<T> input, index_type c_input, /* `c_input` = number of channels in `input` */
View<T> from, index_type c_from, /* `c_from` = number of channels in `from` */
size_type channel_stride /* common for both `input` and `from` */)
{
using vector_type = get_vector_type_t<T, N>;
auto output_vPtr = vector_type::get_pointer(output.data());
auto input_vPtr = vector_type::get_pointer(input.data());
auto from_vPtr = vector_type::get_pointer(from.data());
auto batch_stride_input = c_input * channel_stride;
auto batch_stride_from = c_from * channel_stride;
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
const auto actual_idx = i * vector_type::size();
const auto b = actual_idx / batch_stride_input; /* `input` and `output` have the same shape */
const auto c = (actual_idx % batch_stride_input) / channel_stride;
const auto c_offset = actual_idx % channel_stride;
vector_type vec_input;
v_load(vec_input, input_vPtr[i]);
/* We can break down the shortcut operation into two steps:
* - copy `input` to `output`
* - add `from` to corresponding channels in `output`
*
* In this scheme, only some channels in the `output` differ from `input`. They differ in the channels
* which have a corresponding channel in `from`.
*/
if (c < c_from) {
const auto from_actual_idx = b * batch_stride_from + c * channel_stride + c_offset;
const auto from_vec_idx = from_actual_idx / vector_type::size();
vector_type vec_from;
v_load(vec_from, from_vPtr[from_vec_idx]);
for (int j = 0; j < vector_type::size(); j++)
vec_input.data[j] += vec_from.data[j];
}
v_store(output_vPtr[i], vec_input);
}
}
}
template <class T, std::size_t N>
void launch_vectorized_input_shortcut(const Stream& stream, Span<T> output, View<T> input, std::size_t c_input, View<T> from, std::size_t c_from, std::size_t channel_stride) {
CV_Assert(is_fully_aligned<T>(output, N));
CV_Assert(is_fully_aligned<T>(input, N));
CV_Assert(is_fully_aligned<T>(from, N));
CV_Assert(channel_stride % N == 0);
auto kernel = raw::input_shortcut_vec<T, N>;
auto policy = make_policy(kernel, output.size() / N, 0, stream);
launch_kernel(kernel, policy, output, input, c_input, from, c_from, channel_stride);
}
template <class T>
void input_shortcut(const csl::Stream& stream, csl::TensorSpan<T> output, csl::TensorView<T> input, csl::TensorView<T> from) {
CV_Assert(is_shape_same(output, input));
CV_Assert(output.rank() == from.rank());
for (int i = 0; i < output.rank(); i++) {
if (i != 1) {
CV_Assert(from.get_axis_size(i) == output.get_axis_size(i));
}
}
auto channel_stride = output.size_range(2, output.rank()); /* same for `output`, `input` and `from` */
auto c_input = input.get_axis_size(1);
auto c_from = from.get_axis_size(1);
if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4) && is_fully_aligned<T>(from, 4) && channel_stride % 4 == 0) {
launch_vectorized_input_shortcut<T, 4>(stream, output, input, c_input, from, c_from, channel_stride);
} else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2) && is_fully_aligned<T>(from, 2) && channel_stride % 2 == 0) {
launch_vectorized_input_shortcut<T, 2>(stream, output, input, c_input, from, c_from, channel_stride);
} else {
launch_vectorized_input_shortcut<T, 1>(stream, output, input, c_input, from, c_from, channel_stride);
}
}
template void input_shortcut(const Stream&, TensorSpan<__half>, TensorView<__half>, TensorView<__half>);
template void input_shortcut(const Stream&, TensorSpan<float>, TensorView<float>, TensorView<float>);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_KERNELS_SHORTCUT_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_KERNELS_SHORTCUT_HPP
#include "../csl/stream.hpp"
#include "../csl/tensor.hpp"
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
template <class T>
void input_shortcut(const csl::Stream& stream, csl::TensorSpan<T> output, csl::TensorView<T> input, csl::TensorView<T> from);
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_SHORTCUT_HPP */
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SHORTCUT_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SHORTCUT_HPP
#include "../../op_cuda.hpp"
#include "../csl/stream.hpp"
#include "../csl/tensor.hpp"
#include "../csl/tensor_ops.hpp"
#include "../kernels/shortcut.hpp"
#include <opencv2/core.hpp>
#include <utility>
namespace cv { namespace dnn { namespace cuda4dnn {
template <class T>
class ShortcutOp final : public CUDABackendNode {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
ShortcutOp(csl::Stream stream_) : stream(std::move(stream_)) { }
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
{
CV_Assert(outputs.size() == 1);
auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
auto input_wrapper = inputs[0].dynamicCast<wrapper_type>();
auto input = input_wrapper->getView();
/* output shape is determined by the input shape */
CV_Assert(is_shape_same(output, input));
for (int i = 1; i < inputs.size(); i++)
{
auto from_wrapper = inputs[i].dynamicCast<wrapper_type>();
auto from = from_wrapper->getView();
CV_Assert(output.rank() == from.rank());
for (int i = 0; i < output.rank(); i++) {
if (i != 1) {
CV_Assert(from.get_axis_size(i) == output.get_axis_size(i));
}
}
if (i == 1)
{
/* optimized path for first two inputs */
kernels::input_shortcut<T>(stream, output, input, from);
}
else
{
kernels::input_shortcut<T>(stream, output, output, from);
}
}
}
private:
csl::Stream stream;
};
}}} /* namespace cv::dnn::cuda4dnn */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_SHORTCUT_HPP */
......@@ -53,6 +53,7 @@
#ifdef HAVE_CUDA
#include "../cuda4dnn/primitives/eltwise.hpp"
#include "../cuda4dnn/primitives/shortcut.hpp"
using namespace cv::dnn::cuda4dnn;
#endif
......@@ -155,8 +156,14 @@ public:
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
if (backendId == DNN_BACKEND_CUDA)
{
if(channelsModeInput == ELTWISE_CHANNNELS_INPUT_0 || channelsModeInput == ELTWISE_CHANNNELS_INPUT_0_TRUNCATE)
return op == SUM && coeffs.empty();
return channelsModeInput == ELTWISE_CHANNNELS_SAME;
}
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_CUDA ||
(backendId == DNN_BACKEND_HALIDE && op != DIV) || // TODO: not implemented, see PR #15811
((((backendId == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && (preferableTarget != DNN_TARGET_OPENCL || coeffs.empty()))
|| backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) && channelsMode == ELTWISE_CHANNNELS_SAME));
......@@ -623,6 +630,25 @@ public:
{
auto context = reinterpret_cast<csl::CSLContext*>(context_);
CV_Assert(channelsModeInput == ELTWISE_CHANNNELS_INPUT_0 ||
channelsModeInput == ELTWISE_CHANNNELS_INPUT_0_TRUNCATE ||
channelsModeInput == ELTWISE_CHANNNELS_SAME);
if(channelsModeInput == ELTWISE_CHANNNELS_INPUT_0 || channelsModeInput == ELTWISE_CHANNNELS_INPUT_0_TRUNCATE)
{
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
for (int i = 1; i < inputs.size(); i++)
{
auto from_wrapper = inputs[i].dynamicCast<CUDABackendWrapper>();
if (input_wrapper->getShape()[1] != from_wrapper->getShape()[1])
{
CV_Assert(op == SUM);
CV_Assert(coeffs.empty());
return make_cuda_node<cuda4dnn::ShortcutOp>(preferableTarget, std::move(context->stream));
}
}
}
auto op_ = [this] {
switch (op) {
case MAX: return cuda4dnn::EltwiseOpType::MAX;
......
......@@ -528,8 +528,6 @@ INSTANTIATE_TEST_CASE_P(/**/, Test_Darknet_nets, dnnBackendsAndTargets());
TEST_P(Test_Darknet_layers, shortcut)
{
if (backend == DNN_BACKEND_CUDA)
applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA);
testDarknetLayer("shortcut");
testDarknetLayer("shortcut_leaky");
testDarknetLayer("shortcut_unequal");
......
......@@ -1624,7 +1624,7 @@ TEST_P(Layer_Test_Eltwise_unequal, accuracy_input_0_truncate)
int backendId = get<0>(get<1>(GetParam()));
int targetId = get<1>(get<1>(GetParam()));
if (backendId == DNN_BACKEND_CUDA)
if (backendId == DNN_BACKEND_CUDA && weighted)
applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA);
Net net;
......@@ -1690,15 +1690,15 @@ TEST_P(Layer_Test_Eltwise_unequal, accuracy_input_0)
int backendId = get<0>(get<1>(GetParam()));
int targetId = get<1>(get<1>(GetParam()));
if (backendId == DNN_BACKEND_CUDA)
applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA);
Net net;
LayerParams lp;
lp.type = "Eltwise";
lp.name = "testLayer";
lp.set<std::string>("output_channels_mode", "input_0");
if (backendId == DNN_BACKEND_CUDA && weighted)
applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA);
const int inpShapes[][4] = {{1, 4, 2, 2}, {1, 2, 2, 2}, {1, 3, 2, 2}};
const int out_channels = inpShapes[0][1];
std::vector<String> inpNames(3);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment