Commit 191d5c1d authored by Maksim Shabunin's avatar Maksim Shabunin

Merge pull request #747 from sbokov:color_constancy

parents 3468ae57 75dedf9e
......@@ -21,8 +21,10 @@ PERF_TEST_P( Size_WBThresh, autowbGrayworld,
Mat dst(size, CV_8UC3);
declare.in(src, WARMUP_RNG).out(dst);
Ptr<xphoto::GrayworldWB> wb = xphoto::createGrayworldWB();
wb->setSaturationThreshold(wb_thresh);
TEST_CYCLE() xphoto::autowbGrayworld(src, dst, wb_thresh);
TEST_CYCLE() wb->balanceWhite(src, dst);
SANITY_CHECK(dst);
}
......
......@@ -65,8 +65,12 @@ PERF_TEST_P(learningBasedWBPerfTest, perf, Combine(SZ_ALL_HD, Values(CV_8UC3, CV
RNG rng(1234);
rng.fill(src_dscl, RNG::UNIFORM, 0, range_max_val);
resize(src_dscl, src, src.size());
Ptr<xphoto::LearningBasedWB> wb = xphoto::createLearningBasedWB();
wb->setRangeMaxVal(range_max_val);
wb->setSaturationThreshold(0.98f);
wb->setHistBinNum(hist_bin_num);
TEST_CYCLE() xphoto::autowbLearningBased(src, dst, range_max_val, 0.98f, hist_bin_num);
TEST_CYCLE() wb->balanceWhite(src, dst);
SANITY_CHECK_NOTHING();
}
......@@ -4,14 +4,16 @@
using namespace cv;
using namespace std;
const char *keys = {"{help h usage ? | | print this message}"
const char *keys = { "{help h usage ? | | print this message}"
"{i | | input image name }"
"{o | | output image name }"};
"{o | | output image name }"
"{a |grayworld| color balance algorithm (simple, grayworld or learning_based)}"
"{m | | path to the model for the learning-based algorithm (optional) }" };
int main(int argc, const char **argv)
{
CommandLineParser parser(argc, argv, keys);
parser.about("OpenCV learning-based color balance demonstration sample");
parser.about("OpenCV color balance demonstration sample");
if (parser.has("help") || argc < 2)
{
parser.printMessage();
......@@ -20,6 +22,8 @@ int main(int argc, const char **argv)
string inFilename = parser.get<string>("i");
string outFilename = parser.get<string>("o");
string algorithm = parser.get<string>("a");
string modelFilename = parser.get<string>("m");
if (!parser.check())
{
......@@ -35,7 +39,20 @@ int main(int argc, const char **argv)
}
Mat res;
xphoto::autowbLearningBased(src, res);
Ptr<xphoto::WhiteBalancer> wb;
if (algorithm == "simple")
wb = xphoto::createSimpleWB();
else if (algorithm == "grayworld")
wb = xphoto::createGrayworldWB();
else if (algorithm == "learning_based")
wb = xphoto::createLearningBasedWB(modelFilename);
else
{
printf("Unsupported algorithm: %s\n", algorithm.c_str());
return -1;
}
wb->balanceWhite(src, res);
if (outFilename == "")
{
......
......@@ -5,6 +5,7 @@ import numpy as np
import scipy.io
import cv2
import timeit
from learn_color_balance import load_ground_truth
def load_json(path):
......@@ -39,15 +40,24 @@ def stretch_to_8bit(arr, clip_percentile = 2.5):
return arr.astype(np.uint8)
def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder):
def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder, model_folder):
new_im = None
start_time = timeit.default_timer()
if algo=="grayworld":
new_im = cv2.xphoto.autowbGrayworld(im, 0.95)
inst = cv2.xphoto.createGrayworldWB()
inst.setSaturationThreshold(0.95)
new_im = inst.balanceWhite(im)
elif algo=="nothing":
new_im = im
elif algo=="learning_based":
new_im = cv2.xphoto.autowbLearningBased(im, None, range_thresh, 0.98, bin_num)
elif algo.split(":")[0]=="learning_based":
model_path = ""
if len(algo.split(":"))>1:
model_path = os.path.join(model_folder, algo.split(":")[1])
inst = cv2.xphoto.createLearningBasedWB(model_path)
inst.setRangeMaxVal(range_thresh)
inst.setSaturationThreshold(0.98)
inst.setHistBinNum(bin_num)
new_im = inst.balanceWhite(im)
elif algo=="GT":
gains = gt_illuminant / min(gt_illuminant)
g1 = float(1.0 / gains[2])
......@@ -59,7 +69,7 @@ def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder):
if len(dst_folder)>0:
if not os.path.exists(dst_folder):
os.makedirs(dst_folder)
im_name = ("%04d_" % i) + algo + ".jpg"
im_name = ("%04d_" % i) + algo.replace(":","_") + ".jpg"
cv2.imwrite(os.path.join(dst_folder, im_name), stretch_to_8bit(new_im))
#recover the illuminant from the color balancing result, assuming the standard model:
......@@ -140,7 +150,9 @@ if __name__ == '__main__':
metavar="ALGORITHMS",
default="",
help=("Comma-separated list of color balance algorithms to evaluate. "
"Currently available: GT,learning_based,grayworld,nothing."))
"Currently available: GT,learning_based,grayworld,nothing. "
"Use a colon to set a specific model for the learning-based "
"algorithm, e.g. learning_based:model1.yml,learning_based:model2.yml"))
parser.add_argument(
"-i",
"--input_folder",
......@@ -196,6 +208,12 @@ if __name__ == '__main__':
default="0,0",
help=("Comma-separated range of images from the dataset to evaluate on (for instance: 0,568). "
"All available images are used by default."))
parser.add_argument(
"-m",
"--model_folder",
metavar="MODEL_FOLDER",
default="",
help=("Path to the folder containing models for the learning-based color balance algorithm (optional)"))
args, other_args = parser.parse_known_args()
if not os.path.exists(args.input_folder):
......@@ -218,22 +236,8 @@ if __name__ == '__main__':
print("Error: Please specify the -r parameter in form <first_image_index>,<last_image_index>")
sys.exit(1)
gt = scipy.io.loadmat(args.ground_truth)
img_files = sorted(os.listdir(args.input_folder))
gt_illuminants = []
black_levels = []
if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys():
#NUS 8-camera dataset format
gt_illuminants = gt["groundtruth_illuminants"]
black_levels = len(gt_illuminants) * [gt["darkness_level"][0][0]]
elif "real_rgb" in gt.keys():
#Gehler-Shi dataset format
gt_illuminants = gt["real_rgb"]
black_levels = 87 * [0] + (len(gt_illuminants) - 87) * [129]
else:
print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported")
sys.exit(1)
(gt_illuminants,black_levels) = load_ground_truth(args.ground_truth)
for algorithm in algorithm_list:
i = 0
......@@ -254,7 +258,7 @@ if __name__ == '__main__':
im = stretch_to_8bit(im)
(time,angular_err) = evaluate(im, algorithm, gt_illuminants[i], i, range_thresh,
256 if range_thresh > 255 else 64, args.dst_folder)
256 if range_thresh > 255 else 64, args.dst_folder, args.model_folder)
state[algorithm][file] = {"angular_error": angular_err, "time": time}
sys.stdout.write("Algorithm: %-20s Done: [%3d/%3d]\r" % (algorithm, i, sz)),
sys.stdout.flush()
......
#include "opencv2/xphoto.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/core/utility.hpp"
using namespace cv;
using namespace std;
const char* keys =
{
"{i || input image name}"
"{o || output image name}"
};
int main( int argc, const char** argv )
{
bool printHelp = ( argc == 1 );
printHelp = printHelp || ( argc == 2 && string(argv[1]) == "--help" );
printHelp = printHelp || ( argc == 2 && string(argv[1]) == "-h" );
if ( printHelp )
{
printf("\nThis sample demonstrates the grayworld balance algorithm\n"
"Call:\n"
" simple_color_blance -i=in_image_name [-o=out_image_name]\n\n");
return 0;
}
CommandLineParser parser(argc, argv, keys);
if ( !parser.check() )
{
parser.printErrors();
return -1;
}
string inFilename = parser.get<string>("i");
string outFilename = parser.get<string>("o");
Mat src = imread(inFilename, 1);
if ( src.empty() )
{
printf("Cannot read image file: %s\n", inFilename.c_str());
return -1;
}
Mat res(src.size(), src.type());
xphoto::autowbGrayworld(src, res);
if ( outFilename == "" )
{
namedWindow("after white balance", 1);
imshow("after white balance", res);
waitKey(0);
}
else
imwrite(outFilename, res);
return 0;
}
......@@ -80,7 +80,7 @@ def get_tree_node_lists(tree, tree_depth):
return (dst_feature_idx, dst_thresh_vals, dst_leaf_vals)
def generate_code(model, input_params):
def generate_code(model, input_params, use_YML, out_file):
feature_idx = []
thresh_vals = []
leaf_vals = []
......@@ -95,31 +95,60 @@ def generate_code(model, input_params):
feature_idx += local_feature_idx
thresh_vals += local_thresh_vals
leaf_vals += local_leaf_vals
if use_YML:
fs = cv2.FileStorage(out_file, 1)
fs.write("num_trees", len(model))
fs.write("num_tree_nodes", 2**depth)
fs.write("feature_idx", np.array(feature_idx).astype(np.uint8))
fs.write("thresh_vals", np.array(thresh_vals).astype(np.float32))
fs.write("leaf_vals", np.array(leaf_vals).astype(np.float32))
fs.release()
else:
res = "/* This file was automatically generated by learn_color_balance.py script\n" +\
" * using the following parameters:\n"
for key in input_params:
res += " " + key + " " + input_params[key]
res += "\n */\n"
res += "const int num_trees = " + str(len(model)) + ";\n"
res += "const int num_features = 4;\n"
res += "const int num_tree_nodes = " + str(2**depth) + ";\n"
res += "const int _num_trees = " + str(len(model)) + ";\n"
res += "const int _num_tree_nodes = " + str(2**depth) + ";\n"
res += "unsigned char feature_idx[num_trees*num_features*2*(num_tree_nodes-1)] = {" + str(feature_idx[0])
res += "unsigned char _feature_idx[_num_trees*num_features*2*(_num_tree_nodes-1)] = {" + str(feature_idx[0])
for i in range(1,len(feature_idx)):
res += "," + str(feature_idx[i])
res += "};\n"
res += "float thresh_vals[num_trees*num_features*2*(num_tree_nodes-1)] = {" + ("%.3ff" % thresh_vals[0])[1:]
res += "float _thresh_vals[_num_trees*num_features*2*(_num_tree_nodes-1)] = {" + ("%.3ff" % thresh_vals[0])[1:]
for i in range(1,len(thresh_vals)):
res += "," + ("%.3ff" % thresh_vals[i])[1:]
res += "};\n"
res += "float leaf_vals[num_trees*num_features*2*num_tree_nodes] = {" + ("%.3ff" % leaf_vals[0])[1:]
res += "float _leaf_vals[_num_trees*num_features*2*_num_tree_nodes] = {" + ("%.3ff" % leaf_vals[0])[1:]
for i in range(1,len(leaf_vals)):
res += "," + ("%.3ff" % leaf_vals[i])[1:]
res += "};\n"
return res
f = open(out_file,"w")
f.write(res)
f.close()
def load_ground_truth(gt_path):
gt = scipy.io.loadmat(gt_path)
base_gt_illuminants = []
black_levels = []
if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys():
#NUS 8-camera dataset format
base_gt_illuminants = gt["groundtruth_illuminants"]
black_levels = len(base_gt_illuminants) * [gt["darkness_level"][0][0]]
elif "real_rgb" in gt.keys():
#Gehler-Shi dataset format
base_gt_illuminants = gt["real_rgb"]
black_levels = 87 * [0] + (len(base_gt_illuminants) - 87) * [129]
else:
print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported")
sys.exit(1)
return (base_gt_illuminants, black_levels)
if __name__ == '__main__':
......@@ -153,8 +182,9 @@ if __name__ == '__main__':
"-o",
"--out",
metavar="OUT",
default="learning_based_color_balance_model.hpp",
help="Path to the output learnt model")
default="color_balance_model.yml",
help="Path to the output learnt model. Either a .yml (for loading during runtime) "
"or .hpp (for compiling with the main code) file ")
parser.add_argument(
"--hist_bin_num",
metavar="HIST_BIN_NUM",
......@@ -196,39 +226,37 @@ if __name__ == '__main__':
print("Error: Please specify the -r parameter in form <first_image_index>,<last_image_index>")
sys.exit(1)
use_YML = None
if args.out.endswith(".yml"):
use_YML = True
elif args.out.endswith(".hpp"):
use_YML = False
else:
print("Error: Only .hpp and .yml are supported as output formats")
sys.exit(1)
hist_bin_num = int(args.hist_bin_num)
num_trees = int(args.num_trees)
max_tree_depth = int(args.max_tree_depth)
gt = scipy.io.loadmat(args.ground_truth)
img_files = sorted(os.listdir(args.input_folder))
base_gt_illuminants = []
black_levels = []
if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys():
#NUS 8-camera dataset format
base_gt_illuminants = gt["groundtruth_illuminants"]
black_levels = len(gt_illuminants) * [gt["darkness_level"][0][0]]
elif "real_rgb" in gt.keys():
#Gehler-Shi dataset format
base_gt_illuminants = gt["real_rgb"]
black_levels = 87 * [0] + (len(base_gt_illuminants) - 87) * [129]
else:
print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported")
sys.exit(1)
(base_gt_illuminants,black_levels) = load_ground_truth(args.ground_truth)
features = []
gt_illuminants = []
i=0
sz = len(img_files)
random.seed(1234)
inst = cv2.xphoto.createLearningBasedWB()
inst.setRangeMaxVal(255)
inst.setSaturationThreshold(0.98)
inst.setHistBinNum(hist_bin_num)
for file in img_files:
if (i>=img_range[0] and i<img_range[1]) or (img_range[0]==img_range[1]==0):
cur_path = os.path.join(args.input_folder,file)
im = cv2.imread(cur_path, -1).astype(np.float32)
im -= black_levels[i]
im_8bit = convert_to_8bit(im)
cur_img_features = cv2.xphoto.extractSimpleFeatures(im_8bit, None, 255, 0.98, hist_bin_num)
cur_img_features = inst.extractSimpleFeatures(im_8bit, None)
features.append(cur_img_features.tolist())
gt_illuminants.append(base_gt_illuminants[i].tolist())
......@@ -241,7 +269,7 @@ if __name__ == '__main__':
im_8bit[:,:,1] *= G_coef
im_8bit[:,:,2] *= R_coef
im_8bit = convert_to_8bit(im)
cur_img_features = cv2.xphoto.extractSimpleFeatures(im_8bit, None, 255, 0.98, hist_bin_num)
cur_img_features = inst.extractSimpleFeatures(im_8bit, None)
features.append(cur_img_features.tolist())
illum = base_gt_illuminants[i]
illum[0] *= R_coef
......@@ -255,10 +283,8 @@ if __name__ == '__main__':
print("\nLearning the model...")
model = learn_regression_tree_ensemble(features, gt_illuminants, num_trees, max_tree_depth)
print("Generating code...")
str = generate_code(model,{"-r":args.range, "--hist_bin_num": args.hist_bin_num, "--num_trees": args.num_trees,
"--max_tree_depth": args.max_tree_depth, "--num_augmented": args.num_augmented})
f = open(args.out,"w")
f.write(str)
f.close()
print("Writing the model...")
generate_code(model,{"-r":args.range, "--hist_bin_num": args.hist_bin_num, "--num_trees": args.num_trees,
"--max_tree_depth": args.max_tree_depth, "--num_augmented": args.num_augmented},
use_YML, args.out)
print("Done")
#include "opencv2/xphoto.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/core/utility.hpp"
#include "opencv2/imgproc/types_c.h"
const char* keys =
{
"{i || input image name}"
"{o || output image name}"
};
int main( int argc, const char** argv )
{
bool printHelp = ( argc == 1 );
printHelp = printHelp || ( argc == 2 && std::string(argv[1]) == "--help" );
printHelp = printHelp || ( argc == 2 && std::string(argv[1]) == "-h" );
if ( printHelp )
{
printf("\nThis sample demonstrates simple color balance algorithm\n"
"Call:\n"
" simple_color_blance -i=in_image_name [-o=out_image_name]\n\n");
return 0;
}
cv::CommandLineParser parser(argc, argv, keys);
if ( !parser.check() )
{
parser.printErrors();
return -1;
}
std::string inFilename = parser.get<std::string>("i");
std::string outFilename = parser.get<std::string>("o");
cv::Mat src = cv::imread(inFilename, 1);
if ( src.empty() )
{
printf("Cannot read image file: %s\n", inFilename.c_str());
return -1;
}
cv::Mat res(src.size(), src.type());
cv::xphoto::balanceWhite(src, res, cv::xphoto::WHITE_BALANCE_SIMPLE);
if ( outFilename == "" )
{
cv::namedWindow("after white balance", 1);
cv::imshow("after white balance", res);
cv::waitKey(0);
}
else
cv::imwrite(outFilename, res);
return 0;
}
\ No newline at end of file
......@@ -49,6 +49,54 @@ namespace xphoto
void calculateChannelSums(uint &sumB, uint &sumG, uint &sumR, uchar *src_data, int src_len, float thresh);
void calculateChannelSums(uint64 &sumB, uint64 &sumG, uint64 &sumR, ushort *src_data, int src_len, float thresh);
class GrayworldWBImpl : public GrayworldWB
{
private:
float thresh;
public:
GrayworldWBImpl() { thresh = 0.9f; }
float getSaturationThreshold() const { return thresh; }
void setSaturationThreshold(float val) { thresh = val; }
void balanceWhite(InputArray _src, OutputArray _dst)
{
CV_Assert(!_src.empty());
CV_Assert(_src.isContinuous());
CV_Assert(_src.type() == CV_8UC3 || _src.type() == CV_16UC3);
Mat src = _src.getMat();
int N = src.cols * src.rows, N3 = N * 3;
double dsumB = 0.0, dsumG = 0.0, dsumR = 0.0;
if (src.type() == CV_8UC3)
{
uint sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<uchar>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
else if (src.type() == CV_16UC3)
{
uint64 sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<ushort>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
// Find inverse of averages
double max_sum = max(dsumB, max(dsumR, dsumG));
const double eps = 0.1;
float dinvB = dsumB < eps ? 0.f : (float)(max_sum / dsumB),
dinvG = dsumG < eps ? 0.f : (float)(max_sum / dsumG),
dinvR = dsumR < eps ? 0.f : (float)(max_sum / dsumR);
// Use the inverse of averages as channel gains:
applyChannelGains(src, _dst, dinvB, dinvG, dinvR);
}
};
/* Computes sums for each channel, while ignoring saturated pixels which are determined by thresh
* (version for CV_8UC3)
*/
......@@ -297,41 +345,6 @@ void applyChannelGains(InputArray _src, OutputArray _dst, float gainB, float gai
}
}
void autowbGrayworld(InputArray _src, OutputArray _dst, float thresh)
{
Mat src = _src.getMat();
CV_Assert(!src.empty());
CV_Assert(src.isContinuous());
CV_Assert(src.type() == CV_8UC3 || src.type() == CV_16UC3);
int N = src.cols * src.rows, N3 = N * 3;
double dsumB = 0.0, dsumG = 0.0, dsumR = 0.0;
if (src.type() == CV_8UC3)
{
uint sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<uchar>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
else if (src.type() == CV_16UC3)
{
uint64 sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<ushort>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
// Find inverse of averages
double max_sum = max(dsumB, max(dsumR, dsumG));
const double eps = 0.1;
float dinvB = dsumB < eps ? 0.f : (float)(max_sum / dsumB), dinvG = dsumG < eps ? 0.f : (float)(max_sum / dsumG),
dinvR = dsumR < eps ? 0.f : (float)(max_sum / dsumR);
// Use the inverse of averages as channel gains:
applyChannelGains(src, _dst, dinvB, dinvG, dinvR);
}
Ptr<GrayworldWB> createGrayworldWB() { return makePtr<GrayworldWBImpl>(); }
}
}
......@@ -2,10 +2,10 @@
* using the following parameters:
--num_trees 20 --hist_bin_num 64 --max_tree_depth 4 --num_augmented 2 -r 0,0
*/
const int num_trees = 20;
const int num_features = 4;
const int num_tree_nodes = 16;
unsigned char feature_idx[num_trees * num_features * 2 * (num_tree_nodes - 1)] = {
const int _num_trees = 20;
const int _num_tree_nodes = 16;
unsigned char _feature_idx[_num_trees * num_features * 2 * (_num_tree_nodes - 1)] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1,
1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
......@@ -68,7 +68,7 @@ unsigned char feature_idx[num_trees * num_features * 2 * (num_tree_nodes - 1)] =
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1,
1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
float thresh_vals[num_trees * num_features * 2 * (num_tree_nodes - 1)] = {
float _thresh_vals[_num_trees * num_features * 2 * (_num_tree_nodes - 1)] = {
.193f, .098f, .455f, .040f, .145f, .316f, .571f, .016f, .058f, .137f, .174f, .276f, .356f, .515f, .730f, .606f, .324f,
.794f, .230f, .440f, .683f, .878f, .134f, .282f, .406f, .532f, .036f, .747f, .830f, .931f, .196f, .145f, .363f, .047f,
.351f, .279f, .519f, .013f, .887f, .191f, .193f, .361f, .316f, .576f, .445f, .524f, .368f, .752f, .271f, .477f, .636f,
......@@ -211,7 +211,7 @@ float thresh_vals[num_trees * num_features * 2 * (num_tree_nodes - 1)] = {
.550f, .000f, .195f, .377f, .500f, .984f, .000f, .479f, .183f, .704f, .082f, .310f, .567f, .875f, .043f, .141f, .271f,
.372f, .511f, .630f, .762f, .896f, .325f, .164f, .602f, .086f, .230f, .414f, .761f, .040f, .131f, .197f, .283f, .352f,
.516f, .685f, .855f};
float leaf_vals[num_trees * num_features * 2 * num_tree_nodes] = {
float _leaf_vals[_num_trees * num_features * 2 * _num_tree_nodes] = {
.011f, .029f, .047f, .064f, .075f, .102f, .141f, .172f, .212f, .259f, .308f, .364f, .443f, .497f, .592f, .767f, .069f,
.165f, .241f, .278f, .357f, .412f, .463f, .540f, .562f, .623f, .676f, .734f, .797f, .838f, .894f, .944f, .014f, .040f,
.061f, .033f, .040f, .160f, .181f, .101f, .123f, .047f, .195f, .282f, .374f, .775f, .248f, .068f, .064f, .155f, .177f,
......
......@@ -37,50 +37,39 @@
//
//M*/
#include <vector>
#include <algorithm>
#include <iterator>
#include <iostream>
#include "opencv2/xphoto.hpp"
#include "opencv2/imgproc.hpp"
#include <iterator>
#include <vector>
#include "opencv2/core.hpp"
#include "opencv2/core/core_c.h"
#include "opencv2/core/types.hpp"
#include "opencv2/core/types_c.h"
#include "opencv2/imgproc.hpp"
#include "opencv2/xphoto.hpp"
namespace cv
{
namespace xphoto
{
template <typename T>
void balanceWhite(std::vector < Mat_<T> > &src, Mat &dst,
const float inputMin, const float inputMax,
const float outputMin, const float outputMax, const int algorithmType)
{
switch ( algorithmType )
{
case WHITE_BALANCE_SIMPLE:
{
template <typename T>
void balanceWhiteSimple(std::vector<Mat_<T> > &src, Mat &dst, const float inputMin, const float inputMax,
const float outputMin, const float outputMax, const float p)
{
/********************* Simple white balance *********************/
float s1 = 2.0f; // low quantile
float s2 = 2.0f; // high quantile
const float s1 = p; // low quantile
const float s2 = p; // high quantile
int depth = 2; // depth of histogram tree
if (src[0].depth() != CV_8U)
++depth;
int bins = 16; // number of bins at each histogram level
int nElements = int( pow((float)bins, (float)depth) );
int nElements = int(pow((float)bins, (float)depth));
// number of elements in histogram tree
for (size_t i = 0; i < src.size(); ++i)
{
std::vector <int> hist(nElements, 0);
std::vector<int> hist(nElements, 0);
typename Mat_<T>::iterator beginIt = src[i].begin();
typename Mat_<T>::iterator endIt = src[i].end();
......@@ -97,19 +86,19 @@ namespace xphoto
for (int j = 0; j < depth; ++j)
{
int currentBin = int( (val - minValue + 1e-4f) / interval );
int currentBin = int((val - minValue + 1e-4f) / interval);
++hist[pos + currentBin];
pos = (pos + currentBin)*bins;
pos = (pos + currentBin) * bins;
minValue = minValue + currentBin*interval;
minValue = minValue + currentBin * interval;
maxValue = minValue + interval;
interval /= bins;
}
}
int total = int( src[i].total() );
int total = int(src[i].total());
int p1 = 0, p2 = bins - 1;
int n1 = 0, n2 = total;
......@@ -134,78 +123,90 @@ namespace xphoto
n2 -= hist[p2--];
maxValue -= interval;
}
p2 = p2*bins - 1;
p2 = p2 * bins - 1;
interval /= bins;
}
src[i] = (outputMax - outputMin) * (src[i] - minValue)
/ (maxValue - minValue) + outputMin;
src[i] = (outputMax - outputMin) * (src[i] - minValue) / (maxValue - minValue) + outputMin;
}
/****************************************************************/
break;
}
default:
CV_Error_( CV_StsNotImplemented,
("Unsupported algorithm type (=%d)", algorithmType) );
}
dst.create(/**/ src[0].size(), CV_MAKETYPE( src[0].depth(), int( src.size() ) ) /**/);
dst.create(/**/ src[0].size(), CV_MAKETYPE(src[0].depth(), int(src.size())) /**/);
cv::merge(src, dst);
}
class SimpleWBImpl : public SimpleWB
{
private:
float inputMin, inputMax, outputMin, outputMax, p;
public:
SimpleWBImpl()
{
inputMin = 0.0f;
inputMax = 255.0f;
outputMin = 0.0f;
outputMax = 255.0f;
p = 2.0f;
}
/*!
* Wrappers over different white balance algorithm
*
* \param src : source image (RGB)
* \param dst : destination image
*
* \param inputMin : minimum input value
* \param inputMax : maximum input value
* \param outputMin : minimum output value
* \param outputMax : maximum output value
*
* \param algorithmType : type of the algorithm to use
*/
void balanceWhite(const Mat &src, Mat &dst, const int algorithmType,
const float inputMin, const float inputMax,
const float outputMin, const float outputMax)
float getInputMin() const { return inputMin; }
void setInputMin(float val) { inputMin = val; }
float getInputMax() const { return inputMax; }
void setInputMax(float val) { inputMax = val; }
float getOutputMin() const { return outputMin; }
void setOutputMin(float val) { outputMin = val; }
float getOutputMax() const { return outputMax; }
void setOutputMax(float val) { outputMax = val; }
float getP() const { return p; }
void setP(float val) { p = val; }
void balanceWhite(InputArray _src, OutputArray _dst)
{
switch ( src.depth() )
CV_Assert(!_src.empty());
CV_Assert(_src.depth() == CV_8U || _src.depth() == CV_16S || _src.depth() == CV_32S || _src.depth() == CV_32F);
Mat src = _src.getMat();
Mat &dst = _dst.getMatRef();
switch (src.depth())
{
case CV_8U:
{
std::vector < Mat_<uchar> > mv;
std::vector<Mat_<uchar> > mv;
split(src, mv);
balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType);
balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p);
break;
}
case CV_16S:
{
std::vector < Mat_<short> > mv;
std::vector<Mat_<short> > mv;
split(src, mv);
balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType);
balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p);
break;
}
case CV_32S:
{
std::vector < Mat_<int> > mv;
std::vector<Mat_<int> > mv;
split(src, mv);
balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType);
balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p);
break;
}
case CV_32F:
{
std::vector < Mat_<float> > mv;
std::vector<Mat_<float> > mv;
split(src, mv);
balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType);
balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p);
break;
}
default:
CV_Error_( CV_StsNotImplemented,
("Unsupported source image format (=%d)", src.type()) );
break;
}
}
};
Ptr<SimpleWB> createSimpleWB() { return makePtr<SimpleWBImpl>(); }
}
}
......@@ -7,6 +7,7 @@ namespace cvtest
cv::String dir = cvtest::TS::ptr()->get_data_path() + "cv/xphoto/simple_white_balance/";
int nTests = 12;
float threshold = 0.005f;
cv::Ptr<cv::xphoto::WhiteBalancer> wb = cv::xphoto::createSimpleWB();
for (int i = 0; i < nTests; ++i)
{
......@@ -18,7 +19,7 @@ namespace cvtest
cv::Mat previousResult = cv::imread( previousResultName, 1 );
cv::Mat currentResult;
cv::xphoto::balanceWhite(src, currentResult, cv::xphoto::WHITE_BALANCE_SIMPLE);
wb->balanceWhite(src, currentResult);
cv::Mat sqrError = ( currentResult - previousResult )
.mul( currentResult - previousResult );
......
......@@ -69,6 +69,8 @@ namespace cvtest {
const int nTests = 14;
const float wb_thresh = 0.5f;
const float acc_thresh = 2.f;
Ptr<xphoto::GrayworldWB> wb = xphoto::createGrayworldWB();
wb->setSaturationThreshold(wb_thresh);
for ( int i = 0; i < nTests; ++i )
{
......@@ -80,13 +82,13 @@ namespace cvtest {
ref_autowbGrayworld(src, referenceResult, wb_thresh);
Mat currentResult;
xphoto::autowbGrayworld(src, currentResult, wb_thresh);
wb->balanceWhite(src, currentResult);
ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh);
// test the 16-bit depth:
Mat currentResult_16U, src_16U;
src.convertTo(src_16U, CV_16UC3, 256.0);
xphoto::autowbGrayworld(src_16U, currentResult_16U, wb_thresh);
wb->balanceWhite(src_16U, currentResult_16U);
currentResult_16U.convertTo(currentResult, CV_8UC3, 1/256.0);
ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh);
}
......
......@@ -18,7 +18,11 @@ TEST(xphoto_simplefeatures, regression)
Vec2f ref2(200.0f / (240 + 220 + 200), 220.0f / (240 + 220 + 200));
vector<Vec2f> dst_features;
xphoto::extractSimpleFeatures(test_im, dst_features, 255, 0.98f, 64);
Ptr<xphoto::LearningBasedWB> wb = xphoto::createLearningBasedWB();
wb->setRangeMaxVal(255);
wb->setSaturationThreshold(0.98f);
wb->setHistBinNum(64);
wb->extractSimpleFeatures(test_im, dst_features);
ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh);
......@@ -26,7 +30,10 @@ TEST(xphoto_simplefeatures, regression)
// check 16 bit depth:
test_im.convertTo(test_im, CV_16U, 256.0);
xphoto::extractSimpleFeatures(test_im, dst_features, 65535, 0.98f, 64);
wb->setRangeMaxVal(65535);
wb->setSaturationThreshold(0.98f);
wb->setHistBinNum(128);
wb->extractSimpleFeatures(test_im, dst_features);
ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh);
......
Training the learning-based white balance algorithm {#tutorial_xphoto_training_white_balance}
===================================================
Introduction
------------
Many traditional white balance algorithms are statistics-based, i.e. they rely on the fact that certain assumptions should hold in properly white-balanced images
like the well-known grey-world assumption. However, better results can often be achieved by leveraging large datasets of images with ground-truth
illuminants in a learning-based framework. This tutorial demonstrates how to train a learning-based white balance algorithm and evaluate the quality of the results.
How to train a model
--------------------
-# Download a dataset for training. In this tutorial we will use the [Gehler-Shi dataset ](http://www.cs.sfu.ca/~colour/data/shi_gehler/). Extract all 568 training images
in one folder. A file containing ground-truth illuminant values (real_illum_568..mat) is downloaded separately.
-# We will be using a [Python script ](https://github.com/opencv/opencv_contrib/tree/master/modules/xphoto/samples/learn_color_balance.py) for training.
Call it with the following parameters:
@code
python learn_color_balance.py -i <path to the folder with training images> -g <path to real_illum_568..mat> -r 0,378 --num_trees 30 --max_tree_depth 6 --num_augmented 0
@endcode
This should start training a model on the first 378 images (2/3 of the whole dataset). We set the size of the model to be 30 regression tree pairs per feature and limit
the tree depth to be no more then 6. By default the resulting model will be saved to color_balance_model.yml
-# Use the trained model by passing its path when constructing an instance of LearningBasedWB:
@code{.cpp}
Ptr<xphoto::LearningBasedWB> wb = xphoto::createLearningBasedWB(modelFilename);
@endcode
How to evaluate a model
----------------------
-# We will use a [benchmarking script ](https://github.com/opencv/opencv_contrib/tree/master/modules/xphoto/samples/color_balance_benchmark.py) to compare
the model that we've trained with the classic grey-world algorithm on the remaining 1/3 of the dataset. Call the script with the following parameters:
@code
python color_balance_benchmark.py -a grayworld,learning_based:color_balance_model.yml -m <full path to folder containing the model> -i <path to the folder with training images> -g <path to real_illum_568..mat> -r 379,567 -d "img"
@endcode
-# The objective evaluation results are stored in white_balance_eval_result.html and the resulting white-balanced images are stored in the img folder for a qualitative
comparison of algorithms. Different algorithms are compared in terms of angular error between the estimated and ground-truth illuminants.
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment