Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
0ddd16cf
Commit
0ddd16cf
authored
Nov 12, 2012
by
Vladislav Vinogradov
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
calcHist & equalizeHist
parent
7e57648e
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
91 additions
and
194 deletions
+91
-194
gpu.hpp
modules/gpu/include/opencv2/gpu/gpu.hpp
+0
-2
perf_imgproc.cpp
modules/gpu/perf/perf_imgproc.cpp
+2
-3
hist.cu
modules/gpu/src/cuda/hist.cu
+80
-148
imgproc.cpp
modules/gpu/src/imgproc.cpp
+9
-41
No files found.
modules/gpu/include/opencv2/gpu/gpu.hpp
View file @
0ddd16cf
...
@@ -1028,11 +1028,9 @@ CV_EXPORTS void histRange(const GpuMat& src, GpuMat hist[4], const GpuMat levels
...
@@ -1028,11 +1028,9 @@ CV_EXPORTS void histRange(const GpuMat& src, GpuMat hist[4], const GpuMat levels
//! Calculates histogram for 8u one channel image
//! Calculates histogram for 8u one channel image
//! Output hist will have one row, 256 cols and CV32SC1 type.
//! Output hist will have one row, 256 cols and CV32SC1 type.
CV_EXPORTS
void
calcHist
(
const
GpuMat
&
src
,
GpuMat
&
hist
,
Stream
&
stream
=
Stream
::
Null
());
CV_EXPORTS
void
calcHist
(
const
GpuMat
&
src
,
GpuMat
&
hist
,
Stream
&
stream
=
Stream
::
Null
());
CV_EXPORTS
void
calcHist
(
const
GpuMat
&
src
,
GpuMat
&
hist
,
GpuMat
&
buf
,
Stream
&
stream
=
Stream
::
Null
());
//! normalizes the grayscale image brightness and contrast by normalizing its histogram
//! normalizes the grayscale image brightness and contrast by normalizing its histogram
CV_EXPORTS
void
equalizeHist
(
const
GpuMat
&
src
,
GpuMat
&
dst
,
Stream
&
stream
=
Stream
::
Null
());
CV_EXPORTS
void
equalizeHist
(
const
GpuMat
&
src
,
GpuMat
&
dst
,
Stream
&
stream
=
Stream
::
Null
());
CV_EXPORTS
void
equalizeHist
(
const
GpuMat
&
src
,
GpuMat
&
dst
,
GpuMat
&
hist
,
Stream
&
stream
=
Stream
::
Null
());
CV_EXPORTS
void
equalizeHist
(
const
GpuMat
&
src
,
GpuMat
&
dst
,
GpuMat
&
hist
,
GpuMat
&
buf
,
Stream
&
stream
=
Stream
::
Null
());
CV_EXPORTS
void
equalizeHist
(
const
GpuMat
&
src
,
GpuMat
&
dst
,
GpuMat
&
hist
,
GpuMat
&
buf
,
Stream
&
stream
=
Stream
::
Null
());
//////////////////////////////// StereoBM_GPU ////////////////////////////////
//////////////////////////////// StereoBM_GPU ////////////////////////////////
...
...
modules/gpu/perf/perf_imgproc.cpp
View file @
0ddd16cf
...
@@ -581,13 +581,12 @@ PERF_TEST_P(Sz, ImgProc_CalcHist, GPU_TYPICAL_MAT_SIZES)
...
@@ -581,13 +581,12 @@ PERF_TEST_P(Sz, ImgProc_CalcHist, GPU_TYPICAL_MAT_SIZES)
{
{
cv
::
gpu
::
GpuMat
d_src
(
src
);
cv
::
gpu
::
GpuMat
d_src
(
src
);
cv
::
gpu
::
GpuMat
d_hist
;
cv
::
gpu
::
GpuMat
d_hist
;
cv
::
gpu
::
GpuMat
d_buf
;
cv
::
gpu
::
calcHist
(
d_src
,
d_hist
,
d_buf
);
cv
::
gpu
::
calcHist
(
d_src
,
d_hist
);
TEST_CYCLE
()
TEST_CYCLE
()
{
{
cv
::
gpu
::
calcHist
(
d_src
,
d_hist
,
d_buf
);
cv
::
gpu
::
calcHist
(
d_src
,
d_hist
);
}
}
GPU_SANITY_CHECK
(
d_hist
);
GPU_SANITY_CHECK
(
d_hist
);
...
...
modules/gpu/src/cuda/hist.cu
View file @
0ddd16cf
...
@@ -43,182 +43,115 @@
...
@@ -43,182 +43,115 @@
#if !defined CUDA_DISABLER
#if !defined CUDA_DISABLER
#include "internal_shared.hpp"
#include "opencv2/gpu/device/common.hpp"
#include "opencv2/gpu/device/utility.hpp"
#include "opencv2/gpu/device/functional.hpp"
#include "opencv2/gpu/device/saturate_cast.hpp"
#include "opencv2/gpu/device/emulation.hpp"
#include "opencv2/gpu/device/transform.hpp"
namespace cv { namespace gpu { namespace device
using namespace cv::gpu;
{
using namespace cv::gpu::device;
#define UINT_BITS 32U
//Warps == subhistograms per threadblock
#define WARP_COUNT 6
//Threadblock size
#define HISTOGRAM256_THREADBLOCK_SIZE (WARP_COUNT * OPENCV_GPU_WARP_SIZE)
#define HISTOGRAM256_BIN_COUNT 256
//Shared memory per threadblock
#define HISTOGRAM256_THREADBLOCK_MEMORY (WARP_COUNT * HISTOGRAM256_BIN_COUNT)
#define PARTIAL_HISTOGRAM256_COUNT 240
#define MERGE_THREADBLOCK_SIZE 256
#define USE_SMEM_ATOMICS (defined (__CUDA_ARCH__) && (__CUDA_ARCH__ >= 120))
namespace
{
namespace hist
__global__ void histogram256(const uchar* src, int cols, int rows, size_t step, int* hist)
{
{
#if (!USE_SMEM_ATOMICS)
__shared__ int shist[256];
#define TAG_MASK ( (1U << (UINT_BITS - OPENCV_GPU_LOG_WARP_SIZE)) - 1U )
__forceinline__ __device__ void addByte(volatile uint* s_WarpHist, uint data, uint threadTag)
{
uint count;
do
{
count = s_WarpHist[data] & TAG_MASK;
count = threadTag | (count + 1);
s_WarpHist[data] = count;
} while (s_WarpHist[data] != count);
}
#else
#define TAG_MASK 0xFFFFFFFFU
const int y = blockIdx.x * blockDim.y + threadIdx.y;
const int tid = threadIdx.y * blockDim.x + threadIdx.x;
__forceinline__ __device__ void addByte(uint* s_WarpHist, uint data, uint threadTag)
shist[tid] = 0;
{
__syncthreads();
atomicAdd(s_WarpHist + data, 1);
}
#endif
if (y < rows)
__forceinline__ __device__ void addWord(uint* s_WarpHist, uint data, uint tag, uint pos_x, uint cols)
{
{
uint x = pos_x << 2;
const unsigned int* rowPtr = (const unsigned int*) (src + y * step);
if (x + 0 < cols) addByte(s_WarpHist, (data >> 0) & 0xFFU, tag);
if (x + 1 < cols) addByte(s_WarpHist, (data >> 8) & 0xFFU, tag);
if (x + 2 < cols) addByte(s_WarpHist, (data >> 16) & 0xFFU, tag);
if (x + 3 < cols) addByte(s_WarpHist, (data >> 24) & 0xFFU, tag);
}
__global__ void histogram256(const PtrStep<uint> d_Data, uint* d_PartialHistograms, uint dataCount, uint cols)
const int cols_4 = cols / 4;
{
for (int x = threadIdx.x; x < cols_4; x += blockDim.x)
//Per-warp subhistogram storage
__shared__ uint s_Hist[HISTOGRAM256_THREADBLOCK_MEMORY];
uint* s_WarpHist= s_Hist + (threadIdx.x >> OPENCV_GPU_LOG_WARP_SIZE) * HISTOGRAM256_BIN_COUNT;
//Clear shared memory storage for current threadblock before processing
#pragma unroll
for (uint i = 0; i < (HISTOGRAM256_THREADBLOCK_MEMORY / HISTOGRAM256_THREADBLOCK_SIZE); i++)
s_Hist[threadIdx.x + i * HISTOGRAM256_THREADBLOCK_SIZE] = 0;
//Cycle through the entire data set, update subhistograms for each warp
const uint tag = threadIdx.x << (UINT_BITS - OPENCV_GPU_LOG_WARP_SIZE);
__syncthreads();
const uint colsui = d_Data.step / sizeof(uint);
for(uint pos = blockIdx.x * blockDim.x + threadIdx.x; pos < dataCount; pos += blockDim.x * gridDim.x)
{
{
uint pos_y = pos / colsui;
unsigned int data = rowPtr[x];
uint pos_x = pos % colsui;
uint data = d_Data.ptr(pos_y)[pos_x];
Emulation::smem::atomicAdd(&shist[(data >> 0) & 0xFFU], 1);
addWord(s_WarpHist, data, tag, pos_x, cols);
Emulation::smem::atomicAdd(&shist[(data >> 8) & 0xFFU], 1);
Emulation::smem::atomicAdd(&shist[(data >> 16) & 0xFFU], 1);
Emulation::smem::atomicAdd(&shist[(data >> 24) & 0xFFU], 1);
}
}
//Merge per-warp histograms into per-block and write to global memory
if (cols % 4 != 0 && threadIdx.x == 0)
__syncthreads();
for(uint bin = threadIdx.x; bin < HISTOGRAM256_BIN_COUNT; bin += HISTOGRAM256_THREADBLOCK_SIZE)
{
{
uint sum = 0;
for (int x = cols_4 * 4; x < cols; ++x)
{
for (uint i = 0; i < WARP_COUNT; i++)
unsigned int data = ((const uchar*)rowPtr)[x];
sum += s_Hist[bin + i * HISTOGRAM256_BIN_COUNT] & TAG_MASK;
Emulation::smem::atomicAdd(&shist[data], 1);
}
d_PartialHistograms[blockIdx.x * HISTOGRAM256_BIN_COUNT + bin] = sum;
}
}
}
}
////////////////////////////////////////////////////////////////////////////////
__syncthreads();
// Merge histogram256() output
// Run one threadblock per bin; each threadblock adds up the same bin counter
// from every partial histogram. Reads are uncoalesced, but mergeHistogram256
// takes only a fraction of total processing time
////////////////////////////////////////////////////////////////////////////////
__global__ void mergeHistogram256(const uint* d_PartialHistograms, int* d_Histogram)
{
uint sum = 0;
#pragma unroll
const int histVal = shist[tid];
for (uint i = threadIdx.x; i < PARTIAL_HISTOGRAM256_COUNT; i += MERGE_THREADBLOCK_SIZE)
if (histVal > 0)
sum += d_PartialHistograms[blockIdx.x + i * HISTOGRAM256_BIN_COUNT];
::atomicAdd(hist + tid, histVal);
}
}
__shared__ uint data[MERGE_THREADBLOCK_SIZE];
namespace hist
data[threadIdx.x] = sum;
{
void histogram256(PtrStepSzb src, int* hist, cudaStream_t stream)
for (uint stride = MERGE_THREADBLOCK_SIZE / 2; stride > 0; stride >>= 1)
{
{
const dim3 block(32, 8);
__syncthreads();
const dim3 grid(divUp(src.rows, block.y));
if(threadIdx.x < stride)
data[threadIdx.x] += data[threadIdx.x + stride];
}
if(threadIdx.x == 0)
d_Histogram[blockIdx.x] = saturate_cast<int>(data[0]);
}
void histogram256_gpu(PtrStepSzb src, int* hist, uint* buf, cudaStream_t stream)
::histogram256<<<grid, block, 0, stream>>>(src.data, src.cols, src.rows, src.step, hist);
{
cudaSafeCall( cudaGetLastError() );
histogram256<<<PARTIAL_HISTOGRAM256_COUNT, HISTOGRAM256_THREADBLOCK_SIZE, 0, stream>>>(
PtrStepSz<uint>(src),
buf,
static_cast<uint>(src.rows * src.step / sizeof(uint)),
src.cols);
cudaSafeCall( cudaGetLastError() );
if (stream == 0)
cudaSafeCall( cudaDeviceSynchronize() );
}
}
mergeHistogram256<<<HISTOGRAM256_BIN_COUNT, MERGE_THREADBLOCK_SIZE, 0, stream>>>(buf, hist);
/////////////////////////////////////////////////////////////////////////
cudaSafeCall( cudaGetLastError() );
namespace
{
__constant__ int c_lut[256];
if (stream == 0)
struct EqualizeHist : unary_function<uchar, uchar>
cudaSafeCall( cudaDeviceSynchronize() );
{
}
float scale;
__
constant__ int c_lut[256];
__
host__ EqualizeHist(float _scale) : scale(_scale) {}
__
global__ void equalizeHist(const PtrStepSzb src, PtrStepb dst)
__
device__ __forceinline__ uchar operator ()(uchar val) const
{
{
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int lut = c_lut[val];
const int y = blockIdx.y * blockDim.y + threadIdx.y;
return __float2int_rn(scale * lut);
if (x < src.cols && y < src.rows)
{
const uchar val = src.ptr(y)[x];
const int lut = c_lut[val];
dst.ptr(y)[x] = __float2int_rn(255.0f / (src.cols * src.rows) * lut);
}
}
}
};
}
void equalizeHist_gpu(PtrStepSzb src, PtrStepSzb dst, const int* lut, cudaStream_t stream)
namespace cv { namespace gpu { namespace device
{
{
dim3 block(16, 16);
template <> struct TransformFunctorTraits<EqualizeHist> : DefaultTransformFunctorTraits<EqualizeHist>
dim3 grid(divUp(src.cols, block.x), divUp(src.rows, block.y));
{
enum { smart_shift = 4 };
};
}}}
namespace hist
{
void equalizeHist(PtrStepSzb src, PtrStepSzb dst, const int* lut, cudaStream_t stream)
{
if (stream == 0)
cudaSafeCall( cudaMemcpyToSymbol(c_lut, lut, 256 * sizeof(int), 0, cudaMemcpyDeviceToDevice) );
cudaSafeCall( cudaMemcpyToSymbol(c_lut, lut, 256 * sizeof(int), 0, cudaMemcpyDeviceToDevice) );
else
cudaSafeCall( cudaMemcpyToSymbolAsync(c_lut, lut, 256 * sizeof(int), 0, cudaMemcpyDeviceToDevice, stream) );
equalizeHist<<<grid, block, 0, stream>>>(src, dst);
const float scale = 255.0f / (src.cols * src.rows);
cudaSafeCall( cudaGetLastError() );
if (stream == 0)
cudaSafeCall( cudaDeviceSynchronize() );
}
} // namespace hist
}}} // namespace cv { namespace gpu { namespace device
transform(src, dst, EqualizeHist(scale), WithOutMask(), stream);
}
}
#endif /* CUDA_DISABLER */
#endif /* CUDA_DISABLER */
\ No newline at end of file
modules/gpu/src/imgproc.cpp
View file @
0ddd16cf
...
@@ -71,9 +71,7 @@ void cv::gpu::histRange(const GpuMat&, GpuMat&, const GpuMat&, GpuMat&, Stream&)
...
@@ -71,9 +71,7 @@ void cv::gpu::histRange(const GpuMat&, GpuMat&, const GpuMat&, GpuMat&, Stream&)
void
cv
::
gpu
::
histRange
(
const
GpuMat
&
,
GpuMat
*
,
const
GpuMat
*
,
Stream
&
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
histRange
(
const
GpuMat
&
,
GpuMat
*
,
const
GpuMat
*
,
Stream
&
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
histRange
(
const
GpuMat
&
,
GpuMat
*
,
const
GpuMat
*
,
GpuMat
&
,
Stream
&
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
histRange
(
const
GpuMat
&
,
GpuMat
*
,
const
GpuMat
*
,
GpuMat
&
,
Stream
&
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
calcHist
(
const
GpuMat
&
,
GpuMat
&
,
Stream
&
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
calcHist
(
const
GpuMat
&
,
GpuMat
&
,
Stream
&
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
calcHist
(
const
GpuMat
&
,
GpuMat
&
,
GpuMat
&
,
Stream
&
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
equalizeHist
(
const
GpuMat
&
,
GpuMat
&
,
Stream
&
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
equalizeHist
(
const
GpuMat
&
,
GpuMat
&
,
Stream
&
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
equalizeHist
(
const
GpuMat
&
,
GpuMat
&
,
GpuMat
&
,
Stream
&
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
equalizeHist
(
const
GpuMat
&
,
GpuMat
&
,
GpuMat
&
,
GpuMat
&
,
Stream
&
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
equalizeHist
(
const
GpuMat
&
,
GpuMat
&
,
GpuMat
&
,
GpuMat
&
,
Stream
&
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
cornerHarris
(
const
GpuMat
&
,
GpuMat
&
,
int
,
int
,
double
,
int
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
cornerHarris
(
const
GpuMat
&
,
GpuMat
&
,
int
,
int
,
double
,
int
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
cornerHarris
(
const
GpuMat
&
,
GpuMat
&
,
GpuMat
&
,
GpuMat
&
,
int
,
int
,
double
,
int
)
{
throw_nogpu
();
}
void
cv
::
gpu
::
cornerHarris
(
const
GpuMat
&
,
GpuMat
&
,
GpuMat
&
,
GpuMat
&
,
int
,
int
,
double
,
int
)
{
throw_nogpu
();
}
...
@@ -989,36 +987,20 @@ void cv::gpu::histRange(const GpuMat& src, GpuMat hist[4], const GpuMat levels[4
...
@@ -989,36 +987,20 @@ void cv::gpu::histRange(const GpuMat& src, GpuMat hist[4], const GpuMat levels[4
hist_callers
[
src
.
depth
()](
src
,
hist
,
levels
,
buf
,
StreamAccessor
::
getStream
(
stream
));
hist_callers
[
src
.
depth
()](
src
,
hist
,
levels
,
buf
,
StreamAccessor
::
getStream
(
stream
));
}
}
namespace
cv
{
namespace
gpu
{
namespace
device
namespace
hist
{
namespace
hist
{
void
histogram256_gpu
(
PtrStepSzb
src
,
int
*
hist
,
unsigned
int
*
buf
,
cudaStream_t
stream
);
const
int
PARTIAL_HISTOGRAM256_COUNT
=
240
;
const
int
HISTOGRAM256_BIN_COUNT
=
256
;
void
equalizeHist_gpu
(
PtrStepSzb
src
,
PtrStepSzb
dst
,
const
int
*
lut
,
cudaStream_t
stream
);
}
}}}
void
cv
::
gpu
::
calcHist
(
const
GpuMat
&
src
,
GpuMat
&
hist
,
Stream
&
stream
)
{
{
GpuMat
buf
;
void
histogram256
(
PtrStepSzb
src
,
int
*
hist
,
cudaStream_t
stream
)
;
calcHist
(
src
,
hist
,
buf
,
stream
);
void
equalizeHist
(
PtrStepSzb
src
,
PtrStepSzb
dst
,
const
int
*
lut
,
cudaStream_t
stream
);
}
}
void
cv
::
gpu
::
calcHist
(
const
GpuMat
&
src
,
GpuMat
&
hist
,
GpuMat
&
buf
,
Stream
&
stream
)
void
cv
::
gpu
::
calcHist
(
const
GpuMat
&
src
,
GpuMat
&
hist
,
Stream
&
stream
)
{
{
using
namespace
::
cv
::
gpu
::
device
::
hist
;
CV_Assert
(
src
.
type
()
==
CV_8UC1
);
CV_Assert
(
src
.
type
()
==
CV_8UC1
);
hist
.
create
(
1
,
256
,
CV_32SC1
);
hist
.
create
(
1
,
256
,
CV_32SC1
);
hist
.
setTo
(
Scalar
::
all
(
0
));
ensureSizeIsEnough
(
1
,
PARTIAL_HISTOGRAM256_COUNT
*
HISTOGRAM256_BIN_COUNT
,
CV_32SC1
,
buf
);
hist
::
histogram256
(
src
,
hist
.
ptr
<
int
>
(),
StreamAccessor
::
getStream
(
stream
));
histogram256_gpu
(
src
,
hist
.
ptr
<
int
>
(),
buf
.
ptr
<
unsigned
int
>
(),
StreamAccessor
::
getStream
(
stream
));
}
}
void
cv
::
gpu
::
equalizeHist
(
const
GpuMat
&
src
,
GpuMat
&
dst
,
Stream
&
stream
)
void
cv
::
gpu
::
equalizeHist
(
const
GpuMat
&
src
,
GpuMat
&
dst
,
Stream
&
stream
)
...
@@ -1028,16 +1010,8 @@ void cv::gpu::equalizeHist(const GpuMat& src, GpuMat& dst, Stream& stream)
...
@@ -1028,16 +1010,8 @@ void cv::gpu::equalizeHist(const GpuMat& src, GpuMat& dst, Stream& stream)
equalizeHist
(
src
,
dst
,
hist
,
buf
,
stream
);
equalizeHist
(
src
,
dst
,
hist
,
buf
,
stream
);
}
}
void
cv
::
gpu
::
equalizeHist
(
const
GpuMat
&
src
,
GpuMat
&
dst
,
GpuMat
&
hist
,
Stream
&
stream
)
{
GpuMat
buf
;
equalizeHist
(
src
,
dst
,
hist
,
buf
,
stream
);
}
void
cv
::
gpu
::
equalizeHist
(
const
GpuMat
&
src
,
GpuMat
&
dst
,
GpuMat
&
hist
,
GpuMat
&
buf
,
Stream
&
s
)
void
cv
::
gpu
::
equalizeHist
(
const
GpuMat
&
src
,
GpuMat
&
dst
,
GpuMat
&
hist
,
GpuMat
&
buf
,
Stream
&
s
)
{
{
using
namespace
::
cv
::
gpu
::
device
::
hist
;
CV_Assert
(
src
.
type
()
==
CV_8UC1
);
CV_Assert
(
src
.
type
()
==
CV_8UC1
);
dst
.
create
(
src
.
size
(),
src
.
type
());
dst
.
create
(
src
.
size
(),
src
.
type
());
...
@@ -1045,15 +1019,12 @@ void cv::gpu::equalizeHist(const GpuMat& src, GpuMat& dst, GpuMat& hist, GpuMat&
...
@@ -1045,15 +1019,12 @@ void cv::gpu::equalizeHist(const GpuMat& src, GpuMat& dst, GpuMat& hist, GpuMat&
int
intBufSize
;
int
intBufSize
;
nppSafeCall
(
nppsIntegralGetBufferSize_32s
(
256
,
&
intBufSize
)
);
nppSafeCall
(
nppsIntegralGetBufferSize_32s
(
256
,
&
intBufSize
)
);
int
bufSize
=
static_cast
<
int
>
(
std
::
max
(
256
*
240
*
sizeof
(
int
),
intBufSize
+
256
*
sizeof
(
int
)));
ensureSizeIsEnough
(
1
,
intBufSize
+
256
*
sizeof
(
int
),
CV_8UC1
,
buf
);
ensureSizeIsEnough
(
1
,
bufSize
,
CV_8UC1
,
buf
);
GpuMat
histBuf
(
1
,
256
*
240
,
CV_32SC1
,
buf
.
ptr
());
GpuMat
intBuf
(
1
,
intBufSize
,
CV_8UC1
,
buf
.
ptr
());
GpuMat
intBuf
(
1
,
intBufSize
,
CV_8UC1
,
buf
.
ptr
());
GpuMat
lut
(
1
,
256
,
CV_32S
,
buf
.
ptr
()
+
intBufSize
);
GpuMat
lut
(
1
,
256
,
CV_32S
,
buf
.
ptr
()
+
intBufSize
);
calcHist
(
src
,
hist
,
histBuf
,
s
);
calcHist
(
src
,
hist
,
s
);
cudaStream_t
stream
=
StreamAccessor
::
getStream
(
s
);
cudaStream_t
stream
=
StreamAccessor
::
getStream
(
s
);
...
@@ -1061,10 +1032,7 @@ void cv::gpu::equalizeHist(const GpuMat& src, GpuMat& dst, GpuMat& hist, GpuMat&
...
@@ -1061,10 +1032,7 @@ void cv::gpu::equalizeHist(const GpuMat& src, GpuMat& dst, GpuMat& hist, GpuMat&
nppSafeCall
(
nppsIntegral_32s
(
hist
.
ptr
<
Npp32s
>
(),
lut
.
ptr
<
Npp32s
>
(),
256
,
intBuf
.
ptr
<
Npp8u
>
())
);
nppSafeCall
(
nppsIntegral_32s
(
hist
.
ptr
<
Npp32s
>
(),
lut
.
ptr
<
Npp32s
>
(),
256
,
intBuf
.
ptr
<
Npp8u
>
())
);
if
(
stream
==
0
)
hist
::
equalizeHist
(
src
,
dst
,
lut
.
ptr
<
int
>
(),
stream
);
cudaSafeCall
(
cudaDeviceSynchronize
()
);
equalizeHist_gpu
(
src
,
dst
,
lut
.
ptr
<
int
>
(),
stream
);
}
}
////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment