Commit 601afeed authored by Vitaliy Lyudvichenko's avatar Vitaliy Lyudvichenko

Adding of OCL implementations and public interfaces for Convolution and LRN

parent 50c9e1c9
...@@ -205,6 +205,46 @@ namespace dnn ...@@ -205,6 +205,46 @@ namespace dnn
void forward(std::vector<Blob*> &input, std::vector<Blob> &output); void forward(std::vector<Blob*> &input, std::vector<Blob> &output);
}; };
class CV_EXPORTS_W BaseConvolutionLayer : public Layer
{
public:
Size kernel, pad, stride;
};
class CV_EXPORTS_W ConvolutionLayer : public BaseConvolutionLayer
{
public:
static Ptr<BaseConvolutionLayer> create();
static Ptr<BaseConvolutionLayer> create(Size kernel = Size(3, 3), Size pad = Size(0, 0), Size stride = Size(1, 1));
};
class CV_EXPORTS_W DeconvolutionLayer : public BaseConvolutionLayer
{
public:
static Ptr<BaseConvolutionLayer> create();
static Ptr<BaseConvolutionLayer> create(Size kernel = Size(3, 3), Size pad = Size(0, 0), Size stride = Size(1, 1));
};
class CV_EXPORTS_W LRNLayer : public Layer
{
public:
enum
{
CHANNEL_NRM,
SPATIAL_NRM
};
int type;
int size;
double alpha, beta;
};
//! @} //! @}
//! @} //! @}
......
...@@ -121,6 +121,7 @@ namespace dnn //! This namespace is used for dnn module functionlaity. ...@@ -121,6 +121,7 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
Layer(); Layer();
explicit Layer(const LayerParams &params); //!< Initializes only #name, #type and #blobs fields. explicit Layer(const LayerParams &params); //!< Initializes only #name, #type and #blobs fields.
void setParamsFrom(const LayerParams &params); //!< Initializes only #name, #type and #blobs fields.
virtual ~Layer(); virtual ~Layer();
}; };
......
...@@ -48,7 +48,10 @@ ...@@ -48,7 +48,10 @@
namespace cv { namespace cv {
namespace dnn { namespace dnn {
std::ostream &operator<< (std::ostream &s, cv::Range &r) //Useful shortcut
typedef BlobShape Shape;
inline std::ostream &operator<< (std::ostream &s, cv::Range &r)
{ {
return s << "[" << r.start << ", " << r.end << ")"; return s << "[" << r.start << ", " << r.end << ")";
} }
...@@ -96,8 +99,6 @@ Mat slice(const Mat &m, const _Range &r0, const _Range &r1) ...@@ -96,8 +99,6 @@ Mat slice(const Mat &m, const _Range &r0, const _Range &r1)
ranges[i] = Range::all(); ranges[i] = Range::all();
ranges[0] = r0; ranges[0] = r0;
ranges[1] = r1; ranges[1] = r1;
// for (int i = 0; i < m.dims; i++)
// std::cout << ranges[i] << "\n";
return m(&ranges[0]); return m(&ranges[0]);
} }
...@@ -128,8 +129,32 @@ Mat slice(const Mat &m, const _Range &r0, const _Range &r1, const _Range &r2, co ...@@ -128,8 +129,32 @@ Mat slice(const Mat &m, const _Range &r0, const _Range &r1, const _Range &r2, co
return m(&ranges[0]); return m(&ranges[0]);
} }
} //Traits for switching in ploymorphic implementations
template<typename XMat>
struct MatTraits
{
};
} template<>
struct MatTraits<cv::Mat>
{
enum
{
IS_MAT = 1,
IS_UMAT = 0,
};
};
template<>
struct MatTraits<cv::UMat>
{
enum
{
IS_MAT = 0,
IS_UMAT = 1,
};
};
}
}
#endif #endif
...@@ -543,6 +543,13 @@ Layer::Layer(const LayerParams &params) ...@@ -543,6 +543,13 @@ Layer::Layer(const LayerParams &params)
} }
void Layer::setParamsFrom(const LayerParams &params)
{
blobs = params.blobs;
name = params.name;
type = params.type;
}
int Layer::inputNameToIndex(String) int Layer::inputNameToIndex(String)
{ {
return -1; return -1;
......
...@@ -81,9 +81,9 @@ void initModule() ...@@ -81,9 +81,9 @@ void initModule()
REG_RUNTIME_LAYER_CLASS(Split, SplitLayer) REG_RUNTIME_LAYER_CLASS(Split, SplitLayer)
REG_RUNTIME_LAYER_CLASS(Reshape, ReshapeLayer) REG_RUNTIME_LAYER_CLASS(Reshape, ReshapeLayer)
REG_STATIC_LAYER_FUNC(Flatten, createFlattenLayer) REG_STATIC_LAYER_FUNC(Flatten, createFlattenLayer)
REG_RUNTIME_LAYER_CLASS(Pooling, PoolingLayer) REG_RUNTIME_LAYER_CLASS(Pooling, PoolingLayerImpl)
REG_RUNTIME_LAYER_CLASS(MVN, MVNLayer) REG_RUNTIME_LAYER_CLASS(MVN, MVNLayer)
REG_RUNTIME_LAYER_CLASS(LRN, LRNLayer) REG_RUNTIME_LAYER_FUNC(LRN, createLRNLayerFromCaffe)
REG_RUNTIME_LAYER_CLASS(InnerProduct, FullyConnectedLayer) REG_RUNTIME_LAYER_CLASS(InnerProduct, FullyConnectedLayer)
REG_RUNTIME_LAYER_CLASS(ReLU, ElementWiseLayer<ReLUFunctor>) REG_RUNTIME_LAYER_CLASS(ReLU, ElementWiseLayer<ReLUFunctor>)
...@@ -94,8 +94,8 @@ void initModule() ...@@ -94,8 +94,8 @@ void initModule()
REG_RUNTIME_LAYER_CLASS(Sigmoid, ElementWiseLayer<SigmoidFunctor>) REG_RUNTIME_LAYER_CLASS(Sigmoid, ElementWiseLayer<SigmoidFunctor>)
REG_RUNTIME_LAYER_CLASS(Dropout, BlankLayer) REG_RUNTIME_LAYER_CLASS(Dropout, BlankLayer)
REG_RUNTIME_LAYER_CLASS(Convolution, ConvolutionLayer) REG_RUNTIME_LAYER_FUNC(Convolution, createConvolutionLayerFromCaffe)
REG_RUNTIME_LAYER_CLASS(Deconvolution, DeConvolutionLayer) REG_RUNTIME_LAYER_FUNC(Deconvolution, createDeconvolutionLayerFromCaffe)
REG_RUNTIME_LAYER_CLASS(Concat, ConcatLayer) REG_RUNTIME_LAYER_CLASS(Concat, ConcatLayer)
init.status = true; init.status = true;
......
...@@ -42,61 +42,65 @@ ...@@ -42,61 +42,65 @@
#ifndef __OPENCV_DNN_LAYERS_CONVOLUTION_LAYER_HPP__ #ifndef __OPENCV_DNN_LAYERS_CONVOLUTION_LAYER_HPP__
#define __OPENCV_DNN_LAYERS_CONVOLUTION_LAYER_HPP__ #define __OPENCV_DNN_LAYERS_CONVOLUTION_LAYER_HPP__
#include "../precomp.hpp" #include "../precomp.hpp"
#include <opencv2/dnn/all_layers.hpp>
namespace cv namespace cv
{ {
namespace dnn namespace dnn
{ {
//TODO: simultaneously convolution and bias addition for cache optimization
class ConvolutionLayer : public Layer
{
protected:
bool bias;
int numOutput, group;
int padH, padW;
int kerH, kerW;
int strideH, strideW;
//TODO: simultaneously convolution and bias addition for cache optimization
class ConvolutionLayerImpl : public ConvolutionLayer
{
public:
ConvolutionLayerImpl();
virtual void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
virtual void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
virtual void init();
protected:
int numOutput, group;
int inpH, inpW, inpCn; int inpH, inpW, inpCn;
int outH, outW, outCn; int outH, outW, outCn;
int topH, topW, topCn; //switched between inp/out on deconv/conv int topH, topW, topCn; //switched between inp/out on deconv/conv
int inpGroupCn, outGroupCn; int inpGroupCn, outGroupCn;
int ksize; int ksize;
bool bias;
bool tryUseOpenCL, useOpenCL; bool tryUseOpenCL, useOpenCL;
Blob colBlob, biasOnesBlob; Blob colBlob, biasOnesBlob;
inline bool is1x1() const; bool is1x1() const;
virtual void computeInpOutShape(const Blob &inpBlob); virtual void computeInpOutShape(const Blob &inpBlob);
template<typename XMat>
void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void im2col(const Mat &srcImg, Mat &dstCol); void im2col(const Mat &srcImg, Mat &dstCol);
void im2col(const UMat &srcImg, UMat &dstCol); void im2col(const UMat &srcImg, UMat &dstCol);
};
class DeConvolutionLayerImpl : public ConvolutionLayerImpl
{
public:
DeConvolutionLayerImpl();
virtual void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
protected:
public: virtual void computeInpOutShape(const Blob &inpBlob);
ConvolutionLayer() {}
ConvolutionLayer(LayerParams &params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
template<typename XMat> template<typename XMat>
void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs); void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
class DeConvolutionLayer : public ConvolutionLayer
{
protected:
void computeInpOutShape(const Blob &inpBlob);
void col2im(const Mat &colMat, Mat &dstImg); void col2im(const Mat &colMat, Mat &dstImg);
void col2im(const UMat &colMat, UMat &dstImg); void col2im(const UMat &colMat, UMat &dstImg);
};
public: //Importers
DeConvolutionLayer(LayerParams &params); Ptr<Layer> createConvolutionLayerFromCaffe(LayerParams &params);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs); Ptr<Layer> createDeconvolutionLayerFromCaffe(LayerParams &params);
template<typename XMat>
void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
} }
} }
#endif #endif
...@@ -46,43 +46,44 @@ namespace cv ...@@ -46,43 +46,44 @@ namespace cv
namespace dnn namespace dnn
{ {
void getKernelParams(LayerParams &params, int &kernelH, int &kernelW, int &padH, int &padW, int &strideH, int &strideW) void getCaffeConvParams(LayerParams &params, Size &kernel, Size &pad, Size &stride)
{ {
if (params.has("kernel_h") && params.has("kernel_w")) if (params.has("kernel_h") && params.has("kernel_w"))
{ {
kernelH = params.get<int>("kernel_h"); kernel.height = params.get<int>("kernel_h");
kernelW = params.get<int>("kernel_w"); kernel.width = params.get<int>("kernel_w");
} }
else if (params.has("kernel_size")) else if (params.has("kernel_size"))
{ {
kernelH = kernelW = params.get<int>("kernel_size"); kernel.height = kernel.width = params.get<int>("kernel_size");
} }
else else
{ {
CV_Error(cv::Error::StsBadArg, "kernel_size (or kernel_h and kernel_w) not specified"); CV_Error(Error::StsBadArg, "kernel_size (or kernel_h and kernel_w) not specified");
} }
CV_Assert(kernel.height > 0 && kernel.width > 0);
if (params.has("pad_h") && params.has("pad_w")) if (params.has("pad_h") && params.has("pad_w"))
{ {
padH = params.get<int>("pad_h"); pad.height = params.get<int>("pad_h");
padW = params.get<int>("pad_w"); pad.width = params.get<int>("pad_w");
} }
else else
{ {
padH = padW = params.get<int>("pad", 0); pad.height = pad.width = params.get<int>("pad", 0);
} }
CV_Assert(pad.height >= 0 && pad.width >= 0);
if (params.has("stride_h") && params.has("stride_w")) if (params.has("stride_h") && params.has("stride_w"))
{ {
strideH = params.get<int>("stride_h"); stride.height = params.get<int>("stride_h");
strideW = params.get<int>("stride_w"); stride.width = params.get<int>("stride_w");
} }
else else
{ {
strideH = strideW = params.get<int>("stride", 1); stride.height = stride.width = params.get<int>("stride", 1);
} }
CV_Assert(stride.height > 0 && stride.width > 0);
CV_Assert(kernelH > 0 && kernelW > 0 && padH >= 0 && padW >= 0 && strideH > 0 && strideW > 0);
} }
} }
......
...@@ -48,7 +48,7 @@ namespace cv ...@@ -48,7 +48,7 @@ namespace cv
namespace dnn namespace dnn
{ {
void getKernelParams(LayerParams &params, int &kernelH, int &kernelW, int &padH, int &padW, int &strideH, int &strideW); void getCaffeConvParams(LayerParams &params, Size &kernel, Size &pad, Size &stride);
} }
} }
......
...@@ -42,45 +42,41 @@ ...@@ -42,45 +42,41 @@
#include "../precomp.hpp" #include "../precomp.hpp"
#include "layers_common.hpp" #include "layers_common.hpp"
#include "lrn_layer.hpp" #include "lrn_layer.hpp"
#include "opencl_kernels_dnn.hpp"
#include <opencv2/imgproc.hpp> #include <opencv2/imgproc.hpp>
#include <opencv2/core/ocl.hpp>
#include <opencv2/dnn/shape_utils.hpp>
#include <algorithm> #include <algorithm>
namespace cv namespace cv
{ {
namespace dnn namespace dnn
{ {
LRNLayer::LRNLayer(LayerParams &params) : Layer(params)
{ LRNLayerImpl::LRNLayerImpl()
String nrmType = params.get<String>("norm_region", "ACROSS_CHANNELS"); {
if (nrmType == "ACROSS_CHANNELS") size = 5;
alpha = 1;
beta = 0.75;
type = CHANNEL_NRM; type = CHANNEL_NRM;
else if (nrmType == "WITHIN_CHANNEL") }
type = SPATIAL_NRM;
else
CV_Error(Error::StsBadArg, "Unknown region type \"" + nrmType + "\"");
size = params.get<int>("local_size", 5); void LRNLayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
if (size % 2 != 1 || size <= 0) {
CV_Error(Error::StsBadArg, "LRN layer supports only positive odd values for local_size"); CV_Assert(inputs.size() == 1 && inputs[0]->dims() == 4);
useOpenCL = cv::ocl::useOpenCL();
alpha = params.get<double>("alpha", 1); if (type == SPATIAL_NRM && !useOpenCL)
beta = params.get<double>("beta", 0.75); buf.create(inputs[0]->shape().slice(2), inputs[0]->type(), Blob::ALLOC_MAT);
} if (type == CHANNEL_NRM && useOpenCL)
buf.create(inputs[0]->shape().slice(2), inputs[0]->type(), Blob::ALLOC_UMAT);
void LRNLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1);
outputs.resize(1); outputs.resize(1);
outputs[0].create(inputs[0]->shape(), inputs[0]->type());
}
Vec4i shape = inputs[0]->shape4(); void LRNLayerImpl::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
outputs[0].create(shape); {
shape[0] = 1; //maybe make shape[0] = 1 too
bufBlob.create(shape);
}
void LRNLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
Blob &src = *inputs[0]; Blob &src = *inputs[0];
Blob &dst = outputs[0]; Blob &dst = outputs[0];
...@@ -93,72 +89,171 @@ namespace dnn ...@@ -93,72 +89,171 @@ namespace dnn
spatialNormalization(src, dst); spatialNormalization(src, dst);
break; break;
default: default:
CV_Error(cv::Error::StsNotImplemented, "Unimplemented mode of LRN layer"); CV_Error(Error::StsNotImplemented, "Unimplemented mode of LRN layer");
break; break;
} }
} }
void LRNLayer::channelNoramlization(Blob &srcBlob, Blob &dstBlob) template<typename XMat>
static XMat getPlane(XMat &m, int n, int cn)
{
return reshaped(slice(m, n, cn), BlobShape::like(m).slice(2));
}
void LRNLayerImpl::channelNoramlization(Blob &src, Blob &dst)
{
if (!useOpenCL)
channelNoramlization_<Mat>(src, dst);
else
{ {
CV_DbgAssert(srcBlob.ptr() != dstBlob.ptr()); //channelNoramlization_ocl(src.getRefConst<UMat>(), dst.getRef<UMat>()); //consumes a lot of memory
channelNoramlization_<UMat>(src, dst);
}
}
template<typename XMat>
void LRNLayerImpl::channelNoramlization_(Blob &srcBlob, Blob &dstBlob)
{
int num = srcBlob.num(); int num = srcBlob.num();
int channels = srcBlob.channels(); int channels = srcBlob.channels();
int ksize = (size - 1) / 2; int ksize = (size - 1) / 2;
XMat srcMat = srcBlob.getRefConst<XMat>();
XMat dstMat = dstBlob.getRef<XMat>();
for (int n = 0; n < num; n++) for (int n = 0; n < num; n++)
{ {
Mat accum = dstBlob.getPlane(n, channels-1); //trick for memory saving XMat accum = getPlane(dstMat, n, channels-1); //trick for memory saving
accum.setTo(0); accum.setTo(0);
for (int cn = 0; cn < std::min(ksize, channels); cn++) for (int cn = 0; cn < std::min(ksize, channels); cn++)
cv::accumulateSquare(srcBlob.getPlane(n, cn), accum); cv::accumulateSquare(getPlane(srcMat, n, cn), accum);
for (int cn = 0; cn < channels; cn++) for (int cn = 0; cn < channels; cn++)
{ {
if (cn + ksize < channels) if (cn + ksize < channels)
{ {
cv::accumulateSquare(srcBlob.getPlane(n, cn + ksize), accum); cv::accumulateSquare(getPlane(srcMat, n, cn + ksize), accum);
} }
if (cn - ksize - 1 >= 0) if (cn - ksize - 1 >= 0)
{ {
Mat left = srcBlob.getPlane(n, cn - ksize - 1); //subtractSquare
cv::subtract(accum, left.mul(left), accum); //subtractSquare XMat left = getPlane(srcMat, n, cn - ksize - 1);
cv::pow(left, 2, left);
cv::subtract(accum, left, accum);
} }
Mat dst = dstBlob.getPlane(n, cn); XMat dst = getPlane(dstMat, n, cn);
accum.convertTo(dst, dst.type(), alpha/size, 1); accum.convertTo(dst, dst.type(), alpha/size, 1);
cv::pow(dst, beta, dst); cv::pow(dst, beta, dst);
cv::divide(srcBlob.getPlane(n, cn), dst, dst); cv::divide(getPlane(srcMat, n, cn), dst, dst);
}
} }
} }
}
void LRNLayer::spatialNormalization(Blob &srcBlob, Blob &dstBlob) bool LRNLayerImpl::channelNoramlization_ocl(const UMat &src, UMat &dst)
{ {
if (src.offset != 0 || dst.offset != 0) //TODO: add offset
return false;
String buildOpts = String("-DT=") + ocl::typeToStr(src.type());
ocl::Kernel kerScale("LRNFillScale", ocl::dnn::lrn_oclsrc, buildOpts);
if (kerScale.empty())
return false;
ocl::Kernel kerOutput("LRNComputeOutput", ocl::dnn::lrn_oclsrc, buildOpts);
if (kerOutput.empty())
return false;
Shape shape = Shape::like(src);
int ksize = (size - 1) / 2;
size_t wgSize = ocl::Device::getDefault().maxWorkGroupSize();
UMat &scaleBuf = buf.umatRef();
size_t nthreads = (size_t)(shape.total() / shape[1]);
kerScale.args((int)nthreads,
ocl::KernelArg::PtrReadOnly(src), shape[0], shape[1], shape[2], shape[3],
size, (float)(alpha/size), (float)ksize, ocl::KernelArg::PtrWriteOnly(scaleBuf));
if (!kerScale.run(1, &nthreads, &wgSize, true))
return false;
nthreads = (size_t)shape.total();
kerOutput.args((int)nthreads,
ocl::KernelArg::PtrReadOnly(src), ocl::KernelArg::PtrReadOnly(scaleBuf),
-beta, ocl::KernelArg::PtrWriteOnly(dst) );
if (!kerOutput.run(1, &nthreads, &wgSize, true))
return false;
return true;
}
void LRNLayerImpl::spatialNormalization(Blob &src, Blob &dst)
{
if (!useOpenCL)
spatialNormalization_<Mat>(src, dst);
else
spatialNormalization_<UMat>(src, dst);
}
template<typename XMat>
void LRNLayerImpl::spatialNormalization_(Blob &srcBlob, Blob &dstBlob)
{
int num = srcBlob.num(); int num = srcBlob.num();
int channels = srcBlob.channels(); int channels = srcBlob.channels();
XMat srcMat = srcBlob.getRefConst<XMat>();
XMat dstMat = dstBlob.getRef<XMat>();
for (int n = 0; n < num; n++) for (int n = 0; n < num; n++)
{ {
for (int cn = 0; cn < channels; cn++) for (int cn = 0; cn < channels; cn++)
{ {
Mat src = srcBlob.getPlane(n, cn); XMat src = getPlane(srcMat, n, cn);
Mat dst = dstBlob.getPlane(n, cn); XMat dst = getPlane(dstMat, n, cn);
uchar *dataDst0 = dst.data;
if (MatTraits<XMat>::IS_UMAT)
{
cv::sqrBoxFilter(src, dst, dst.depth(), Size(size, size), Point(-1, -1), false, BORDER_CONSTANT | BORDER_ISOLATED);
}
else
{
//TODO: fix cv::boxFilter with BORDER_ISOLATED flag in CPU mode
Mat bufMat = buf.getRef<Mat>();
src.copyTo(bufMat);
cv::sqrBoxFilter(bufMat, dst, dst.depth(), Size(size, size), Point(-1, -1), false, BORDER_CONSTANT);
}
cv::pow(srcBlob.getPlane(n, cn), 2, dst);
//TODO: check border type
cv::boxFilter(dst, dst, dst.depth(), cv::Size(size, size), cv::Point(-1, -1), false, cv::BORDER_CONSTANT);
dst.convertTo(dst, dst.type(), alpha/(size*size), 1); dst.convertTo(dst, dst.type(), alpha/(size*size), 1);
cv::pow(dst, beta, dst); cv::pow(dst, beta, dst);
cv::divide(src, dst, dst); cv::divide(src, dst, dst);
CV_Assert(dataDst0 == dst.data); //debug
}
} }
} }
}
Ptr<Layer> createLRNLayerFromCaffe(LayerParams &params)
{
LRNLayerImpl *l = new LRNLayerImpl();
String nrmType = params.get<String>("norm_region", "ACROSS_CHANNELS");
if (nrmType == "ACROSS_CHANNELS")
l->type = LRNLayer::CHANNEL_NRM;
else if (nrmType == "WITHIN_CHANNEL")
l->type = LRNLayer::SPATIAL_NRM;
else
CV_Error(Error::StsBadArg, "Unknown region type \"" + nrmType + "\"");
int size = params.get<int>("local_size", 5);
if (size % 2 != 1 || size <= 0)
CV_Error(Error::StsBadArg, "LRN layer supports only positive odd values for local_size");
l->size = size;
l->alpha = params.get<double>("alpha", 1);
l->beta = params.get<double>("beta", 0.75);
return Ptr<Layer>(l);
}
} }
} }
...@@ -42,34 +42,36 @@ ...@@ -42,34 +42,36 @@
#ifndef __OPENCV_DNN_LAYERS_LRN_LAYER_HPP__ #ifndef __OPENCV_DNN_LAYERS_LRN_LAYER_HPP__
#define __OPENCV_DNN_LAYERS_LRN_LAYER_HPP__ #define __OPENCV_DNN_LAYERS_LRN_LAYER_HPP__
#include "../precomp.hpp" #include "../precomp.hpp"
#include <opencv2/dnn/all_layers.hpp>
namespace cv namespace cv
{ {
namespace dnn namespace dnn
{ {
class LRNLayer : public Layer class LRNLayerImpl : public LRNLayer
{ {
enum bool useOpenCL;
{ Blob buf;
CHANNEL_NRM,
SPATIAL_NRM,
SPATIAL_CONTRAST_NRM //cuda-convnet feature
} type;
int size;
double alpha, beta;
Blob bufBlob;
void channelNoramlization(Blob &src, Blob &dst); void channelNoramlization(Blob &src, Blob &dst);
template<typename XMat>
void channelNoramlization_(Blob &src, Blob &dst);
bool channelNoramlization_ocl(const UMat &src, UMat &dst);
void spatialNormalization(Blob &src, Blob &dst); void spatialNormalization(Blob &src, Blob &dst);
template<typename XMat>
void spatialNormalization_(Blob &src, Blob &dst);
public: public:
LRNLayer(LayerParams &params); LRNLayerImpl();
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs); void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs); void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
}; };
Ptr<Layer> createLRNLayerFromCaffe(LayerParams &params);
} }
} }
#endif #endif
...@@ -72,22 +72,21 @@ namespace dnn ...@@ -72,22 +72,21 @@ namespace dnn
type = MAX; type = MAX;
} }
getKernelParams(params, kernelH, kernelW, padH, padW, strideH, strideW); getCaffeConvParams(params, kernel, pad, stride);
} }
void PoolingLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs) void PoolingLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{ {
CV_Assert(inputs.size() > 0); CV_Assert(inputs.size() > 0);
inpW = inputs[0]->cols(); inp = inputs[0]->size2();
inpH = inputs[0]->rows(); computeOutputShape(inp);
computeOutputShape(inpH, inpW);
outputs.resize(inputs.size()); outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) for (size_t i = 0; i < inputs.size(); i++)
{ {
CV_Assert(inputs[i]->rows() == inpH && inputs[i]->cols() == inpW); CV_Assert(inputs[i]->rows() == inp.height && inputs[i]->cols() == inp.width);
outputs[i].create(BlobShape(inputs[i]->num(), inputs[i]->channels(), outH, outW)); outputs[i].create(BlobShape(inputs[i]->num(), inputs[i]->channels(), out.height, out.width));
} }
} }
...@@ -104,7 +103,7 @@ namespace dnn ...@@ -104,7 +103,7 @@ namespace dnn
avePooling(*inputs[ii], outputs[ii]); avePooling(*inputs[ii], outputs[ii]);
break; break;
default: default:
CV_Error(cv::Error::StsNotImplemented, "Not implemented"); CV_Error(Error::StsNotImplemented, "Not implemented");
break; break;
} }
} }
...@@ -112,7 +111,7 @@ namespace dnn ...@@ -112,7 +111,7 @@ namespace dnn
void PoolingLayer::maxPooling(Blob &input, Blob &output) void PoolingLayer::maxPooling(Blob &input, Blob &output)
{ {
CV_DbgAssert(output.rows() == outH && output.cols() == outW); CV_DbgAssert(output.rows() == out.height && output.cols() == out.width);
for (int n = 0; n < input.num(); ++n) for (int n = 0; n < input.num(); ++n)
{ {
...@@ -121,23 +120,23 @@ namespace dnn ...@@ -121,23 +120,23 @@ namespace dnn
float *srcData = input.ptrf(n, c); float *srcData = input.ptrf(n, c);
float *dstData = output.ptrf(n, c); float *dstData = output.ptrf(n, c);
for (int ph = 0; ph < outH; ++ph) for (int ph = 0; ph < out.height; ++ph)
{ {
for (int pw = 0; pw < outW; ++pw) for (int pw = 0; pw < out.width; ++pw)
{ {
int hstart = ph * strideH - padH; int hstart = ph * stride.height - pad.height;
int wstart = pw * strideW - padW; int wstart = pw * stride.width - pad.width;
int hend = min(hstart + kernelH, inpH); int hend = min(hstart + kernel.height, inp.height);
int wend = min(wstart + kernelW, inpW); int wend = min(wstart + kernel.width, inp.width);
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
const int poolIndex = ph * outW + pw; const int poolIndex = ph * out.width + pw;
float max_val = -FLT_MAX; float max_val = -FLT_MAX;
for (int h = hstart; h < hend; ++h) for (int h = hstart; h < hend; ++h)
for (int w = wstart; w < wend; ++w) for (int w = wstart; w < wend; ++w)
{ {
const int index = h * inpW + w; const int index = h * inp.width + w;
if (srcData[index] > max_val) if (srcData[index] > max_val)
max_val = srcData[index]; max_val = srcData[index];
} }
...@@ -158,49 +157,49 @@ namespace dnn ...@@ -158,49 +157,49 @@ namespace dnn
float *srcData = input.ptrf(n, c); float *srcData = input.ptrf(n, c);
float *dstData = output.ptrf(n, c); float *dstData = output.ptrf(n, c);
for (int ph = 0; ph < outH; ++ph) for (int ph = 0; ph < out.height; ++ph)
{ {
for (int pw = 0; pw < outW; ++pw) for (int pw = 0; pw < out.width; ++pw)
{ {
int hstart = ph * strideH - padH; int hstart = ph * stride.height - pad.height;
int wstart = pw * strideW - padW; int wstart = pw * stride.width - pad.width;
int hend = min(hstart + kernelH, inpH + padH); int hend = min(hstart + kernel.height, inp.height + pad.height);
int wend = min(wstart + kernelW, inpW + padW); int wend = min(wstart + kernel.width, inp.width + pad.width);
int poolSize = (hend - hstart) * (wend - wstart); int poolSize = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
hend = min(hend, inpH); hend = min(hend, inp.height);
wend = min(wend, inpW); wend = min(wend, inp.width);
dstData[ph * outW + pw] = 0.f; dstData[ph * out.width + pw] = 0.f;
for (int h = hstart; h < hend; ++h) for (int h = hstart; h < hend; ++h)
for (int w = wstart; w < wend; ++w) for (int w = wstart; w < wend; ++w)
dstData[ph * outW + pw] += srcData[h * inpW + w]; dstData[ph * out.width + pw] += srcData[h * inp.width + w];
dstData[ph * outW + pw] /= poolSize; dstData[ph * out.width + pw] /= poolSize;
} }
} }
} }
} }
} }
void PoolingLayer::computeOutputShape(int inH, int inW) void PoolingLayer::computeOutputShape(Size inpSz)
{ {
//Yeah, something strange Caffe scheme-) //Yeah, something strange Caffe scheme-)
outH = static_cast<int>(ceil(static_cast<float>(inH + 2 * padH - kernelH) / strideH)) + 1; out.height = static_cast<int>(ceil(static_cast<float>(inpSz.height + 2 * pad.height - kernel.height) / stride.height)) + 1;
outW = static_cast<int>(ceil(static_cast<float>(inW + 2 * padW - kernelW) / strideW)) + 1; out.width = static_cast<int>(ceil(static_cast<float>(inpSz.width + 2 * pad.width - kernel.width) / stride.width)) + 1;
if (padH || padW) if (pad.height || pad.width)
{ {
// If we have padding, ensure that the last pooling starts strictly // If we have padding, ensure that the last pooling starts strictly
// inside the image (instead of at the padding); otherwise clip the last. // inside the image (instead of at the padding); otherwise clip the last.
if ((outH - 1) * strideH >= inH + padH) if ((out.height - 1) * stride.height >= inpSz.height + pad.height)
--outH; --out.height;
if ((outW - 1) * strideW >= inW + padW) if ((out.width - 1) * stride.width >= inpSz.width + pad.width)
--outW; --out.width;
CV_Assert((outH - 1) * strideH < inH + padH); CV_Assert((out.height - 1) * stride.height < inpSz.height + pad.height);
CV_Assert((outW - 1) * strideW < inW + padW); CV_Assert((out.width - 1) * stride.width < inpSz.width + pad.width);
} }
} }
} }
......
...@@ -57,14 +57,10 @@ namespace dnn ...@@ -57,14 +57,10 @@ namespace dnn
}; };
int type; int type;
int padH, padW; Size kernel, pad, stride;
int strideH, strideW; Size inp, out;
int kernelH, kernelW;
int inpH, inpW; void computeOutputShape(Size inpSz);
int outH, outW;
void computeOutputShape(int inpH, int inpW);
void maxPooling(Blob &input, Blob &output); void maxPooling(Blob &input, Blob &output);
void avePooling(Blob &input, Blob &output); void avePooling(Blob &input, Blob &output);
......
/*************************************************************************************
* Copyright (c) 2015, Advanced Micro Devices, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation and/or
* other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
* IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
* INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
* OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
**************************************************************************************/
__kernel void LRNComputeOutput(const int nthreads, __global T* in, __global T* scale, const T negative_beta, __global T* out) {
int index = get_global_id(0);
int tmp = get_global_size(0);
for(index; index < nthreads; index += tmp)
out[index] = in[index] * pow(scale[index], negative_beta);
}
__kernel void LRNFillScale(const int nthreads, __global T* in, const int num, const int channels, const int height, const int width, const int size, const T alpha_over_size, const T k, __global T* scale) {
int index = get_global_id(0);
int tmp = get_global_size(0);
for(index; index < nthreads; index += tmp) {
// find out the local offset
const int w = index % width;
const int h = (index / width) % height;
const int n = index / width / height;
const int offset = (n * channels * height + h) * width + w;
const int step = height * width;
in = in + offset;
scale = scale + offset;
int head = 0;
const int pre_pad = (size - 1) / 2;
const int post_pad = size - pre_pad - 1;
T accum_scale = 0;
// fill the scale at [n, :, h, w]
// accumulate values
while (head < post_pad && head < channels) {
accum_scale += in[head * step] * in[head * step];
++head;
}
// both add and subtract
while (head < channels) {
accum_scale += in[head * step] * in[head * step];
if (head - size >= 0) {
accum_scale -= in[(head - size) * step]
* in[(head - size) * step];
}
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
++head;
}
// subtract only
while (head < channels + post_pad) {
if (head - size >= 0) {
accum_scale -= in[(head - size) * step]
* in[(head - size) * step];
}
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
++head;
}
}
}
\ No newline at end of file
...@@ -97,12 +97,22 @@ OCL_TEST(Layer_Test_Softmax, Accuracy) ...@@ -97,12 +97,22 @@ OCL_TEST(Layer_Test_Softmax, Accuracy)
TEST(Layer_Test_LRN_spatial, Accuracy) TEST(Layer_Test_LRN_spatial, Accuracy)
{ {
testLayerUsingCaffeModels("layer_lrn_spatial"); OCL_OFF(testLayerUsingCaffeModels("layer_lrn_spatial"));
}
OCL_TEST(Layer_Test_LRN_spatial, Accuracy)
{
OCL_ON(testLayerUsingCaffeModels("layer_lrn_spatial"));
OCL_OFF();
} }
TEST(Layer_Test_LRN_channels, Accuracy) TEST(Layer_Test_LRN_channels, Accuracy)
{ {
testLayerUsingCaffeModels("layer_lrn_channels"); OCL_OFF(testLayerUsingCaffeModels("layer_lrn_channels"));
}
OCL_TEST(Layer_Test_LRN_channels, Accuracy)
{
OCL_ON(testLayerUsingCaffeModels("layer_lrn_channels"));
OCL_OFF();
} }
TEST(Layer_Test_Convolution, Accuracy) TEST(Layer_Test_Convolution, Accuracy)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment