Commit 31fd7e1f authored by Alexander Alekhin's avatar Alexander Alekhin

Merge pull request #998 from alalek:dnn_fix_ocl_pooling

parents 5d9808b0 e6550fca
...@@ -132,7 +132,7 @@ void PoolingLayerImpl::maxPooling(Blob &src, Blob &dst, Blob &mask) ...@@ -132,7 +132,7 @@ void PoolingLayerImpl::maxPooling(Blob &src, Blob &dst, Blob &mask)
bool PoolingLayerImpl::maxPooling_ocl(Blob &src, Blob &dst, Blob &mask) bool PoolingLayerImpl::maxPooling_ocl(Blob &src, Blob &dst, Blob &mask)
{ {
return pooling_ocl("MaxPoolForward", src, dst); return pooling_ocl("MaxPoolForward", src, dst, &mask);
} }
void PoolingLayerImpl::avePooling(Blob &src, Blob &dst) void PoolingLayerImpl::avePooling(Blob &src, Blob &dst)
...@@ -201,22 +201,36 @@ bool PoolingLayerImpl::pooling_ocl(const char *kname, const Blob &src, Blob &dst ...@@ -201,22 +201,36 @@ bool PoolingLayerImpl::pooling_ocl(const char *kname, const Blob &src, Blob &dst
{ {
const UMat &srcMat = src.umatRefConst(); const UMat &srcMat = src.umatRefConst();
UMat &dstMat = dst.umatRef(); UMat &dstMat = dst.umatRef();
UMat* indexesMat = mask == NULL ? NULL : &dst.umatRef(); UMat *maskUMat = mask == NULL ? NULL : &mask->umatRef();
CV_Assert(maskUMat == NULL || maskUMat->type() == CV_32FC1); // FIXIT CV_32SC1
CV_Assert(maskUMat == NULL || maskUMat->offset == 0);
CV_Assert(srcMat.offset == 0 && dstMat.offset == 0); CV_Assert(srcMat.offset == 0 && dstMat.offset == 0);
ocl::Kernel ker(kname, ocl::dnn::pooling_oclsrc, String("-DT=") + ocl::typeToStr(src.type())); ocl::Kernel ker(kname, ocl::dnn::pooling_oclsrc,
cv::format("-DT=%s%s", ocl::typeToStr(src.type()), maskUMat ? " -DMASK=1" : ""));
if (ker.empty()) if (ker.empty())
return false; return false;
BlobShape s = src.shape(); BlobShape s = src.shape();
size_t nthreads = dst.total(); size_t nthreads = dst.total();
if (maskUMat)
{
ker.args((int)nthreads, ker.args((int)nthreads,
ocl::KernelArg::PtrReadOnly(srcMat), s[0], s[1], s[2], s[3], ocl::KernelArg::PtrReadOnly(srcMat), s[0], s[1], s[2], s[3],
out.height, out.width, kernel.height, kernel.width, out.height, out.width, kernel.height, kernel.width,
stride.height, stride.width, pad.height, pad.width, stride.height, stride.width, pad.height, pad.width,
ocl::KernelArg::PtrWriteOnly(dstMat), ocl::KernelArg::PtrWriteOnly(dstMat),
ocl::KernelArg(ocl::KernelArg::PTR_ONLY + ocl::KernelArg::WRITE_ONLY, indexesMat)); ocl::KernelArg::PtrWriteOnly(*maskUMat));
}
else
{
ker.args((int)nthreads,
ocl::KernelArg::PtrReadOnly(srcMat), s[0], s[1], s[2], s[3],
out.height, out.width, kernel.height, kernel.width,
stride.height, stride.width, pad.height, pad.width,
ocl::KernelArg::PtrWriteOnly(dstMat));
}
size_t wgSize = ocl::Device::getDefault().maxWorkGroupSize(); size_t wgSize = ocl::Device::getDefault().maxWorkGroupSize();
if (!ker.run(1, &nthreads, &wgSize, true)) if (!ker.run(1, &nthreads, &wgSize, true))
......
...@@ -24,8 +24,16 @@ ...@@ -24,8 +24,16 @@
* POSSIBILITY OF SUCH DAMAGE. * POSSIBILITY OF SUCH DAMAGE.
**************************************************************************************/ **************************************************************************************/
__kernel void MaxPoolForward(const int nthreads, __global T* bottom_data, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, __global T* top_data, __global int* mask __kernel void MaxPoolForward(const int nthreads,
) { __global T* bottom_data, const int num, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
__global T* top_data
#ifdef MASK
, __global float* mask
#endif
)
{
int index = get_global_id(0); int index = get_global_id(0);
int tmp = get_global_size(0); int tmp = get_global_size(0);
for(index; index < nthreads; index += tmp) { for(index; index < nthreads; index += tmp) {
...@@ -51,15 +59,25 @@ __kernel void MaxPoolForward(const int nthreads, __global T* bottom_data, const ...@@ -51,15 +59,25 @@ __kernel void MaxPoolForward(const int nthreads, __global T* bottom_data, const
} }
} }
} }
top_data[index] = maxval; top_data[index] = maxval;
if (mask) { #ifdef MASK
mask[index] = maxidx; mask[index] = maxidx;
} #endif
} }
} }
__kernel void AvePoolForward(const int nthreads, __global T* bottom_data, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w,__global T* top_data) { __kernel void AvePoolForward(const int nthreads,
__global T* bottom_data, const int num, const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
__global T* top_data
#ifdef MASK
, __global float* mask // NOT USED
#endif
)
{
int index = get_global_id(0); int index = get_global_id(0);
int tmp = get_global_size(0); int tmp = get_global_size(0);
for(index; index < nthreads; index+=tmp) { for(index; index < nthreads; index+=tmp) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment