Commit e6550fca authored by Alexander Alekhin's avatar Alexander Alekhin

dnn: fix OpenCL code in pooling_ocl

parent 3cdc0e48
......@@ -132,7 +132,7 @@ void PoolingLayerImpl::maxPooling(Blob &src, Blob &dst, Blob &mask)
bool PoolingLayerImpl::maxPooling_ocl(Blob &src, Blob &dst, Blob &mask)
{
return pooling_ocl("MaxPoolForward", src, dst);
return pooling_ocl("MaxPoolForward", src, dst, &mask);
}
void PoolingLayerImpl::avePooling(Blob &src, Blob &dst)
......@@ -201,22 +201,36 @@ bool PoolingLayerImpl::pooling_ocl(const char *kname, const Blob &src, Blob &dst
{
const UMat &srcMat = src.umatRefConst();
UMat &dstMat = dst.umatRef();
UMat* indexesMat = mask == NULL ? NULL : &dst.umatRef();
UMat *maskUMat = mask == NULL ? NULL : &mask->umatRef();
CV_Assert(maskUMat == NULL || maskUMat->type() == CV_32FC1); // FIXIT CV_32SC1
CV_Assert(maskUMat == NULL || maskUMat->offset == 0);
CV_Assert(srcMat.offset == 0 && dstMat.offset == 0);
ocl::Kernel ker(kname, ocl::dnn::pooling_oclsrc, String("-DT=") + ocl::typeToStr(src.type()));
ocl::Kernel ker(kname, ocl::dnn::pooling_oclsrc,
cv::format("-DT=%s%s", ocl::typeToStr(src.type()), maskUMat ? " -DMASK=1" : ""));
if (ker.empty())
return false;
BlobShape s = src.shape();
size_t nthreads = dst.total();
if (maskUMat)
{
ker.args((int)nthreads,
ocl::KernelArg::PtrReadOnly(srcMat), s[0], s[1], s[2], s[3],
out.height, out.width, kernel.height, kernel.width,
stride.height, stride.width, pad.height, pad.width,
ocl::KernelArg::PtrWriteOnly(dstMat),
ocl::KernelArg(ocl::KernelArg::PTR_ONLY + ocl::KernelArg::WRITE_ONLY, indexesMat));
ocl::KernelArg::PtrWriteOnly(*maskUMat));
}
else
{
ker.args((int)nthreads,
ocl::KernelArg::PtrReadOnly(srcMat), s[0], s[1], s[2], s[3],
out.height, out.width, kernel.height, kernel.width,
stride.height, stride.width, pad.height, pad.width,
ocl::KernelArg::PtrWriteOnly(dstMat));
}
size_t wgSize = ocl::Device::getDefault().maxWorkGroupSize();
if (!ker.run(1, &nthreads, &wgSize, true))
......
......@@ -24,8 +24,16 @@
* POSSIBILITY OF SUCH DAMAGE.
**************************************************************************************/
__kernel void MaxPoolForward(const int nthreads, __global T* bottom_data, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, __global T* top_data, __global int* mask
) {
__kernel void MaxPoolForward(const int nthreads,
__global T* bottom_data, const int num, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
__global T* top_data
#ifdef MASK
, __global float* mask
#endif
)
{
int index = get_global_id(0);
int tmp = get_global_size(0);
for(index; index < nthreads; index += tmp) {
......@@ -51,15 +59,25 @@ __kernel void MaxPoolForward(const int nthreads, __global T* bottom_data, const
}
}
}
top_data[index] = maxval;
if (mask) {
#ifdef MASK
mask[index] = maxidx;
}
#endif
}
}
__kernel void AvePoolForward(const int nthreads, __global T* bottom_data, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w,__global T* top_data) {
__kernel void AvePoolForward(const int nthreads,
__global T* bottom_data, const int num, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
__global T* top_data
#ifdef MASK
, __global float* mask // NOT USED
#endif
)
{
int index = get_global_id(0);
int tmp = get_global_size(0);
for(index; index < nthreads; index+=tmp) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment