Commit ed8209eb authored by Anna Petrovicheva's avatar Anna Petrovicheva

Implemented dilated convolution

parent d83aa810
......@@ -53,7 +53,7 @@ namespace dnn
{
ConvolutionLayer::ConvolutionLayer(LayerParams &params) : Layer(params)
{
getKernelParams(params, kerH, kerW, padH, padW, strideH, strideW);
getKernelParams(params, kerH, kerW, padH, padW, strideH, strideW, dilationH, dilationW);
numOutput = params.get<int>("num_output");
bias = params.get<bool>("bias_term", true);
......@@ -119,7 +119,7 @@ namespace dnn
inline bool ConvolutionLayer::is1x1() const
{
return (kerH == 1 && kerW == 1) && (strideW == 1 && strideH == 1); //hotfix with stride
return (kerH == 1 && kerW == 1) && (strideW == 1 && strideH == 1) && (dilationW == 1 && dilationH == 1); //hotfix with stride
}
void ConvolutionLayer::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
......@@ -179,9 +179,9 @@ namespace dnn
#endif // HAVE_OPENCL
if (inpBlob.type() == CV_32F)
im2col_CpuPBody<float>::run((float*)srcPtr, inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, colMat.ptr<float>());
im2col_CpuPBody<float>::run((float*)srcPtr, inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, dilationH, dilationW, colMat.ptr<float>());
if (inpBlob.type() == CV_64F)
im2col_CpuPBody<double>::run((double*)srcPtr, inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, colMat.ptr<double>());
im2col_CpuPBody<double>::run((double*)srcPtr, inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, dilationH, dilationW, colMat.ptr<double>());
}
void ConvolutionLayer::computeInpOutShape(const Blob &inpBlob)
......@@ -253,9 +253,9 @@ namespace dnn
if (is1x1()) return;
if (dstMat.type() == CV_32F)
col2im_cpu(colMat.ptr<float>(), inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, dstMat.ptr<float>());
col2im_cpu(colMat.ptr<float>(), inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, dilationH, dilationW, dstMat.ptr<float>());
if (dstMat.type() == CV_64F)
col2im_cpu(colMat.ptr<double>(), inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, dstMat.ptr<double>());
col2im_cpu(colMat.ptr<double>(), inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, dilationH, dilationW, dstMat.ptr<double>());
}
}
}
......@@ -56,6 +56,7 @@ namespace dnn
int padH, padW;
int kerH, kerW;
int strideH, strideW;
int dilationH, dilationW;
int inpH, inpW, inpCn;
int outH, outW, outCn;
......
......@@ -46,7 +46,7 @@ namespace cv
namespace dnn
{
void getKernelParams(LayerParams &params, int &kernelH, int &kernelW, int &padH, int &padW, int &strideH, int &strideW)
void getKernelParams(LayerParams &params, int &kernelH, int &kernelW, int &padH, int &padW, int &strideH, int &strideW, int &dilationH, int &dilationW)
{
if (params.has("kernel_h") && params.has("kernel_w"))
{
......@@ -82,7 +82,17 @@ void getKernelParams(LayerParams &params, int &kernelH, int &kernelW, int &padH,
strideH = strideW = params.get<int>("stride", 1);
}
CV_Assert(kernelH > 0 && kernelW > 0 && padH >= 0 && padW >= 0 && strideH > 0 && strideW > 0);
if (params.has("dilation_h") && params.has("dilation_w"))
{
dilationH = params.get<int>("dilation_h");
dilationW = params.get<int>("dilation_w");
}
else
{
dilationH = dilationW = params.get<int>("dilation", 1);
}
CV_Assert(kernelH > 0 && kernelW > 0 && padH >= 0 && padW >= 0 && strideH > 0 && strideW > 0 && dilationH > 0 && dilationW > 0);
}
}
......
......@@ -48,7 +48,7 @@ namespace cv
namespace dnn
{
void getKernelParams(LayerParams &params, int &kernelH, int &kernelW, int &padH, int &padW, int &strideH, int &strideW);
void getKernelParams(LayerParams &params, int &kernelH, int &kernelW, int &padH, int &padW, int &strideH, int &strideW, int &dilationH, int &dilationW);
}
}
......
......@@ -57,26 +57,31 @@ class im2col_CpuPBody : public cv::ParallelLoopBody
int kernel_h, kernel_w;
int pad_h, pad_w;
int stride_h, stride_w;
int dilation_h, dilation_w;
Dtype* data_col;
int height_col, width_col, channels_col;
public:
im2col_CpuPBody(const Dtype* data_im_,
int channels_, int height_, int width_,
int kernel_h_, int kernel_w_,
int pad_h_, int pad_w_,
int stride_h_, int stride_w_,
Dtype* data_col_) :
int channels_, int height_, int width_,
int kernel_h_, int kernel_w_,
int pad_h_, int pad_w_,
int stride_h_, int stride_w_,
int dilation_h_, int dilation_w_,
Dtype* data_col_) :
data_im(data_im_),
channels(channels_), height(height_), width(width_),
kernel_h(kernel_h_), kernel_w(kernel_w_),
pad_h(pad_h_), pad_w(pad_w_),
stride_h(stride_h_), stride_w(stride_w_),
dilation_h(dilation_h_), dilation_w(dilation_w_),
data_col(data_col_)
{
height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
height_col = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1))
/ stride_h + 1;
width_col = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1))
/ stride_w + 1;
channels_col = channels * kernel_h * kernel_w;
}
......@@ -85,25 +90,29 @@ public:
int kernel_h, int kernel_w,
int pad_h, int pad_w,
int stride_h, int stride_w,
int dilation_h, int dilation_w,
Dtype* data_col)
{
im2col_CpuPBody<Dtype> pb(data_im, channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, data_col);
im2col_CpuPBody<Dtype> pb(data_im, channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, data_col);
cv::parallel_for_(Range(0, pb.channels_col), pb);
}
virtual void operator ()(const Range &r) const
{
for (int c = r.start; c < r.end; ++c) {
for (int c = r.start; c < r.end; ++c)
{
int w_offset = c % kernel_w;
int h_offset = (c / kernel_w) % kernel_h;
int c_im = c / kernel_h / kernel_w;
for (int h = 0; h < height_col; ++h) {
for (int w = 0; w < width_col; ++w) {
int h_pad = h * stride_h - pad_h + h_offset;
int w_pad = w * stride_w - pad_w + w_offset;
for (int h = 0; h < height_col; ++h)
{
for (int w = 0; w < width_col; ++w)
{
int h_pad = h * (stride_h + dilation_h) - pad_h + h_offset;
int w_pad = w * (stride_w + dilation_w) - pad_w + w_offset;
if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
data_col[(c * height_col + h) * width_col + w] =
data_im[(c_im * height + h_pad) * width + w_pad];
data_im[(c_im * height + h_pad) * width + w_pad];
else
data_col[(c * height_col + h) * width_col + w] = 0;
}
......@@ -118,12 +127,13 @@ void col2im_cpu(const Dtype* data_col,
int patch_h, int patch_w,
int pad_h, int pad_w,
int stride_h, int stride_w,
int dilation_h, int dilation_w,
Dtype* data_im)
{
memset(data_im, 0, height * width * channels * sizeof(Dtype));
int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1;
int height_col = (height + 2 * pad_h - (dilation_h * (patch_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (patch_w - 1) + 1)) / stride_w + 1;
int channels_col = channels * patch_h * patch_w;
for (int c = 0; c < channels_col; ++c)
......@@ -136,12 +146,12 @@ void col2im_cpu(const Dtype* data_col,
{
for (int w = 0; w < width_col; ++w)
{
int h_pad = h * stride_h - pad_h + h_offset;
int w_pad = w * stride_w - pad_w + w_offset;
int h_pad = h * (stride_h + dilation_h) - pad_h + h_offset;
int w_pad = w * (stride_w + dilation_w) - pad_w + w_offset;
if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
data_im[(c_im * height + h_pad) * width + w_pad] +=
data_col[(c * height_col + h) * width_col + w];
data_col[(c * height_col + h) * width_col + w];
}
}
}
......@@ -153,6 +163,7 @@ void im2col_ocl(UMat &img,
int kernel_h, int kernel_w,
int pad_h, int pad_w,
int stride_h, int stride_w,
int dilation_h, int dilation_w,
UMat &col);
#endif
......
......@@ -72,7 +72,8 @@ namespace dnn
type = MAX;
}
getKernelParams(params, kernelH, kernelW, padH, padW, strideH, strideW);
int defaultDilation = 1;
getKernelParams(params, kernelH, kernelW, padH, padW, strideH, strideW, defaultDilation, defaultDilation);
}
void PoolingLayer::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment