Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
f614e354
Commit
f614e354
authored
May 06, 2013
by
Vladislav Vinogradov
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
split hough sources
parent
1d79e131
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1240 additions
and
0 deletions
+1240
-0
build_point_list.cu
modules/gpuimgproc/src/cuda/build_point_list.cu
+138
-0
generalized_hough.cu
modules/gpuimgproc/src/cuda/generalized_hough.cu
+0
-0
hough_circles.cu
modules/gpuimgproc/src/cuda/hough_circles.cu
+255
-0
hough_lines.cu
modules/gpuimgproc/src/cuda/hough_lines.cu
+212
-0
hough_segments.cu
modules/gpuimgproc/src/cuda/hough_segments.cu
+249
-0
generalized_hough.cpp
modules/gpuimgproc/src/generalized_hough.cpp
+0
-0
hough_circles.cpp
modules/gpuimgproc/src/hough_circles.cpp
+0
-0
hough_lines.cpp
modules/gpuimgproc/src/hough_lines.cpp
+202
-0
hough_segments.cpp
modules/gpuimgproc/src/hough_segments.cpp
+183
-0
precomp.hpp
modules/gpuimgproc/src/precomp.hpp
+1
-0
No files found.
modules/gpuimgproc/src/cuda/build_point_list.cu
0 → 100644
View file @
f614e354
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
// Copyright (C) 2009, Willow Garage Inc., all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#if !defined CUDA_DISABLER
#include "opencv2/core/cuda/common.hpp"
#include "opencv2/core/cuda/emulation.hpp"
namespace cv { namespace gpu { namespace cudev
{
namespace hough
{
__device__ int g_counter;
template <int PIXELS_PER_THREAD>
__global__ void buildPointList(const PtrStepSzb src, unsigned int* list)
{
__shared__ unsigned int s_queues[4][32 * PIXELS_PER_THREAD];
__shared__ int s_qsize[4];
__shared__ int s_globStart[4];
const int x = blockIdx.x * blockDim.x * PIXELS_PER_THREAD + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
if (threadIdx.x == 0)
s_qsize[threadIdx.y] = 0;
__syncthreads();
if (y < src.rows)
{
// fill the queue
const uchar* srcRow = src.ptr(y);
for (int i = 0, xx = x; i < PIXELS_PER_THREAD && xx < src.cols; ++i, xx += blockDim.x)
{
if (srcRow[xx])
{
const unsigned int val = (y << 16) | xx;
const int qidx = Emulation::smem::atomicAdd(&s_qsize[threadIdx.y], 1);
s_queues[threadIdx.y][qidx] = val;
}
}
}
__syncthreads();
// let one thread reserve the space required in the global list
if (threadIdx.x == 0 && threadIdx.y == 0)
{
// find how many items are stored in each list
int totalSize = 0;
for (int i = 0; i < blockDim.y; ++i)
{
s_globStart[i] = totalSize;
totalSize += s_qsize[i];
}
// calculate the offset in the global list
const int globalOffset = atomicAdd(&g_counter, totalSize);
for (int i = 0; i < blockDim.y; ++i)
s_globStart[i] += globalOffset;
}
__syncthreads();
// copy local queues to global queue
const int qsize = s_qsize[threadIdx.y];
int gidx = s_globStart[threadIdx.y] + threadIdx.x;
for(int i = threadIdx.x; i < qsize; i += blockDim.x, gidx += blockDim.x)
list[gidx] = s_queues[threadIdx.y][i];
}
int buildPointList_gpu(PtrStepSzb src, unsigned int* list)
{
const int PIXELS_PER_THREAD = 16;
void* counterPtr;
cudaSafeCall( cudaGetSymbolAddress(&counterPtr, g_counter) );
cudaSafeCall( cudaMemset(counterPtr, 0, sizeof(int)) );
const dim3 block(32, 4);
const dim3 grid(divUp(src.cols, block.x * PIXELS_PER_THREAD), divUp(src.rows, block.y));
cudaSafeCall( cudaFuncSetCacheConfig(buildPointList<PIXELS_PER_THREAD>, cudaFuncCachePreferShared) );
buildPointList<PIXELS_PER_THREAD><<<grid, block>>>(src, list);
cudaSafeCall( cudaGetLastError() );
cudaSafeCall( cudaDeviceSynchronize() );
int totalCount;
cudaSafeCall( cudaMemcpy(&totalCount, counterPtr, sizeof(int), cudaMemcpyDeviceToHost) );
return totalCount;
}
}
}}}
#endif /* CUDA_DISABLER */
modules/gpuimgproc/src/cuda/hough.cu
→
modules/gpuimgproc/src/cuda/
generalized_
hough.cu
View file @
f614e354
This diff is collapsed.
Click to expand it.
modules/gpuimgproc/src/cuda/hough_circles.cu
0 → 100644
View file @
f614e354
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
// Copyright (C) 2009, Willow Garage Inc., all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#if !defined CUDA_DISABLER
#include "opencv2/core/cuda/common.hpp"
#include "opencv2/core/cuda/emulation.hpp"
#include "opencv2/core/cuda/dynamic_smem.hpp"
namespace cv { namespace gpu { namespace cudev
{
namespace hough_circles
{
__device__ int g_counter;
////////////////////////////////////////////////////////////////////////
// circlesAccumCenters
__global__ void circlesAccumCenters(const unsigned int* list, const int count, const PtrStepi dx, const PtrStepi dy,
PtrStepi accum, const int width, const int height, const int minRadius, const int maxRadius, const float idp)
{
const int SHIFT = 10;
const int ONE = 1 << SHIFT;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= count)
return;
const unsigned int val = list[tid];
const int x = (val & 0xFFFF);
const int y = (val >> 16) & 0xFFFF;
const int vx = dx(y, x);
const int vy = dy(y, x);
if (vx == 0 && vy == 0)
return;
const float mag = ::sqrtf(vx * vx + vy * vy);
const int x0 = __float2int_rn((x * idp) * ONE);
const int y0 = __float2int_rn((y * idp) * ONE);
int sx = __float2int_rn((vx * idp) * ONE / mag);
int sy = __float2int_rn((vy * idp) * ONE / mag);
// Step from minRadius to maxRadius in both directions of the gradient
for (int k1 = 0; k1 < 2; ++k1)
{
int x1 = x0 + minRadius * sx;
int y1 = y0 + minRadius * sy;
for (int r = minRadius; r <= maxRadius; x1 += sx, y1 += sy, ++r)
{
const int x2 = x1 >> SHIFT;
const int y2 = y1 >> SHIFT;
if (x2 < 0 || x2 >= width || y2 < 0 || y2 >= height)
break;
::atomicAdd(accum.ptr(y2 + 1) + x2 + 1, 1);
}
sx = -sx;
sy = -sy;
}
}
void circlesAccumCenters_gpu(const unsigned int* list, int count, PtrStepi dx, PtrStepi dy, PtrStepSzi accum, int minRadius, int maxRadius, float idp)
{
const dim3 block(256);
const dim3 grid(divUp(count, block.x));
cudaSafeCall( cudaFuncSetCacheConfig(circlesAccumCenters, cudaFuncCachePreferL1) );
circlesAccumCenters<<<grid, block>>>(list, count, dx, dy, accum, accum.cols - 2, accum.rows - 2, minRadius, maxRadius, idp);
cudaSafeCall( cudaGetLastError() );
cudaSafeCall( cudaDeviceSynchronize() );
}
////////////////////////////////////////////////////////////////////////
// buildCentersList
__global__ void buildCentersList(const PtrStepSzi accum, unsigned int* centers, const int threshold)
{
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
if (x < accum.cols - 2 && y < accum.rows - 2)
{
const int top = accum(y, x + 1);
const int left = accum(y + 1, x);
const int cur = accum(y + 1, x + 1);
const int right = accum(y + 1, x + 2);
const int bottom = accum(y + 2, x + 1);
if (cur > threshold && cur > top && cur >= bottom && cur > left && cur >= right)
{
const unsigned int val = (y << 16) | x;
const int idx = ::atomicAdd(&g_counter, 1);
centers[idx] = val;
}
}
}
int buildCentersList_gpu(PtrStepSzi accum, unsigned int* centers, int threshold)
{
void* counterPtr;
cudaSafeCall( cudaGetSymbolAddress(&counterPtr, g_counter) );
cudaSafeCall( cudaMemset(counterPtr, 0, sizeof(int)) );
const dim3 block(32, 8);
const dim3 grid(divUp(accum.cols - 2, block.x), divUp(accum.rows - 2, block.y));
cudaSafeCall( cudaFuncSetCacheConfig(buildCentersList, cudaFuncCachePreferL1) );
buildCentersList<<<grid, block>>>(accum, centers, threshold);
cudaSafeCall( cudaGetLastError() );
cudaSafeCall( cudaDeviceSynchronize() );
int totalCount;
cudaSafeCall( cudaMemcpy(&totalCount, counterPtr, sizeof(int), cudaMemcpyDeviceToHost) );
return totalCount;
}
////////////////////////////////////////////////////////////////////////
// circlesAccumRadius
__global__ void circlesAccumRadius(const unsigned int* centers, const unsigned int* list, const int count,
float3* circles, const int maxCircles, const float dp,
const int minRadius, const int maxRadius, const int histSize, const int threshold)
{
int* smem = DynamicSharedMem<int>();
for (int i = threadIdx.x; i < histSize + 2; i += blockDim.x)
smem[i] = 0;
__syncthreads();
unsigned int val = centers[blockIdx.x];
float cx = (val & 0xFFFF);
float cy = (val >> 16) & 0xFFFF;
cx = (cx + 0.5f) * dp;
cy = (cy + 0.5f) * dp;
for (int i = threadIdx.x; i < count; i += blockDim.x)
{
val = list[i];
const int x = (val & 0xFFFF);
const int y = (val >> 16) & 0xFFFF;
const float rad = ::sqrtf((cx - x) * (cx - x) + (cy - y) * (cy - y));
if (rad >= minRadius && rad <= maxRadius)
{
const int r = __float2int_rn(rad - minRadius);
Emulation::smem::atomicAdd(&smem[r + 1], 1);
}
}
__syncthreads();
for (int i = threadIdx.x; i < histSize; i += blockDim.x)
{
const int curVotes = smem[i + 1];
if (curVotes >= threshold && curVotes > smem[i] && curVotes >= smem[i + 2])
{
const int ind = ::atomicAdd(&g_counter, 1);
if (ind < maxCircles)
circles[ind] = make_float3(cx, cy, i + minRadius);
}
}
}
int circlesAccumRadius_gpu(const unsigned int* centers, int centersCount, const unsigned int* list, int count,
float3* circles, int maxCircles, float dp, int minRadius, int maxRadius, int threshold, bool has20)
{
void* counterPtr;
cudaSafeCall( cudaGetSymbolAddress(&counterPtr, g_counter) );
cudaSafeCall( cudaMemset(counterPtr, 0, sizeof(int)) );
const dim3 block(has20 ? 1024 : 512);
const dim3 grid(centersCount);
const int histSize = maxRadius - minRadius + 1;
size_t smemSize = (histSize + 2) * sizeof(int);
circlesAccumRadius<<<grid, block, smemSize>>>(centers, list, count, circles, maxCircles, dp, minRadius, maxRadius, histSize, threshold);
cudaSafeCall( cudaGetLastError() );
cudaSafeCall( cudaDeviceSynchronize() );
int totalCount;
cudaSafeCall( cudaMemcpy(&totalCount, counterPtr, sizeof(int), cudaMemcpyDeviceToHost) );
totalCount = ::min(totalCount, maxCircles);
return totalCount;
}
}
}}}
#endif /* CUDA_DISABLER */
modules/gpuimgproc/src/cuda/hough_lines.cu
0 → 100644
View file @
f614e354
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
// Copyright (C) 2009, Willow Garage Inc., all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#if !defined CUDA_DISABLER
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#include "opencv2/core/cuda/common.hpp"
#include "opencv2/core/cuda/emulation.hpp"
#include "opencv2/core/cuda/dynamic_smem.hpp"
namespace cv { namespace gpu { namespace cudev
{
namespace hough_lines
{
__device__ int g_counter;
////////////////////////////////////////////////////////////////////////
// linesAccum
__global__ void linesAccumGlobal(const unsigned int* list, const int count, PtrStepi accum, const float irho, const float theta, const int numrho)
{
const int n = blockIdx.x;
const float ang = n * theta;
float sinVal;
float cosVal;
sincosf(ang, &sinVal, &cosVal);
sinVal *= irho;
cosVal *= irho;
const int shift = (numrho - 1) / 2;
int* accumRow = accum.ptr(n + 1);
for (int i = threadIdx.x; i < count; i += blockDim.x)
{
const unsigned int val = list[i];
const int x = (val & 0xFFFF);
const int y = (val >> 16) & 0xFFFF;
int r = __float2int_rn(x * cosVal + y * sinVal);
r += shift;
::atomicAdd(accumRow + r + 1, 1);
}
}
__global__ void linesAccumShared(const unsigned int* list, const int count, PtrStepi accum, const float irho, const float theta, const int numrho)
{
int* smem = DynamicSharedMem<int>();
for (int i = threadIdx.x; i < numrho + 1; i += blockDim.x)
smem[i] = 0;
__syncthreads();
const int n = blockIdx.x;
const float ang = n * theta;
float sinVal;
float cosVal;
sincosf(ang, &sinVal, &cosVal);
sinVal *= irho;
cosVal *= irho;
const int shift = (numrho - 1) / 2;
for (int i = threadIdx.x; i < count; i += blockDim.x)
{
const unsigned int val = list[i];
const int x = (val & 0xFFFF);
const int y = (val >> 16) & 0xFFFF;
int r = __float2int_rn(x * cosVal + y * sinVal);
r += shift;
Emulation::smem::atomicAdd(&smem[r + 1], 1);
}
__syncthreads();
int* accumRow = accum.ptr(n + 1);
for (int i = threadIdx.x; i < numrho + 1; i += blockDim.x)
accumRow[i] = smem[i];
}
void linesAccum_gpu(const unsigned int* list, int count, PtrStepSzi accum, float rho, float theta, size_t sharedMemPerBlock, bool has20)
{
const dim3 block(has20 ? 1024 : 512);
const dim3 grid(accum.rows - 2);
size_t smemSize = (accum.cols - 1) * sizeof(int);
if (smemSize < sharedMemPerBlock - 1000)
linesAccumShared<<<grid, block, smemSize>>>(list, count, accum, 1.0f / rho, theta, accum.cols - 2);
else
linesAccumGlobal<<<grid, block>>>(list, count, accum, 1.0f / rho, theta, accum.cols - 2);
cudaSafeCall( cudaGetLastError() );
cudaSafeCall( cudaDeviceSynchronize() );
}
////////////////////////////////////////////////////////////////////////
// linesGetResult
__global__ void linesGetResult(const PtrStepSzi accum, float2* out, int* votes, const int maxSize, const float rho, const float theta, const int threshold, const int numrho)
{
const int r = blockIdx.x * blockDim.x + threadIdx.x;
const int n = blockIdx.y * blockDim.y + threadIdx.y;
if (r >= accum.cols - 2 || n >= accum.rows - 2)
return;
const int curVotes = accum(n + 1, r + 1);
if (curVotes > threshold &&
curVotes > accum(n + 1, r) &&
curVotes >= accum(n + 1, r + 2) &&
curVotes > accum(n, r + 1) &&
curVotes >= accum(n + 2, r + 1))
{
const float radius = (r - (numrho - 1) * 0.5f) * rho;
const float angle = n * theta;
const int ind = ::atomicAdd(&g_counter, 1);
if (ind < maxSize)
{
out[ind] = make_float2(radius, angle);
votes[ind] = curVotes;
}
}
}
int linesGetResult_gpu(PtrStepSzi accum, float2* out, int* votes, int maxSize, float rho, float theta, int threshold, bool doSort)
{
void* counterPtr;
cudaSafeCall( cudaGetSymbolAddress(&counterPtr, g_counter) );
cudaSafeCall( cudaMemset(counterPtr, 0, sizeof(int)) );
const dim3 block(32, 8);
const dim3 grid(divUp(accum.cols - 2, block.x), divUp(accum.rows - 2, block.y));
cudaSafeCall( cudaFuncSetCacheConfig(linesGetResult, cudaFuncCachePreferL1) );
linesGetResult<<<grid, block>>>(accum, out, votes, maxSize, rho, theta, threshold, accum.cols - 2);
cudaSafeCall( cudaGetLastError() );
cudaSafeCall( cudaDeviceSynchronize() );
int totalCount;
cudaSafeCall( cudaMemcpy(&totalCount, counterPtr, sizeof(int), cudaMemcpyDeviceToHost) );
totalCount = ::min(totalCount, maxSize);
if (doSort && totalCount > 0)
{
thrust::device_ptr<float2> outPtr(out);
thrust::device_ptr<int> votesPtr(votes);
thrust::sort_by_key(votesPtr, votesPtr + totalCount, outPtr, thrust::greater<int>());
}
return totalCount;
}
}
}}}
#endif /* CUDA_DISABLER */
modules/gpuimgproc/src/cuda/hough_segments.cu
0 → 100644
View file @
f614e354
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
// Copyright (C) 2009, Willow Garage Inc., all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#if !defined CUDA_DISABLER
#include "opencv2/core/cuda/common.hpp"
#include "opencv2/core/cuda/vec_math.hpp"
namespace cv { namespace gpu { namespace cudev
{
namespace hough_segments
{
__device__ int g_counter;
texture<uchar, cudaTextureType2D, cudaReadModeElementType> tex_mask(false, cudaFilterModePoint, cudaAddressModeClamp);
__global__ void houghLinesProbabilistic(const PtrStepSzi accum,
int4* out, const int maxSize,
const float rho, const float theta,
const int lineGap, const int lineLength,
const int rows, const int cols)
{
const int r = blockIdx.x * blockDim.x + threadIdx.x;
const int n = blockIdx.y * blockDim.y + threadIdx.y;
if (r >= accum.cols - 2 || n >= accum.rows - 2)
return;
const int curVotes = accum(n + 1, r + 1);
if (curVotes >= lineLength &&
curVotes > accum(n, r) &&
curVotes > accum(n, r + 1) &&
curVotes > accum(n, r + 2) &&
curVotes > accum(n + 1, r) &&
curVotes > accum(n + 1, r + 2) &&
curVotes > accum(n + 2, r) &&
curVotes > accum(n + 2, r + 1) &&
curVotes > accum(n + 2, r + 2))
{
const float radius = (r - (accum.cols - 2 - 1) * 0.5f) * rho;
const float angle = n * theta;
float cosa;
float sina;
sincosf(angle, &sina, &cosa);
float2 p0 = make_float2(cosa * radius, sina * radius);
float2 dir = make_float2(-sina, cosa);
float2 pb[4] = {make_float2(-1, -1), make_float2(-1, -1), make_float2(-1, -1), make_float2(-1, -1)};
float a;
if (dir.x != 0)
{
a = -p0.x / dir.x;
pb[0].x = 0;
pb[0].y = p0.y + a * dir.y;
a = (cols - 1 - p0.x) / dir.x;
pb[1].x = cols - 1;
pb[1].y = p0.y + a * dir.y;
}
if (dir.y != 0)
{
a = -p0.y / dir.y;
pb[2].x = p0.x + a * dir.x;
pb[2].y = 0;
a = (rows - 1 - p0.y) / dir.y;
pb[3].x = p0.x + a * dir.x;
pb[3].y = rows - 1;
}
if (pb[0].x == 0 && (pb[0].y >= 0 && pb[0].y < rows))
{
p0 = pb[0];
if (dir.x < 0)
dir = -dir;
}
else if (pb[1].x == cols - 1 && (pb[0].y >= 0 && pb[0].y < rows))
{
p0 = pb[1];
if (dir.x > 0)
dir = -dir;
}
else if (pb[2].y == 0 && (pb[2].x >= 0 && pb[2].x < cols))
{
p0 = pb[2];
if (dir.y < 0)
dir = -dir;
}
else if (pb[3].y == rows - 1 && (pb[3].x >= 0 && pb[3].x < cols))
{
p0 = pb[3];
if (dir.y > 0)
dir = -dir;
}
float2 d;
if (::fabsf(dir.x) > ::fabsf(dir.y))
{
d.x = dir.x > 0 ? 1 : -1;
d.y = dir.y / ::fabsf(dir.x);
}
else
{
d.x = dir.x / ::fabsf(dir.y);
d.y = dir.y > 0 ? 1 : -1;
}
float2 line_end[2];
int gap;
bool inLine = false;
float2 p1 = p0;
if (p1.x < 0 || p1.x >= cols || p1.y < 0 || p1.y >= rows)
return;
for (;;)
{
if (tex2D(tex_mask, p1.x, p1.y))
{
gap = 0;
if (!inLine)
{
line_end[0] = p1;
line_end[1] = p1;
inLine = true;
}
else
{
line_end[1] = p1;
}
}
else if (inLine)
{
if (++gap > lineGap)
{
bool good_line = ::abs(line_end[1].x - line_end[0].x) >= lineLength ||
::abs(line_end[1].y - line_end[0].y) >= lineLength;
if (good_line)
{
const int ind = ::atomicAdd(&g_counter, 1);
if (ind < maxSize)
out[ind] = make_int4(line_end[0].x, line_end[0].y, line_end[1].x, line_end[1].y);
}
gap = 0;
inLine = false;
}
}
p1 = p1 + d;
if (p1.x < 0 || p1.x >= cols || p1.y < 0 || p1.y >= rows)
{
if (inLine)
{
bool good_line = ::abs(line_end[1].x - line_end[0].x) >= lineLength ||
::abs(line_end[1].y - line_end[0].y) >= lineLength;
if (good_line)
{
const int ind = ::atomicAdd(&g_counter, 1);
if (ind < maxSize)
out[ind] = make_int4(line_end[0].x, line_end[0].y, line_end[1].x, line_end[1].y);
}
}
break;
}
}
}
}
int houghLinesProbabilistic_gpu(PtrStepSzb mask, PtrStepSzi accum, int4* out, int maxSize, float rho, float theta, int lineGap, int lineLength)
{
void* counterPtr;
cudaSafeCall( cudaGetSymbolAddress(&counterPtr, g_counter) );
cudaSafeCall( cudaMemset(counterPtr, 0, sizeof(int)) );
const dim3 block(32, 8);
const dim3 grid(divUp(accum.cols - 2, block.x), divUp(accum.rows - 2, block.y));
bindTexture(&tex_mask, mask);
houghLinesProbabilistic<<<grid, block>>>(accum,
out, maxSize,
rho, theta,
lineGap, lineLength,
mask.rows, mask.cols);
cudaSafeCall( cudaGetLastError() );
cudaSafeCall( cudaDeviceSynchronize() );
int totalCount;
cudaSafeCall( cudaMemcpy(&totalCount, counterPtr, sizeof(int), cudaMemcpyDeviceToHost) );
totalCount = ::min(totalCount, maxSize);
return totalCount;
}
}
}}}
#endif /* CUDA_DISABLER */
modules/gpuimgproc/src/hough.cpp
→
modules/gpuimgproc/src/
generalized_
hough.cpp
View file @
f614e354
This diff is collapsed.
Click to expand it.
modules/gpuimgproc/src/hough_circles.cpp
0 → 100644
View file @
f614e354
This diff is collapsed.
Click to expand it.
modules/gpuimgproc/src/hough_lines.cpp
0 → 100644
View file @
f614e354
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
// Copyright (C) 2009, Willow Garage Inc., all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#include "precomp.hpp"
using
namespace
cv
;
using
namespace
cv
::
gpu
;
#if !defined (HAVE_CUDA) || defined (CUDA_DISABLER)
Ptr
<
gpu
::
HoughLinesDetector
>
cv
::
gpu
::
createHoughLinesDetector
(
float
,
float
,
int
,
bool
,
int
)
{
throw_no_cuda
();
return
Ptr
<
HoughLinesDetector
>
();
}
#else
/* !defined (HAVE_CUDA) */
namespace
cv
{
namespace
gpu
{
namespace
cudev
{
namespace
hough
{
int
buildPointList_gpu
(
PtrStepSzb
src
,
unsigned
int
*
list
);
}
namespace
hough_lines
{
void
linesAccum_gpu
(
const
unsigned
int
*
list
,
int
count
,
PtrStepSzi
accum
,
float
rho
,
float
theta
,
size_t
sharedMemPerBlock
,
bool
has20
);
int
linesGetResult_gpu
(
PtrStepSzi
accum
,
float2
*
out
,
int
*
votes
,
int
maxSize
,
float
rho
,
float
theta
,
int
threshold
,
bool
doSort
);
}
}}}
namespace
{
class
HoughLinesDetectorImpl
:
public
HoughLinesDetector
{
public
:
HoughLinesDetectorImpl
(
float
rho
,
float
theta
,
int
threshold
,
bool
doSort
,
int
maxLines
)
:
rho_
(
rho
),
theta_
(
theta
),
threshold_
(
threshold
),
doSort_
(
doSort
),
maxLines_
(
maxLines
)
{
}
void
detect
(
InputArray
src
,
OutputArray
lines
);
void
downloadResults
(
InputArray
d_lines
,
OutputArray
h_lines
,
OutputArray
h_votes
=
noArray
());
void
setRho
(
float
rho
)
{
rho_
=
rho
;
}
float
getRho
()
const
{
return
rho_
;
}
void
setTheta
(
float
theta
)
{
theta_
=
theta
;
}
float
getTheta
()
const
{
return
theta_
;
}
void
setThreshold
(
int
threshold
)
{
threshold_
=
threshold
;
}
int
getThreshold
()
const
{
return
threshold_
;
}
void
setDoSort
(
bool
doSort
)
{
doSort_
=
doSort
;
}
bool
getDoSort
()
const
{
return
doSort_
;
}
void
setMaxLines
(
int
maxLines
)
{
maxLines_
=
maxLines
;
}
int
getMaxLines
()
const
{
return
maxLines_
;
}
void
write
(
FileStorage
&
fs
)
const
{
fs
<<
"name"
<<
"HoughLinesDetector_GPU"
<<
"rho"
<<
rho_
<<
"theta"
<<
theta_
<<
"threshold"
<<
threshold_
<<
"doSort"
<<
doSort_
<<
"maxLines"
<<
maxLines_
;
}
void
read
(
const
FileNode
&
fn
)
{
CV_Assert
(
String
(
fn
[
"name"
])
==
"HoughLinesDetector_GPU"
);
rho_
=
(
float
)
fn
[
"rho"
];
theta_
=
(
float
)
fn
[
"theta"
];
threshold_
=
(
int
)
fn
[
"threshold"
];
doSort_
=
(
int
)
fn
[
"doSort"
]
!=
0
;
maxLines_
=
(
int
)
fn
[
"maxLines"
];
}
private
:
float
rho_
;
float
theta_
;
int
threshold_
;
bool
doSort_
;
int
maxLines_
;
GpuMat
accum_
;
GpuMat
list_
;
GpuMat
result_
;
};
void
HoughLinesDetectorImpl
::
detect
(
InputArray
_src
,
OutputArray
lines
)
{
using
namespace
cv
::
gpu
::
cudev
::
hough
;
using
namespace
cv
::
gpu
::
cudev
::
hough_lines
;
GpuMat
src
=
_src
.
getGpuMat
();
CV_Assert
(
src
.
type
()
==
CV_8UC1
);
CV_Assert
(
src
.
cols
<
std
::
numeric_limits
<
unsigned
short
>::
max
()
);
CV_Assert
(
src
.
rows
<
std
::
numeric_limits
<
unsigned
short
>::
max
()
);
ensureSizeIsEnough
(
1
,
src
.
size
().
area
(),
CV_32SC1
,
list_
);
unsigned
int
*
srcPoints
=
list_
.
ptr
<
unsigned
int
>
();
const
int
pointsCount
=
buildPointList_gpu
(
src
,
srcPoints
);
if
(
pointsCount
==
0
)
{
lines
.
release
();
return
;
}
const
int
numangle
=
cvRound
(
CV_PI
/
theta_
);
const
int
numrho
=
cvRound
(((
src
.
cols
+
src
.
rows
)
*
2
+
1
)
/
rho_
);
CV_Assert
(
numangle
>
0
&&
numrho
>
0
);
ensureSizeIsEnough
(
numangle
+
2
,
numrho
+
2
,
CV_32SC1
,
accum_
);
accum_
.
setTo
(
Scalar
::
all
(
0
));
DeviceInfo
devInfo
;
linesAccum_gpu
(
srcPoints
,
pointsCount
,
accum_
,
rho_
,
theta_
,
devInfo
.
sharedMemPerBlock
(),
devInfo
.
supports
(
FEATURE_SET_COMPUTE_20
));
ensureSizeIsEnough
(
2
,
maxLines_
,
CV_32FC2
,
result_
);
int
linesCount
=
linesGetResult_gpu
(
accum_
,
result_
.
ptr
<
float2
>
(
0
),
result_
.
ptr
<
int
>
(
1
),
maxLines_
,
rho_
,
theta_
,
threshold_
,
doSort_
);
if
(
linesCount
==
0
)
{
lines
.
release
();
return
;
}
result_
.
cols
=
linesCount
;
result_
.
copyTo
(
lines
);
}
void
HoughLinesDetectorImpl
::
downloadResults
(
InputArray
_d_lines
,
OutputArray
h_lines
,
OutputArray
h_votes
)
{
GpuMat
d_lines
=
_d_lines
.
getGpuMat
();
if
(
d_lines
.
empty
())
{
h_lines
.
release
();
if
(
h_votes
.
needed
())
h_votes
.
release
();
return
;
}
CV_Assert
(
d_lines
.
rows
==
2
&&
d_lines
.
type
()
==
CV_32FC2
);
d_lines
.
row
(
0
).
download
(
h_lines
);
if
(
h_votes
.
needed
())
{
GpuMat
d_votes
(
1
,
d_lines
.
cols
,
CV_32SC1
,
d_lines
.
ptr
<
int
>
(
1
));
d_votes
.
download
(
h_votes
);
}
}
}
Ptr
<
HoughLinesDetector
>
cv
::
gpu
::
createHoughLinesDetector
(
float
rho
,
float
theta
,
int
threshold
,
bool
doSort
,
int
maxLines
)
{
return
new
HoughLinesDetectorImpl
(
rho
,
theta
,
threshold
,
doSort
,
maxLines
);
}
#endif
/* !defined (HAVE_CUDA) */
modules/gpuimgproc/src/hough_segments.cpp
0 → 100644
View file @
f614e354
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
// Copyright (C) 2009, Willow Garage Inc., all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#include "precomp.hpp"
using
namespace
cv
;
using
namespace
cv
::
gpu
;
#if !defined (HAVE_CUDA) || defined (CUDA_DISABLER)
Ptr
<
gpu
::
HoughSegmentDetector
>
cv
::
gpu
::
createHoughSegmentDetector
(
float
,
float
,
int
,
int
,
int
)
{
throw_no_cuda
();
return
Ptr
<
HoughSegmentDetector
>
();
}
#else
/* !defined (HAVE_CUDA) */
namespace
cv
{
namespace
gpu
{
namespace
cudev
{
namespace
hough
{
int
buildPointList_gpu
(
PtrStepSzb
src
,
unsigned
int
*
list
);
}
namespace
hough_lines
{
void
linesAccum_gpu
(
const
unsigned
int
*
list
,
int
count
,
PtrStepSzi
accum
,
float
rho
,
float
theta
,
size_t
sharedMemPerBlock
,
bool
has20
);
}
namespace
hough_segments
{
int
houghLinesProbabilistic_gpu
(
PtrStepSzb
mask
,
PtrStepSzi
accum
,
int4
*
out
,
int
maxSize
,
float
rho
,
float
theta
,
int
lineGap
,
int
lineLength
);
}
}}}
namespace
{
class
HoughSegmentDetectorImpl
:
public
HoughSegmentDetector
{
public
:
HoughSegmentDetectorImpl
(
float
rho
,
float
theta
,
int
minLineLength
,
int
maxLineGap
,
int
maxLines
)
:
rho_
(
rho
),
theta_
(
theta
),
minLineLength_
(
minLineLength
),
maxLineGap_
(
maxLineGap
),
maxLines_
(
maxLines
)
{
}
void
detect
(
InputArray
src
,
OutputArray
lines
);
void
setRho
(
float
rho
)
{
rho_
=
rho
;
}
float
getRho
()
const
{
return
rho_
;
}
void
setTheta
(
float
theta
)
{
theta_
=
theta
;
}
float
getTheta
()
const
{
return
theta_
;
}
void
setMinLineLength
(
int
minLineLength
)
{
minLineLength_
=
minLineLength
;
}
int
getMinLineLength
()
const
{
return
minLineLength_
;
}
void
setMaxLineGap
(
int
maxLineGap
)
{
maxLineGap_
=
maxLineGap
;
}
int
getMaxLineGap
()
const
{
return
maxLineGap_
;
}
void
setMaxLines
(
int
maxLines
)
{
maxLines_
=
maxLines
;
}
int
getMaxLines
()
const
{
return
maxLines_
;
}
void
write
(
FileStorage
&
fs
)
const
{
fs
<<
"name"
<<
"PHoughLinesDetector_GPU"
<<
"rho"
<<
rho_
<<
"theta"
<<
theta_
<<
"minLineLength"
<<
minLineLength_
<<
"maxLineGap"
<<
maxLineGap_
<<
"maxLines"
<<
maxLines_
;
}
void
read
(
const
FileNode
&
fn
)
{
CV_Assert
(
String
(
fn
[
"name"
])
==
"PHoughLinesDetector_GPU"
);
rho_
=
(
float
)
fn
[
"rho"
];
theta_
=
(
float
)
fn
[
"theta"
];
minLineLength_
=
(
int
)
fn
[
"minLineLength"
];
maxLineGap_
=
(
int
)
fn
[
"maxLineGap"
];
maxLines_
=
(
int
)
fn
[
"maxLines"
];
}
private
:
float
rho_
;
float
theta_
;
int
minLineLength_
;
int
maxLineGap_
;
int
maxLines_
;
GpuMat
accum_
;
GpuMat
list_
;
GpuMat
result_
;
};
void
HoughSegmentDetectorImpl
::
detect
(
InputArray
_src
,
OutputArray
lines
)
{
using
namespace
cv
::
gpu
::
cudev
::
hough
;
using
namespace
cv
::
gpu
::
cudev
::
hough_lines
;
using
namespace
cv
::
gpu
::
cudev
::
hough_segments
;
GpuMat
src
=
_src
.
getGpuMat
();
CV_Assert
(
src
.
type
()
==
CV_8UC1
);
CV_Assert
(
src
.
cols
<
std
::
numeric_limits
<
unsigned
short
>::
max
()
);
CV_Assert
(
src
.
rows
<
std
::
numeric_limits
<
unsigned
short
>::
max
()
);
ensureSizeIsEnough
(
1
,
src
.
size
().
area
(),
CV_32SC1
,
list_
);
unsigned
int
*
srcPoints
=
list_
.
ptr
<
unsigned
int
>
();
const
int
pointsCount
=
buildPointList_gpu
(
src
,
srcPoints
);
if
(
pointsCount
==
0
)
{
lines
.
release
();
return
;
}
const
int
numangle
=
cvRound
(
CV_PI
/
theta_
);
const
int
numrho
=
cvRound
(((
src
.
cols
+
src
.
rows
)
*
2
+
1
)
/
rho_
);
CV_Assert
(
numangle
>
0
&&
numrho
>
0
);
ensureSizeIsEnough
(
numangle
+
2
,
numrho
+
2
,
CV_32SC1
,
accum_
);
accum_
.
setTo
(
Scalar
::
all
(
0
));
DeviceInfo
devInfo
;
linesAccum_gpu
(
srcPoints
,
pointsCount
,
accum_
,
rho_
,
theta_
,
devInfo
.
sharedMemPerBlock
(),
devInfo
.
supports
(
FEATURE_SET_COMPUTE_20
));
ensureSizeIsEnough
(
1
,
maxLines_
,
CV_32SC4
,
result_
);
int
linesCount
=
houghLinesProbabilistic_gpu
(
src
,
accum_
,
result_
.
ptr
<
int4
>
(),
maxLines_
,
rho_
,
theta_
,
maxLineGap_
,
minLineLength_
);
if
(
linesCount
==
0
)
{
lines
.
release
();
return
;
}
result_
.
cols
=
linesCount
;
result_
.
copyTo
(
lines
);
}
}
Ptr
<
HoughSegmentDetector
>
cv
::
gpu
::
createHoughSegmentDetector
(
float
rho
,
float
theta
,
int
minLineLength
,
int
maxLineGap
,
int
maxLines
)
{
return
new
HoughSegmentDetectorImpl
(
rho
,
theta
,
minLineLength
,
maxLineGap
,
maxLines
);
}
#endif
/* !defined (HAVE_CUDA) */
modules/gpuimgproc/src/precomp.hpp
View file @
f614e354
...
...
@@ -46,6 +46,7 @@
#include "opencv2/gpuimgproc.hpp"
#include "opencv2/gpufilters.hpp"
#include "opencv2/core/utility.hpp"
#include "opencv2/core/private.hpp"
#include "opencv2/core/private.gpu.hpp"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment