Commit 191d5c1d authored by Maksim Shabunin's avatar Maksim Shabunin

Merge pull request #747 from sbokov:color_constancy

parents 3468ae57 75dedf9e
...@@ -21,8 +21,10 @@ PERF_TEST_P( Size_WBThresh, autowbGrayworld, ...@@ -21,8 +21,10 @@ PERF_TEST_P( Size_WBThresh, autowbGrayworld,
Mat dst(size, CV_8UC3); Mat dst(size, CV_8UC3);
declare.in(src, WARMUP_RNG).out(dst); declare.in(src, WARMUP_RNG).out(dst);
Ptr<xphoto::GrayworldWB> wb = xphoto::createGrayworldWB();
wb->setSaturationThreshold(wb_thresh);
TEST_CYCLE() xphoto::autowbGrayworld(src, dst, wb_thresh); TEST_CYCLE() wb->balanceWhite(src, dst);
SANITY_CHECK(dst); SANITY_CHECK(dst);
} }
......
...@@ -65,8 +65,12 @@ PERF_TEST_P(learningBasedWBPerfTest, perf, Combine(SZ_ALL_HD, Values(CV_8UC3, CV ...@@ -65,8 +65,12 @@ PERF_TEST_P(learningBasedWBPerfTest, perf, Combine(SZ_ALL_HD, Values(CV_8UC3, CV
RNG rng(1234); RNG rng(1234);
rng.fill(src_dscl, RNG::UNIFORM, 0, range_max_val); rng.fill(src_dscl, RNG::UNIFORM, 0, range_max_val);
resize(src_dscl, src, src.size()); resize(src_dscl, src, src.size());
Ptr<xphoto::LearningBasedWB> wb = xphoto::createLearningBasedWB();
wb->setRangeMaxVal(range_max_val);
wb->setSaturationThreshold(0.98f);
wb->setHistBinNum(hist_bin_num);
TEST_CYCLE() xphoto::autowbLearningBased(src, dst, range_max_val, 0.98f, hist_bin_num); TEST_CYCLE() wb->balanceWhite(src, dst);
SANITY_CHECK_NOTHING(); SANITY_CHECK_NOTHING();
} }
...@@ -4,14 +4,16 @@ ...@@ -4,14 +4,16 @@
using namespace cv; using namespace cv;
using namespace std; using namespace std;
const char *keys = {"{help h usage ? | | print this message}" const char *keys = { "{help h usage ? | | print this message}"
"{i | | input image name }" "{i | | input image name }"
"{o | | output image name }"}; "{o | | output image name }"
"{a |grayworld| color balance algorithm (simple, grayworld or learning_based)}"
"{m | | path to the model for the learning-based algorithm (optional) }" };
int main(int argc, const char **argv) int main(int argc, const char **argv)
{ {
CommandLineParser parser(argc, argv, keys); CommandLineParser parser(argc, argv, keys);
parser.about("OpenCV learning-based color balance demonstration sample"); parser.about("OpenCV color balance demonstration sample");
if (parser.has("help") || argc < 2) if (parser.has("help") || argc < 2)
{ {
parser.printMessage(); parser.printMessage();
...@@ -20,6 +22,8 @@ int main(int argc, const char **argv) ...@@ -20,6 +22,8 @@ int main(int argc, const char **argv)
string inFilename = parser.get<string>("i"); string inFilename = parser.get<string>("i");
string outFilename = parser.get<string>("o"); string outFilename = parser.get<string>("o");
string algorithm = parser.get<string>("a");
string modelFilename = parser.get<string>("m");
if (!parser.check()) if (!parser.check())
{ {
...@@ -35,7 +39,20 @@ int main(int argc, const char **argv) ...@@ -35,7 +39,20 @@ int main(int argc, const char **argv)
} }
Mat res; Mat res;
xphoto::autowbLearningBased(src, res); Ptr<xphoto::WhiteBalancer> wb;
if (algorithm == "simple")
wb = xphoto::createSimpleWB();
else if (algorithm == "grayworld")
wb = xphoto::createGrayworldWB();
else if (algorithm == "learning_based")
wb = xphoto::createLearningBasedWB(modelFilename);
else
{
printf("Unsupported algorithm: %s\n", algorithm.c_str());
return -1;
}
wb->balanceWhite(src, res);
if (outFilename == "") if (outFilename == "")
{ {
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import scipy.io import scipy.io
import cv2 import cv2
import timeit import timeit
from learn_color_balance import load_ground_truth
def load_json(path): def load_json(path):
...@@ -39,15 +40,24 @@ def stretch_to_8bit(arr, clip_percentile = 2.5): ...@@ -39,15 +40,24 @@ def stretch_to_8bit(arr, clip_percentile = 2.5):
return arr.astype(np.uint8) return arr.astype(np.uint8)
def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder): def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder, model_folder):
new_im = None new_im = None
start_time = timeit.default_timer() start_time = timeit.default_timer()
if algo=="grayworld": if algo=="grayworld":
new_im = cv2.xphoto.autowbGrayworld(im, 0.95) inst = cv2.xphoto.createGrayworldWB()
inst.setSaturationThreshold(0.95)
new_im = inst.balanceWhite(im)
elif algo=="nothing": elif algo=="nothing":
new_im = im new_im = im
elif algo=="learning_based": elif algo.split(":")[0]=="learning_based":
new_im = cv2.xphoto.autowbLearningBased(im, None, range_thresh, 0.98, bin_num) model_path = ""
if len(algo.split(":"))>1:
model_path = os.path.join(model_folder, algo.split(":")[1])
inst = cv2.xphoto.createLearningBasedWB(model_path)
inst.setRangeMaxVal(range_thresh)
inst.setSaturationThreshold(0.98)
inst.setHistBinNum(bin_num)
new_im = inst.balanceWhite(im)
elif algo=="GT": elif algo=="GT":
gains = gt_illuminant / min(gt_illuminant) gains = gt_illuminant / min(gt_illuminant)
g1 = float(1.0 / gains[2]) g1 = float(1.0 / gains[2])
...@@ -59,7 +69,7 @@ def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder): ...@@ -59,7 +69,7 @@ def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder):
if len(dst_folder)>0: if len(dst_folder)>0:
if not os.path.exists(dst_folder): if not os.path.exists(dst_folder):
os.makedirs(dst_folder) os.makedirs(dst_folder)
im_name = ("%04d_" % i) + algo + ".jpg" im_name = ("%04d_" % i) + algo.replace(":","_") + ".jpg"
cv2.imwrite(os.path.join(dst_folder, im_name), stretch_to_8bit(new_im)) cv2.imwrite(os.path.join(dst_folder, im_name), stretch_to_8bit(new_im))
#recover the illuminant from the color balancing result, assuming the standard model: #recover the illuminant from the color balancing result, assuming the standard model:
...@@ -140,7 +150,9 @@ if __name__ == '__main__': ...@@ -140,7 +150,9 @@ if __name__ == '__main__':
metavar="ALGORITHMS", metavar="ALGORITHMS",
default="", default="",
help=("Comma-separated list of color balance algorithms to evaluate. " help=("Comma-separated list of color balance algorithms to evaluate. "
"Currently available: GT,learning_based,grayworld,nothing.")) "Currently available: GT,learning_based,grayworld,nothing. "
"Use a colon to set a specific model for the learning-based "
"algorithm, e.g. learning_based:model1.yml,learning_based:model2.yml"))
parser.add_argument( parser.add_argument(
"-i", "-i",
"--input_folder", "--input_folder",
...@@ -196,6 +208,12 @@ if __name__ == '__main__': ...@@ -196,6 +208,12 @@ if __name__ == '__main__':
default="0,0", default="0,0",
help=("Comma-separated range of images from the dataset to evaluate on (for instance: 0,568). " help=("Comma-separated range of images from the dataset to evaluate on (for instance: 0,568). "
"All available images are used by default.")) "All available images are used by default."))
parser.add_argument(
"-m",
"--model_folder",
metavar="MODEL_FOLDER",
default="",
help=("Path to the folder containing models for the learning-based color balance algorithm (optional)"))
args, other_args = parser.parse_known_args() args, other_args = parser.parse_known_args()
if not os.path.exists(args.input_folder): if not os.path.exists(args.input_folder):
...@@ -218,22 +236,8 @@ if __name__ == '__main__': ...@@ -218,22 +236,8 @@ if __name__ == '__main__':
print("Error: Please specify the -r parameter in form <first_image_index>,<last_image_index>") print("Error: Please specify the -r parameter in form <first_image_index>,<last_image_index>")
sys.exit(1) sys.exit(1)
gt = scipy.io.loadmat(args.ground_truth)
img_files = sorted(os.listdir(args.input_folder)) img_files = sorted(os.listdir(args.input_folder))
(gt_illuminants,black_levels) = load_ground_truth(args.ground_truth)
gt_illuminants = []
black_levels = []
if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys():
#NUS 8-camera dataset format
gt_illuminants = gt["groundtruth_illuminants"]
black_levels = len(gt_illuminants) * [gt["darkness_level"][0][0]]
elif "real_rgb" in gt.keys():
#Gehler-Shi dataset format
gt_illuminants = gt["real_rgb"]
black_levels = 87 * [0] + (len(gt_illuminants) - 87) * [129]
else:
print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported")
sys.exit(1)
for algorithm in algorithm_list: for algorithm in algorithm_list:
i = 0 i = 0
...@@ -254,7 +258,7 @@ if __name__ == '__main__': ...@@ -254,7 +258,7 @@ if __name__ == '__main__':
im = stretch_to_8bit(im) im = stretch_to_8bit(im)
(time,angular_err) = evaluate(im, algorithm, gt_illuminants[i], i, range_thresh, (time,angular_err) = evaluate(im, algorithm, gt_illuminants[i], i, range_thresh,
256 if range_thresh > 255 else 64, args.dst_folder) 256 if range_thresh > 255 else 64, args.dst_folder, args.model_folder)
state[algorithm][file] = {"angular_error": angular_err, "time": time} state[algorithm][file] = {"angular_error": angular_err, "time": time}
sys.stdout.write("Algorithm: %-20s Done: [%3d/%3d]\r" % (algorithm, i, sz)), sys.stdout.write("Algorithm: %-20s Done: [%3d/%3d]\r" % (algorithm, i, sz)),
sys.stdout.flush() sys.stdout.flush()
......
#include "opencv2/xphoto.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/core/utility.hpp"
using namespace cv;
using namespace std;
const char* keys =
{
"{i || input image name}"
"{o || output image name}"
};
int main( int argc, const char** argv )
{
bool printHelp = ( argc == 1 );
printHelp = printHelp || ( argc == 2 && string(argv[1]) == "--help" );
printHelp = printHelp || ( argc == 2 && string(argv[1]) == "-h" );
if ( printHelp )
{
printf("\nThis sample demonstrates the grayworld balance algorithm\n"
"Call:\n"
" simple_color_blance -i=in_image_name [-o=out_image_name]\n\n");
return 0;
}
CommandLineParser parser(argc, argv, keys);
if ( !parser.check() )
{
parser.printErrors();
return -1;
}
string inFilename = parser.get<string>("i");
string outFilename = parser.get<string>("o");
Mat src = imread(inFilename, 1);
if ( src.empty() )
{
printf("Cannot read image file: %s\n", inFilename.c_str());
return -1;
}
Mat res(src.size(), src.type());
xphoto::autowbGrayworld(src, res);
if ( outFilename == "" )
{
namedWindow("after white balance", 1);
imshow("after white balance", res);
waitKey(0);
}
else
imwrite(outFilename, res);
return 0;
}
...@@ -80,7 +80,7 @@ def get_tree_node_lists(tree, tree_depth): ...@@ -80,7 +80,7 @@ def get_tree_node_lists(tree, tree_depth):
return (dst_feature_idx, dst_thresh_vals, dst_leaf_vals) return (dst_feature_idx, dst_thresh_vals, dst_leaf_vals)
def generate_code(model, input_params): def generate_code(model, input_params, use_YML, out_file):
feature_idx = [] feature_idx = []
thresh_vals = [] thresh_vals = []
leaf_vals = [] leaf_vals = []
...@@ -95,31 +95,60 @@ def generate_code(model, input_params): ...@@ -95,31 +95,60 @@ def generate_code(model, input_params):
feature_idx += local_feature_idx feature_idx += local_feature_idx
thresh_vals += local_thresh_vals thresh_vals += local_thresh_vals
leaf_vals += local_leaf_vals leaf_vals += local_leaf_vals
if use_YML:
fs = cv2.FileStorage(out_file, 1)
fs.write("num_trees", len(model))
fs.write("num_tree_nodes", 2**depth)
fs.write("feature_idx", np.array(feature_idx).astype(np.uint8))
fs.write("thresh_vals", np.array(thresh_vals).astype(np.float32))
fs.write("leaf_vals", np.array(leaf_vals).astype(np.float32))
fs.release()
else:
res = "/* This file was automatically generated by learn_color_balance.py script\n" +\
" * using the following parameters:\n"
for key in input_params:
res += " " + key + " " + input_params[key]
res += "\n */\n"
res += "const int num_features = 4;\n"
res += "const int _num_trees = " + str(len(model)) + ";\n"
res += "const int _num_tree_nodes = " + str(2**depth) + ";\n"
res += "unsigned char _feature_idx[_num_trees*num_features*2*(_num_tree_nodes-1)] = {" + str(feature_idx[0])
for i in range(1,len(feature_idx)):
res += "," + str(feature_idx[i])
res += "};\n"
res += "float _thresh_vals[_num_trees*num_features*2*(_num_tree_nodes-1)] = {" + ("%.3ff" % thresh_vals[0])[1:]
for i in range(1,len(thresh_vals)):
res += "," + ("%.3ff" % thresh_vals[i])[1:]
res += "};\n"
res += "float _leaf_vals[_num_trees*num_features*2*_num_tree_nodes] = {" + ("%.3ff" % leaf_vals[0])[1:]
for i in range(1,len(leaf_vals)):
res += "," + ("%.3ff" % leaf_vals[i])[1:]
res += "};\n"
f = open(out_file,"w")
f.write(res)
f.close()
def load_ground_truth(gt_path):
gt = scipy.io.loadmat(gt_path)
base_gt_illuminants = []
black_levels = []
if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys():
#NUS 8-camera dataset format
base_gt_illuminants = gt["groundtruth_illuminants"]
black_levels = len(base_gt_illuminants) * [gt["darkness_level"][0][0]]
elif "real_rgb" in gt.keys():
#Gehler-Shi dataset format
base_gt_illuminants = gt["real_rgb"]
black_levels = 87 * [0] + (len(base_gt_illuminants) - 87) * [129]
else:
print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported")
sys.exit(1)
res = "/* This file was automatically generated by learn_color_balance.py script\n" +\ return (base_gt_illuminants, black_levels)
" * using the following parameters:\n"
for key in input_params:
res += " " + key + " " + input_params[key]
res += "\n */\n"
res += "const int num_trees = " + str(len(model)) + ";\n"
res += "const int num_features = 4;\n"
res += "const int num_tree_nodes = " + str(2**depth) + ";\n"
res += "unsigned char feature_idx[num_trees*num_features*2*(num_tree_nodes-1)] = {" + str(feature_idx[0])
for i in range(1,len(feature_idx)):
res += "," + str(feature_idx[i])
res += "};\n"
res += "float thresh_vals[num_trees*num_features*2*(num_tree_nodes-1)] = {" + ("%.3ff" % thresh_vals[0])[1:]
for i in range(1,len(thresh_vals)):
res += "," + ("%.3ff" % thresh_vals[i])[1:]
res += "};\n"
res += "float leaf_vals[num_trees*num_features*2*num_tree_nodes] = {" + ("%.3ff" % leaf_vals[0])[1:]
for i in range(1,len(leaf_vals)):
res += "," + ("%.3ff" % leaf_vals[i])[1:]
res += "};\n"
return res
if __name__ == '__main__': if __name__ == '__main__':
...@@ -153,8 +182,9 @@ if __name__ == '__main__': ...@@ -153,8 +182,9 @@ if __name__ == '__main__':
"-o", "-o",
"--out", "--out",
metavar="OUT", metavar="OUT",
default="learning_based_color_balance_model.hpp", default="color_balance_model.yml",
help="Path to the output learnt model") help="Path to the output learnt model. Either a .yml (for loading during runtime) "
"or .hpp (for compiling with the main code) file ")
parser.add_argument( parser.add_argument(
"--hist_bin_num", "--hist_bin_num",
metavar="HIST_BIN_NUM", metavar="HIST_BIN_NUM",
...@@ -196,39 +226,37 @@ if __name__ == '__main__': ...@@ -196,39 +226,37 @@ if __name__ == '__main__':
print("Error: Please specify the -r parameter in form <first_image_index>,<last_image_index>") print("Error: Please specify the -r parameter in form <first_image_index>,<last_image_index>")
sys.exit(1) sys.exit(1)
use_YML = None
if args.out.endswith(".yml"):
use_YML = True
elif args.out.endswith(".hpp"):
use_YML = False
else:
print("Error: Only .hpp and .yml are supported as output formats")
sys.exit(1)
hist_bin_num = int(args.hist_bin_num) hist_bin_num = int(args.hist_bin_num)
num_trees = int(args.num_trees) num_trees = int(args.num_trees)
max_tree_depth = int(args.max_tree_depth) max_tree_depth = int(args.max_tree_depth)
gt = scipy.io.loadmat(args.ground_truth)
img_files = sorted(os.listdir(args.input_folder)) img_files = sorted(os.listdir(args.input_folder))
(base_gt_illuminants,black_levels) = load_ground_truth(args.ground_truth)
base_gt_illuminants = []
black_levels = []
if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys():
#NUS 8-camera dataset format
base_gt_illuminants = gt["groundtruth_illuminants"]
black_levels = len(gt_illuminants) * [gt["darkness_level"][0][0]]
elif "real_rgb" in gt.keys():
#Gehler-Shi dataset format
base_gt_illuminants = gt["real_rgb"]
black_levels = 87 * [0] + (len(base_gt_illuminants) - 87) * [129]
else:
print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported")
sys.exit(1)
features = [] features = []
gt_illuminants = [] gt_illuminants = []
i=0 i=0
sz = len(img_files) sz = len(img_files)
random.seed(1234) random.seed(1234)
inst = cv2.xphoto.createLearningBasedWB()
inst.setRangeMaxVal(255)
inst.setSaturationThreshold(0.98)
inst.setHistBinNum(hist_bin_num)
for file in img_files: for file in img_files:
if (i>=img_range[0] and i<img_range[1]) or (img_range[0]==img_range[1]==0): if (i>=img_range[0] and i<img_range[1]) or (img_range[0]==img_range[1]==0):
cur_path = os.path.join(args.input_folder,file) cur_path = os.path.join(args.input_folder,file)
im = cv2.imread(cur_path, -1).astype(np.float32) im = cv2.imread(cur_path, -1).astype(np.float32)
im -= black_levels[i] im -= black_levels[i]
im_8bit = convert_to_8bit(im) im_8bit = convert_to_8bit(im)
cur_img_features = cv2.xphoto.extractSimpleFeatures(im_8bit, None, 255, 0.98, hist_bin_num) cur_img_features = inst.extractSimpleFeatures(im_8bit, None)
features.append(cur_img_features.tolist()) features.append(cur_img_features.tolist())
gt_illuminants.append(base_gt_illuminants[i].tolist()) gt_illuminants.append(base_gt_illuminants[i].tolist())
...@@ -241,7 +269,7 @@ if __name__ == '__main__': ...@@ -241,7 +269,7 @@ if __name__ == '__main__':
im_8bit[:,:,1] *= G_coef im_8bit[:,:,1] *= G_coef
im_8bit[:,:,2] *= R_coef im_8bit[:,:,2] *= R_coef
im_8bit = convert_to_8bit(im) im_8bit = convert_to_8bit(im)
cur_img_features = cv2.xphoto.extractSimpleFeatures(im_8bit, None, 255, 0.98, hist_bin_num) cur_img_features = inst.extractSimpleFeatures(im_8bit, None)
features.append(cur_img_features.tolist()) features.append(cur_img_features.tolist())
illum = base_gt_illuminants[i] illum = base_gt_illuminants[i]
illum[0] *= R_coef illum[0] *= R_coef
...@@ -255,10 +283,8 @@ if __name__ == '__main__': ...@@ -255,10 +283,8 @@ if __name__ == '__main__':
print("\nLearning the model...") print("\nLearning the model...")
model = learn_regression_tree_ensemble(features, gt_illuminants, num_trees, max_tree_depth) model = learn_regression_tree_ensemble(features, gt_illuminants, num_trees, max_tree_depth)
print("Generating code...") print("Writing the model...")
str = generate_code(model,{"-r":args.range, "--hist_bin_num": args.hist_bin_num, "--num_trees": args.num_trees, generate_code(model,{"-r":args.range, "--hist_bin_num": args.hist_bin_num, "--num_trees": args.num_trees,
"--max_tree_depth": args.max_tree_depth, "--num_augmented": args.num_augmented}) "--max_tree_depth": args.max_tree_depth, "--num_augmented": args.num_augmented},
f = open(args.out,"w") use_YML, args.out)
f.write(str)
f.close()
print("Done") print("Done")
#include "opencv2/xphoto.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/core/utility.hpp"
#include "opencv2/imgproc/types_c.h"
const char* keys =
{
"{i || input image name}"
"{o || output image name}"
};
int main( int argc, const char** argv )
{
bool printHelp = ( argc == 1 );
printHelp = printHelp || ( argc == 2 && std::string(argv[1]) == "--help" );
printHelp = printHelp || ( argc == 2 && std::string(argv[1]) == "-h" );
if ( printHelp )
{
printf("\nThis sample demonstrates simple color balance algorithm\n"
"Call:\n"
" simple_color_blance -i=in_image_name [-o=out_image_name]\n\n");
return 0;
}
cv::CommandLineParser parser(argc, argv, keys);
if ( !parser.check() )
{
parser.printErrors();
return -1;
}
std::string inFilename = parser.get<std::string>("i");
std::string outFilename = parser.get<std::string>("o");
cv::Mat src = cv::imread(inFilename, 1);
if ( src.empty() )
{
printf("Cannot read image file: %s\n", inFilename.c_str());
return -1;
}
cv::Mat res(src.size(), src.type());
cv::xphoto::balanceWhite(src, res, cv::xphoto::WHITE_BALANCE_SIMPLE);
if ( outFilename == "" )
{
cv::namedWindow("after white balance", 1);
cv::imshow("after white balance", res);
cv::waitKey(0);
}
else
cv::imwrite(outFilename, res);
return 0;
}
\ No newline at end of file
...@@ -49,6 +49,54 @@ namespace xphoto ...@@ -49,6 +49,54 @@ namespace xphoto
void calculateChannelSums(uint &sumB, uint &sumG, uint &sumR, uchar *src_data, int src_len, float thresh); void calculateChannelSums(uint &sumB, uint &sumG, uint &sumR, uchar *src_data, int src_len, float thresh);
void calculateChannelSums(uint64 &sumB, uint64 &sumG, uint64 &sumR, ushort *src_data, int src_len, float thresh); void calculateChannelSums(uint64 &sumB, uint64 &sumG, uint64 &sumR, ushort *src_data, int src_len, float thresh);
class GrayworldWBImpl : public GrayworldWB
{
private:
float thresh;
public:
GrayworldWBImpl() { thresh = 0.9f; }
float getSaturationThreshold() const { return thresh; }
void setSaturationThreshold(float val) { thresh = val; }
void balanceWhite(InputArray _src, OutputArray _dst)
{
CV_Assert(!_src.empty());
CV_Assert(_src.isContinuous());
CV_Assert(_src.type() == CV_8UC3 || _src.type() == CV_16UC3);
Mat src = _src.getMat();
int N = src.cols * src.rows, N3 = N * 3;
double dsumB = 0.0, dsumG = 0.0, dsumR = 0.0;
if (src.type() == CV_8UC3)
{
uint sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<uchar>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
else if (src.type() == CV_16UC3)
{
uint64 sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<ushort>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
// Find inverse of averages
double max_sum = max(dsumB, max(dsumR, dsumG));
const double eps = 0.1;
float dinvB = dsumB < eps ? 0.f : (float)(max_sum / dsumB),
dinvG = dsumG < eps ? 0.f : (float)(max_sum / dsumG),
dinvR = dsumR < eps ? 0.f : (float)(max_sum / dsumR);
// Use the inverse of averages as channel gains:
applyChannelGains(src, _dst, dinvB, dinvG, dinvR);
}
};
/* Computes sums for each channel, while ignoring saturated pixels which are determined by thresh /* Computes sums for each channel, while ignoring saturated pixels which are determined by thresh
* (version for CV_8UC3) * (version for CV_8UC3)
*/ */
...@@ -297,41 +345,6 @@ void applyChannelGains(InputArray _src, OutputArray _dst, float gainB, float gai ...@@ -297,41 +345,6 @@ void applyChannelGains(InputArray _src, OutputArray _dst, float gainB, float gai
} }
} }
void autowbGrayworld(InputArray _src, OutputArray _dst, float thresh) Ptr<GrayworldWB> createGrayworldWB() { return makePtr<GrayworldWBImpl>(); }
{
Mat src = _src.getMat();
CV_Assert(!src.empty());
CV_Assert(src.isContinuous());
CV_Assert(src.type() == CV_8UC3 || src.type() == CV_16UC3);
int N = src.cols * src.rows, N3 = N * 3;
double dsumB = 0.0, dsumG = 0.0, dsumR = 0.0;
if (src.type() == CV_8UC3)
{
uint sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<uchar>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
else if (src.type() == CV_16UC3)
{
uint64 sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<ushort>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
// Find inverse of averages
double max_sum = max(dsumB, max(dsumR, dsumG));
const double eps = 0.1;
float dinvB = dsumB < eps ? 0.f : (float)(max_sum / dsumB), dinvG = dsumG < eps ? 0.f : (float)(max_sum / dsumG),
dinvR = dsumR < eps ? 0.f : (float)(max_sum / dsumR);
// Use the inverse of averages as channel gains:
applyChannelGains(src, _dst, dinvB, dinvG, dinvR);
}
} }
} }
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
* using the following parameters: * using the following parameters:
--num_trees 20 --hist_bin_num 64 --max_tree_depth 4 --num_augmented 2 -r 0,0 --num_trees 20 --hist_bin_num 64 --max_tree_depth 4 --num_augmented 2 -r 0,0
*/ */
const int num_trees = 20;
const int num_features = 4; const int num_features = 4;
const int num_tree_nodes = 16; const int _num_trees = 20;
unsigned char feature_idx[num_trees * num_features * 2 * (num_tree_nodes - 1)] = { const int _num_tree_nodes = 16;
unsigned char _feature_idx[_num_trees * num_features * 2 * (_num_tree_nodes - 1)] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1,
1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
...@@ -68,7 +68,7 @@ unsigned char feature_idx[num_trees * num_features * 2 * (num_tree_nodes - 1)] = ...@@ -68,7 +68,7 @@ unsigned char feature_idx[num_trees * num_features * 2 * (num_tree_nodes - 1)] =
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1,
1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
float thresh_vals[num_trees * num_features * 2 * (num_tree_nodes - 1)] = { float _thresh_vals[_num_trees * num_features * 2 * (_num_tree_nodes - 1)] = {
.193f, .098f, .455f, .040f, .145f, .316f, .571f, .016f, .058f, .137f, .174f, .276f, .356f, .515f, .730f, .606f, .324f, .193f, .098f, .455f, .040f, .145f, .316f, .571f, .016f, .058f, .137f, .174f, .276f, .356f, .515f, .730f, .606f, .324f,
.794f, .230f, .440f, .683f, .878f, .134f, .282f, .406f, .532f, .036f, .747f, .830f, .931f, .196f, .145f, .363f, .047f, .794f, .230f, .440f, .683f, .878f, .134f, .282f, .406f, .532f, .036f, .747f, .830f, .931f, .196f, .145f, .363f, .047f,
.351f, .279f, .519f, .013f, .887f, .191f, .193f, .361f, .316f, .576f, .445f, .524f, .368f, .752f, .271f, .477f, .636f, .351f, .279f, .519f, .013f, .887f, .191f, .193f, .361f, .316f, .576f, .445f, .524f, .368f, .752f, .271f, .477f, .636f,
...@@ -211,7 +211,7 @@ float thresh_vals[num_trees * num_features * 2 * (num_tree_nodes - 1)] = { ...@@ -211,7 +211,7 @@ float thresh_vals[num_trees * num_features * 2 * (num_tree_nodes - 1)] = {
.550f, .000f, .195f, .377f, .500f, .984f, .000f, .479f, .183f, .704f, .082f, .310f, .567f, .875f, .043f, .141f, .271f, .550f, .000f, .195f, .377f, .500f, .984f, .000f, .479f, .183f, .704f, .082f, .310f, .567f, .875f, .043f, .141f, .271f,
.372f, .511f, .630f, .762f, .896f, .325f, .164f, .602f, .086f, .230f, .414f, .761f, .040f, .131f, .197f, .283f, .352f, .372f, .511f, .630f, .762f, .896f, .325f, .164f, .602f, .086f, .230f, .414f, .761f, .040f, .131f, .197f, .283f, .352f,
.516f, .685f, .855f}; .516f, .685f, .855f};
float leaf_vals[num_trees * num_features * 2 * num_tree_nodes] = { float _leaf_vals[_num_trees * num_features * 2 * _num_tree_nodes] = {
.011f, .029f, .047f, .064f, .075f, .102f, .141f, .172f, .212f, .259f, .308f, .364f, .443f, .497f, .592f, .767f, .069f, .011f, .029f, .047f, .064f, .075f, .102f, .141f, .172f, .212f, .259f, .308f, .364f, .443f, .497f, .592f, .767f, .069f,
.165f, .241f, .278f, .357f, .412f, .463f, .540f, .562f, .623f, .676f, .734f, .797f, .838f, .894f, .944f, .014f, .040f, .165f, .241f, .278f, .357f, .412f, .463f, .540f, .562f, .623f, .676f, .734f, .797f, .838f, .894f, .944f, .014f, .040f,
.061f, .033f, .040f, .160f, .181f, .101f, .123f, .047f, .195f, .282f, .374f, .775f, .248f, .068f, .064f, .155f, .177f, .061f, .033f, .040f, .160f, .181f, .101f, .123f, .047f, .195f, .282f, .374f, .775f, .248f, .068f, .064f, .155f, .177f,
......
This diff is collapsed.
...@@ -7,6 +7,7 @@ namespace cvtest ...@@ -7,6 +7,7 @@ namespace cvtest
cv::String dir = cvtest::TS::ptr()->get_data_path() + "cv/xphoto/simple_white_balance/"; cv::String dir = cvtest::TS::ptr()->get_data_path() + "cv/xphoto/simple_white_balance/";
int nTests = 12; int nTests = 12;
float threshold = 0.005f; float threshold = 0.005f;
cv::Ptr<cv::xphoto::WhiteBalancer> wb = cv::xphoto::createSimpleWB();
for (int i = 0; i < nTests; ++i) for (int i = 0; i < nTests; ++i)
{ {
...@@ -18,7 +19,7 @@ namespace cvtest ...@@ -18,7 +19,7 @@ namespace cvtest
cv::Mat previousResult = cv::imread( previousResultName, 1 ); cv::Mat previousResult = cv::imread( previousResultName, 1 );
cv::Mat currentResult; cv::Mat currentResult;
cv::xphoto::balanceWhite(src, currentResult, cv::xphoto::WHITE_BALANCE_SIMPLE); wb->balanceWhite(src, currentResult);
cv::Mat sqrError = ( currentResult - previousResult ) cv::Mat sqrError = ( currentResult - previousResult )
.mul( currentResult - previousResult ); .mul( currentResult - previousResult );
......
...@@ -69,6 +69,8 @@ namespace cvtest { ...@@ -69,6 +69,8 @@ namespace cvtest {
const int nTests = 14; const int nTests = 14;
const float wb_thresh = 0.5f; const float wb_thresh = 0.5f;
const float acc_thresh = 2.f; const float acc_thresh = 2.f;
Ptr<xphoto::GrayworldWB> wb = xphoto::createGrayworldWB();
wb->setSaturationThreshold(wb_thresh);
for ( int i = 0; i < nTests; ++i ) for ( int i = 0; i < nTests; ++i )
{ {
...@@ -80,13 +82,13 @@ namespace cvtest { ...@@ -80,13 +82,13 @@ namespace cvtest {
ref_autowbGrayworld(src, referenceResult, wb_thresh); ref_autowbGrayworld(src, referenceResult, wb_thresh);
Mat currentResult; Mat currentResult;
xphoto::autowbGrayworld(src, currentResult, wb_thresh); wb->balanceWhite(src, currentResult);
ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh);
// test the 16-bit depth: // test the 16-bit depth:
Mat currentResult_16U, src_16U; Mat currentResult_16U, src_16U;
src.convertTo(src_16U, CV_16UC3, 256.0); src.convertTo(src_16U, CV_16UC3, 256.0);
xphoto::autowbGrayworld(src_16U, currentResult_16U, wb_thresh); wb->balanceWhite(src_16U, currentResult_16U);
currentResult_16U.convertTo(currentResult, CV_8UC3, 1/256.0); currentResult_16U.convertTo(currentResult, CV_8UC3, 1/256.0);
ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh);
} }
......
...@@ -18,7 +18,11 @@ TEST(xphoto_simplefeatures, regression) ...@@ -18,7 +18,11 @@ TEST(xphoto_simplefeatures, regression)
Vec2f ref2(200.0f / (240 + 220 + 200), 220.0f / (240 + 220 + 200)); Vec2f ref2(200.0f / (240 + 220 + 200), 220.0f / (240 + 220 + 200));
vector<Vec2f> dst_features; vector<Vec2f> dst_features;
xphoto::extractSimpleFeatures(test_im, dst_features, 255, 0.98f, 64); Ptr<xphoto::LearningBasedWB> wb = xphoto::createLearningBasedWB();
wb->setRangeMaxVal(255);
wb->setSaturationThreshold(0.98f);
wb->setHistBinNum(64);
wb->extractSimpleFeatures(test_im, dst_features);
ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh);
...@@ -26,7 +30,10 @@ TEST(xphoto_simplefeatures, regression) ...@@ -26,7 +30,10 @@ TEST(xphoto_simplefeatures, regression)
// check 16 bit depth: // check 16 bit depth:
test_im.convertTo(test_im, CV_16U, 256.0); test_im.convertTo(test_im, CV_16U, 256.0);
xphoto::extractSimpleFeatures(test_im, dst_features, 65535, 0.98f, 64); wb->setRangeMaxVal(65535);
wb->setSaturationThreshold(0.98f);
wb->setHistBinNum(128);
wb->extractSimpleFeatures(test_im, dst_features);
ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh);
......
Training the learning-based white balance algorithm {#tutorial_xphoto_training_white_balance}
===================================================
Introduction
------------
Many traditional white balance algorithms are statistics-based, i.e. they rely on the fact that certain assumptions should hold in properly white-balanced images
like the well-known grey-world assumption. However, better results can often be achieved by leveraging large datasets of images with ground-truth
illuminants in a learning-based framework. This tutorial demonstrates how to train a learning-based white balance algorithm and evaluate the quality of the results.
How to train a model
--------------------
-# Download a dataset for training. In this tutorial we will use the [Gehler-Shi dataset ](http://www.cs.sfu.ca/~colour/data/shi_gehler/). Extract all 568 training images
in one folder. A file containing ground-truth illuminant values (real_illum_568..mat) is downloaded separately.
-# We will be using a [Python script ](https://github.com/opencv/opencv_contrib/tree/master/modules/xphoto/samples/learn_color_balance.py) for training.
Call it with the following parameters:
@code
python learn_color_balance.py -i <path to the folder with training images> -g <path to real_illum_568..mat> -r 0,378 --num_trees 30 --max_tree_depth 6 --num_augmented 0
@endcode
This should start training a model on the first 378 images (2/3 of the whole dataset). We set the size of the model to be 30 regression tree pairs per feature and limit
the tree depth to be no more then 6. By default the resulting model will be saved to color_balance_model.yml
-# Use the trained model by passing its path when constructing an instance of LearningBasedWB:
@code{.cpp}
Ptr<xphoto::LearningBasedWB> wb = xphoto::createLearningBasedWB(modelFilename);
@endcode
How to evaluate a model
----------------------
-# We will use a [benchmarking script ](https://github.com/opencv/opencv_contrib/tree/master/modules/xphoto/samples/color_balance_benchmark.py) to compare
the model that we've trained with the classic grey-world algorithm on the remaining 1/3 of the dataset. Call the script with the following parameters:
@code
python color_balance_benchmark.py -a grayworld,learning_based:color_balance_model.yml -m <full path to folder containing the model> -i <path to the folder with training images> -g <path to real_illum_568..mat> -r 379,567 -d "img"
@endcode
-# The objective evaluation results are stored in white_balance_eval_result.html and the resulting white-balanced images are stored in the img folder for a qualitative
comparison of algorithms. Different algorithms are compared in terms of angular error between the estimated and ground-truth illuminants.
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment