Commit dea8aae2 authored by Seon-Wook Park's avatar Seon-Wook Park

Add saturation based thresholding to grayworld WB

parent aaf0ffe9
...@@ -61,7 +61,8 @@ namespace cv { namespace xphoto { ...@@ -61,7 +61,8 @@ namespace cv { namespace xphoto {
@param dst @param dst
@sa balanceWhite @sa balanceWhite
*/ */
CV_EXPORTS_W void autowbGrayworld(InputArray src, OutputArray dst); CV_EXPORTS_W void autowbGrayworld(InputArray src, OutputArray dst,
const float thresh = 0.5f);
//! @} //! @}
......
...@@ -46,7 +46,7 @@ namespace cv { namespace xphoto { ...@@ -46,7 +46,7 @@ namespace cv { namespace xphoto {
/*! /*!
*/ */
void autowbGrayworld(InputArray _src, OutputArray _dst) void autowbGrayworld(InputArray _src, OutputArray _dst, const float thresh)
{ {
Mat src = _src.getMat(); Mat src = _src.getMat();
...@@ -70,48 +70,81 @@ namespace cv { namespace xphoto { ...@@ -70,48 +70,81 @@ namespace cv { namespace xphoto {
ulong sum1 = 0, sum2 = 0, sum3 = 0; ulong sum1 = 0, sum2 = 0, sum3 = 0;
int i = 0; int i = 0;
#if CV_SIMD128 #if CV_SIMD128
v_uint8x16 v_in; v_uint8x16 v_inB, v_inG, v_inR;
v_uint16x8 v_s1, v_s2; v_uint16x8 v_s1, v_s2;
v_uint32x4 v_i1, v_i2, v_i3, v_i4, v_uint32x4 v_iB1, v_iB2, v_iB3, v_iB4,
v_S1 = v_setzero_u32(), v_iG1, v_iG2, v_iG3, v_iG4,
v_S2 = v_setzero_u32(), v_iR1, v_iR2, v_iR3, v_iR4,
v_S3 = v_setzero_u32(), v_SB = v_setzero_u32(),
v_S4 = v_setzero_u32(); v_SG = v_setzero_u32(),
v_SR = v_setzero_u32(),
for (; i < N3 - 14; i += 15) v_m1, v_m2, v_m3, v_m4;
v_float32x4 v_thresh = v_setall_f32(thresh),
v_min1, v_min2, v_min3, v_min4,
v_max1, v_max2, v_max3, v_max4,
v_sat1, v_sat2, v_sat3, v_sat4;
for ( ; i < N3 - 47; i += 48 )
{ {
// Load 16 x 8bit uchars // NOTE: This block assumes BGR channels in naming variables
v_in = v_load(&src_data[i]);
// Load 3x uint8x16 and deinterleave into vectors of each channel
// Split into two vectors of 8 ushorts v_load_deinterleave(&src_data[i], v_inB, v_inG, v_inR);
v_expand(v_in, v_s1, v_s2);
// Split into four int vectors per channel
// Split into four vectors of 4 uints v_expand(v_inB, v_s1, v_s2);
v_expand(v_s1, v_i1, v_i2); v_expand(v_s1, v_iB1, v_iB2);
v_expand(v_s2, v_i3, v_i4); v_expand(v_s2, v_iB3, v_iB4);
// Add to accumulators v_expand(v_inG, v_s1, v_s2);
v_S1 += v_i1; v_expand(v_s1, v_iG1, v_iG2);
v_S2 += v_i2; v_expand(v_s2, v_iG3, v_iG4);
v_S3 += v_i3;
v_S4 += v_i4; v_expand(v_inR, v_s1, v_s2);
v_expand(v_s1, v_iR1, v_iR2);
v_expand(v_s2, v_iR3, v_iR4);
// Get saturation
v_min1 = v_cvt_f32(v_reinterpret_as_s32(v_min(v_iB1, v_min(v_iG1, v_iR1))));
v_min2 = v_cvt_f32(v_reinterpret_as_s32(v_min(v_iB2, v_min(v_iG2, v_iR2))));
v_min3 = v_cvt_f32(v_reinterpret_as_s32(v_min(v_iB3, v_min(v_iG3, v_iR3))));
v_min4 = v_cvt_f32(v_reinterpret_as_s32(v_min(v_iB4, v_min(v_iG4, v_iR4))));
v_max1 = v_cvt_f32(v_reinterpret_as_s32(v_max(v_iB1, v_max(v_iG1, v_iR1))));
v_max2 = v_cvt_f32(v_reinterpret_as_s32(v_max(v_iB2, v_max(v_iG2, v_iR2))));
v_max3 = v_cvt_f32(v_reinterpret_as_s32(v_max(v_iB3, v_max(v_iG3, v_iR3))));
v_max4 = v_cvt_f32(v_reinterpret_as_s32(v_max(v_iB4, v_max(v_iG4, v_iR4))));
v_sat1 = (v_max1 - v_min1) / v_max1;
v_sat2 = (v_max2 - v_min2) / v_max2;
v_sat3 = (v_max3 - v_min3) / v_max3;
v_sat4 = (v_max4 - v_min4) / v_max4;
// Calculate masks
v_m1 = v_reinterpret_as_u32(v_sat1 <= v_thresh);
v_m2 = v_reinterpret_as_u32(v_sat2 <= v_thresh);
v_m3 = v_reinterpret_as_u32(v_sat3 <= v_thresh);
v_m4 = v_reinterpret_as_u32(v_sat4 <= v_thresh);
// Apply mask
v_SB += (v_iB1 & v_m1) + (v_iB2 & v_m2) + (v_iB3 & v_m3) + (v_iB4 & v_m4);
v_SG += (v_iG1 & v_m1) + (v_iG2 & v_m2) + (v_iG3 & v_m3) + (v_iG4 & v_m4);
v_SR += (v_iR1 & v_m1) + (v_iR2 & v_m2) + (v_iR3 & v_m3) + (v_iR4 & v_m4);
} }
// Store accumulated values into memory
uint sums[16];
v_store(&sums[0], v_S1);
v_store(&sums[4], v_S2);
v_store(&sums[8], v_S3);
v_store(&sums[12], v_S4);
// Perform final reduction // Perform final reduction
sum1 = sums[0] + sums[3] + sums[6] + sums[9] + sums[12], sum1 = v_reduce_sum(v_SB);
sum2 = sums[1] + sums[4] + sums[7] + sums[10] + sums[13], sum2 = v_reduce_sum(v_SG);
sum3 = sums[2] + sums[5] + sums[8] + sums[11] + sums[14]; sum3 = v_reduce_sum(v_SR);
#endif #endif
for (; i < N3; i += 3) double minRGB, maxRGB, satur;
for ( ; i < N3; i += 3 )
{ {
sum1 += src_data[i + 0]; minRGB = min(src_data[i], min(src_data[i + 1], src_data[i + 2]));
maxRGB = max(src_data[i], max(src_data[i + 1], src_data[i + 2]));
satur = (maxRGB - minRGB) / maxRGB;
if ( satur > thresh ) continue;
sum1 += src_data[i];
sum2 += src_data[i + 1]; sum2 += src_data[i + 1];
sum3 += src_data[i + 2]; sum3 += src_data[i + 2];
} }
...@@ -130,7 +163,7 @@ namespace cv { namespace xphoto { ...@@ -130,7 +163,7 @@ namespace cv { namespace xphoto {
inv3 = (float) dinv3; inv3 = (float) dinv3;
// Scale by maximum // Scale by maximum
if (inv_max > 0) if ( inv_max > 0 )
{ {
inv1 /= inv_max; inv1 /= inv_max;
inv2 /= inv_max; inv2 /= inv_max;
...@@ -141,14 +174,15 @@ namespace cv { namespace xphoto { ...@@ -141,14 +174,15 @@ namespace cv { namespace xphoto {
uchar* dst_data = dst.ptr<uchar>(0); uchar* dst_data = dst.ptr<uchar>(0);
i = 0; i = 0;
#if CV_SIMD128 #if CV_SIMD128
v_uint8x16 v_out; v_uint8x16 v_in, v_out;
v_uint32x4 v_i1, v_i2, v_i3, v_i4;
v_float32x4 v_f1, v_f2, v_f3, v_f4, v_float32x4 v_f1, v_f2, v_f3, v_f4,
scal1(inv1, inv2, inv3, inv1), scal1(inv1, inv2, inv3, inv1),
scal2(inv2, inv3, inv1, inv2), scal2(inv2, inv3, inv1, inv2),
scal3(inv3, inv1, inv2, inv3), scal3(inv3, inv1, inv2, inv3),
scal4(inv1, inv2, inv3, 0.f); scal4(inv1, inv2, inv3, 0.f);
for (; i < N3 - 14; i += 15) for ( ; i < N3 - 14; i += 15 )
{ {
// Load 16 x 8bit uchars // Load 16 x 8bit uchars
v_in = v_load(&src_data[i]); v_in = v_load(&src_data[i]);
...@@ -189,7 +223,7 @@ namespace cv { namespace xphoto { ...@@ -189,7 +223,7 @@ namespace cv { namespace xphoto {
v_store(&dst_data[i], v_out); v_store(&dst_data[i], v_out);
} }
#endif #endif
for (; i < N3; i += 3) for ( ; i < N3; i += 3 )
{ {
dst_data[i + 0] = src_data[i + 0] * inv1; dst_data[i + 0] = src_data[i + 0] * inv1;
dst_data[i + 1] = src_data[i + 1] * inv2; dst_data[i + 1] = src_data[i + 1] * inv2;
......
...@@ -4,8 +4,7 @@ namespace cvtest { ...@@ -4,8 +4,7 @@ namespace cvtest {
using namespace cv; using namespace cv;
// TODO: Remove debug print statements void ref_autowbGrayworld(InputArray _src, OutputArray _dst, const float thresh)
void ref_autowbGrayworld(InputArray _src, OutputArray _dst)
{ {
Mat src = _src.getMat(); Mat src = _src.getMat();
...@@ -21,70 +20,66 @@ namespace cvtest { ...@@ -21,70 +20,66 @@ namespace cvtest {
const uchar* src_data = src.ptr<uchar>(0); const uchar* src_data = src.ptr<uchar>(0);
unsigned long sum1 = 0, sum2 = 0, sum3 = 0; unsigned long sum1 = 0, sum2 = 0, sum3 = 0;
int i = 0; int i = 0;
for (; i < N3; i += 3) double minRGB, maxRGB, satur;
for ( ; i < N3; i += 3 )
{ {
sum1 += src_data[i + 0]; minRGB = std::min(src_data[i], std::min(src_data[i + 1], src_data[i + 2]));
maxRGB = std::max(src_data[i], std::max(src_data[i + 1], src_data[i + 2]));
satur = (maxRGB - minRGB) / maxRGB;
if ( satur > thresh ) continue;
sum1 += src_data[i];
sum2 += src_data[i + 1]; sum2 += src_data[i + 1];
sum3 += src_data[i + 2]; sum3 += src_data[i + 2];
} }
//printf("sums:\t\t\t%lu, %lu, %lu\n", sum1, sum2, sum3);
// Find inverse of averages // Find inverse of averages
double inv1 = sum1 == 0 ? 0.f : (double)N / (double)sum1, double inv1 = sum1 == 0 ? 0.f : (double)N / (double)sum1,
inv2 = sum2 == 0 ? 0.f : (double)N / (double)sum2, inv2 = sum2 == 0 ? 0.f : (double)N / (double)sum2,
inv3 = sum3 == 0 ? 0.f : (double)N / (double)sum3; inv3 = sum3 == 0 ? 0.f : (double)N / (double)sum3;
//printf("inverse avgs:\t\t%f, %f, %f\n", inv1, inv2, inv3);
// Find maximum // Find maximum
double inv_max = std::max(std::max(inv1, inv2), inv3); double inv_max = std::max(std::max(inv1, inv2), inv3);
// Scale by maximum // Scale by maximum
if (inv_max > 0) if ( inv_max > 0 )
{ {
inv1 = (double) inv1 / inv_max; inv1 = (double) inv1 / inv_max;
inv2 = (double) inv2 / inv_max; inv2 = (double) inv2 / inv_max;
inv3 = (double) inv3 / inv_max; inv3 = (double) inv3 / inv_max;
} }
//printf("scaling factors:\t%f, %f, %f\n", inv1, inv2, inv3);
//printf("scaling factors applied:\t%f, %f, %f\n",
// (double) sum1 * inv1,
// (double) sum2 * inv2,
// (double) sum3 * inv3);
// Scale input pixel values // Scale input pixel values
uchar* dst_data = dst.ptr<uchar>(0); uchar* dst_data = dst.ptr<uchar>(0);
i = 0; i = 0;
for (; i < N3; i += 3) for ( ; i < N3; i += 3 )
{ {
dst_data[i + 0] = src_data[i + 0] * inv1; dst_data[i] = src_data[i] * inv1;
dst_data[i + 1] = src_data[i + 1] * inv2; dst_data[i + 1] = src_data[i + 1] * inv2;
dst_data[i + 2] = src_data[i + 2] * inv3; dst_data[i + 2] = src_data[i + 2] * inv3;
} }
//imshow("original", src);
//imshow("grayworld", dst);
//waitKey();
} }
TEST(xphoto_grayworld_white_balance, regression) TEST(xphoto_grayworld_white_balance, regression)
{ {
String subfolder = "/xphoto/"; String subfolder = "/xphoto/";
String dir = cvtest::TS::ptr()->get_data_path() + subfolder + "simple_white_balance/"; String dir = cvtest::TS::ptr()->get_data_path() + subfolder + "simple_white_balance/";
int nTests = 14; const int nTests = 14;
float threshold = 2.f; const float wb_thresh = 0.5f;
const float acc_thresh = 2.f;
for (int i = 0; i < nTests; ++i) for ( int i = 0; i < nTests; ++i )
{ {
String srcName = dir + format("sources/%02d.png", i + 1); String srcName = dir + format("sources/%02d.png", i + 1);
Mat src = imread(srcName, IMREAD_COLOR); Mat src = imread(srcName, IMREAD_COLOR);
ASSERT_TRUE(!src.empty()); ASSERT_TRUE(!src.empty());
Mat referenceResult; Mat referenceResult;
ref_autowbGrayworld(src, referenceResult); ref_autowbGrayworld(src, referenceResult, wb_thresh);
Mat currentResult; Mat currentResult;
xphoto::autowbGrayworld(src, currentResult); xphoto::autowbGrayworld(src, currentResult, wb_thresh);
ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), threshold); ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh);
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment