Commit 191d5c1d authored by Maksim Shabunin's avatar Maksim Shabunin

Merge pull request #747 from sbokov:color_constancy

parents 3468ae57 75dedf9e
......@@ -58,129 +58,172 @@ namespace xphoto
//! @addtogroup xphoto
//! @{
//! various white balance algorithms
enum WhitebalanceTypes
{
/** perform smart histogram adjustments (ignoring 4% pixels with minimal and maximal
values) for each channel */
WHITE_BALANCE_SIMPLE = 0,
WHITE_BALANCE_GRAYWORLD = 1
};
/** @brief The function implements different algorithm of automatic white balance,
i.e. it tries to map image's white color to perceptual white (this can be violated due to
specific illumination or camera settings).
@param src
@param dst
@param algorithmType see xphoto::WhitebalanceTypes
@param inputMin minimum value in the input image
@param inputMax maximum value in the input image
@param outputMin minimum value in the output image
@param outputMax maximum value in the output image
@sa cvtColor, equalizeHist
/** @brief The base class for auto white balance algorithms.
*/
CV_EXPORTS_W void balanceWhite(const Mat &src, Mat &dst, const int algorithmType,
const float inputMin = 0.0f, const float inputMax = 255.0f,
const float outputMin = 0.0f, const float outputMax = 255.0f);
class CV_EXPORTS_W WhiteBalancer : public Algorithm
{
public:
/** @brief Applies white balancing to the input image
/** @brief Implements a simple grayworld white balance algorithm.
@param src Input image
@param dst White balancing result
@sa cvtColor, equalizeHist
*/
CV_WRAP virtual void balanceWhite(InputArray src, OutputArray dst) = 0;
};
The function autowbGrayworld scales the values of pixels based on a
gray-world assumption which states that the average of all channels
should result in a gray image.
/** @brief A simple white balance algorithm that works by independently stretching
each of the input image channels to the specified range. For increased robustness
it ignores the top and bottom \f$p\%\f$ of pixel values.
*/
class CV_EXPORTS_W SimpleWB : public WhiteBalancer
{
public:
/** @brief Input image range minimum value
@see setInputMin */
CV_WRAP virtual float getInputMin() const = 0;
/** @copybrief getInputMin @see getInputMin */
CV_WRAP virtual void setInputMin(float val) = 0;
/** @brief Input image range maximum value
@see setInputMax */
CV_WRAP virtual float getInputMax() const = 0;
/** @copybrief getInputMax @see getInputMax */
CV_WRAP virtual void setInputMax(float val) = 0;
/** @brief Output image range minimum value
@see setOutputMin */
CV_WRAP virtual float getOutputMin() const = 0;
/** @copybrief getOutputMin @see getOutputMin */
CV_WRAP virtual void setOutputMin(float val) = 0;
/** @brief Output image range maximum value
@see setOutputMax */
CV_WRAP virtual float getOutputMax() const = 0;
/** @copybrief getOutputMax @see getOutputMax */
CV_WRAP virtual void setOutputMax(float val) = 0;
/** @brief Percent of top/bottom values to ignore
@see setP */
CV_WRAP virtual float getP() const = 0;
/** @copybrief getP @see getP */
CV_WRAP virtual void setP(float val) = 0;
};
/** @brief Creates an instance of SimpleWB
*/
CV_EXPORTS_W Ptr<SimpleWB> createSimpleWB();
This function adds a modification which thresholds pixels based on their
saturation value and only uses pixels below the provided threshold in
finding average pixel values.
/** @brief Gray-world white balance algorithm
Saturation is calculated using the following for a 3-channel RGB image per
pixel I and is in the range [0, 1]:
This algorithm scales the values of pixels based on a
gray-world assumption which states that the average of all channels
should result in a gray image.
\f[ \texttt{Saturation} [I] = \frac{\textrm{max}(R,G,B) - \textrm{min}(R,G,B)
}{\textrm{max}(R,G,B)} \f]
It adds a modification which thresholds pixels based on their
saturation value and only uses pixels below the provided threshold in
finding average pixel values.
A threshold of 1 means that all pixels are used to white-balance, while a
threshold of 0 means no pixels are used. Lower thresholds are useful in
white-balancing saturated images.
Saturation is calculated using the following for a 3-channel RGB image per
pixel I and is in the range [0, 1]:
Currently only works on images of type @ref CV_8UC3 and @ref CV_16UC3.
\f[ \texttt{Saturation} [I] = \frac{\textrm{max}(R,G,B) - \textrm{min}(R,G,B)
}{\textrm{max}(R,G,B)} \f]
@param src Input array.
@param dst Output array of the same size and type as src.
@param thresh Maximum saturation for a pixel to be included in the
gray-world assumption.
A threshold of 1 means that all pixels are used to white-balance, while a
threshold of 0 means no pixels are used. Lower thresholds are useful in
white-balancing saturated images.
@sa balanceWhite
Currently supports images of type @ref CV_8UC3 and @ref CV_16UC3.
*/
CV_EXPORTS_W void autowbGrayworld(InputArray src, OutputArray dst,
float thresh = 0.5f);
/** @brief Implements a more sophisticated learning-based automatic color balance algorithm.
As autowbGrayworld, this function works by applying different gains to the input
image channels, but their computation is a bit more involved compared to the
simple grayworld assumption. More details about the algorithm can be found in
@cite Cheng2015 .
class CV_EXPORTS_W GrayworldWB : public WhiteBalancer
{
public:
/** @brief Maximum saturation for a pixel to be included in the
gray-world assumption
@see setSaturationThreshold */
CV_WRAP virtual float getSaturationThreshold() const = 0;
/** @copybrief getSaturationThreshold @see getSaturationThreshold */
CV_WRAP virtual void setSaturationThreshold(float val) = 0;
};
/** @brief Creates an instance of GrayworldWB
*/
CV_EXPORTS_W Ptr<GrayworldWB> createGrayworldWB();
To mask out saturated pixels this function uses only pixels that satisfy the
following condition:
/** @brief More sophisticated learning-based automatic white balance algorithm.
\f[ \frac{\textrm{max}(R,G,B)}{\texttt{range_max_val}} < \texttt{saturation_thresh} \f]
As @ref GrayworldWB, this algorithm works by applying different gains to the input
image channels, but their computation is a bit more involved compared to the
simple gray-world assumption. More details about the algorithm can be found in
@cite Cheng2015 .
Currently supports images of type @ref CV_8UC3 and @ref CV_16UC3.
To mask out saturated pixels this function uses only pixels that satisfy the
following condition:
@param src Input three-channel image in the BGR color space.
@param dst Output image of the same size and type as src.
@param range_max_val Maximum possible value of the input image (e.g. 255 for 8 bit images, 4095 for 12 bit images)
@param saturation_thresh Threshold that is used to determine saturated pixels
@param hist_bin_num Defines the size of one dimension of a three-dimensional RGB histogram that is used internally by
the algorithm. It often makes sense to increase the number of bins for images with higher bit depth (e.g. 256 bins
for a 12 bit image)
\f[ \frac{\textrm{max}(R,G,B)}{\texttt{range_max_val}} < \texttt{saturation_thresh} \f]
@sa autowbGrayworld
Currently supports images of type @ref CV_8UC3 and @ref CV_16UC3.
*/
CV_EXPORTS_W void autowbLearningBased(InputArray src, OutputArray dst, int range_max_val = 255,
float saturation_thresh = 0.98f, int hist_bin_num = 64);
/** @brief Implements the feature extraction part of the learning-based color balance algorithm.
class CV_EXPORTS_W LearningBasedWB : public WhiteBalancer
{
public:
/** @brief Implements the feature extraction part of the algorithm.
In accordance with @cite Cheng2015 , computes the following features for the input image:
1. Chromaticity of an average (R,G,B) tuple
2. Chromaticity of the brightest (R,G,B) tuple (while ignoring saturated pixels)
3. Chromaticity of the dominant (R,G,B) tuple (the one that has the highest value in the RGB histogram)
4. Mode of the chromaticity pallete, that is constructed by taking 300 most common colors according to
4. Mode of the chromaticity palette, that is constructed by taking 300 most common colors according to
the RGB histogram and projecting them on the chromaticity plane. Mode is the most high-density point
of the pallete, which is computed by a straightforward fixed-bandwidth kernel density estimator with
of the palette, which is computed by a straightforward fixed-bandwidth kernel density estimator with
a Epanechnikov kernel function.
@param src Input three-channel image in the BGR color space.
@param src Input three-channel image (BGR color space is assumed).
@param dst An array of four (r,g) chromaticity tuples corresponding to the features listed above.
@param range_max_val Maximum possible value of the input image (e.g. 255 for 8 bit images, 4095 for 12 bit images)
@param saturation_thresh Threshold that is used to determine saturated pixels
@param hist_bin_num Defines the size of one dimension of a three-dimensional RGB histogram that is used internally by
the algorithm. It often makes sense to increase the number of bins for images with higher bit depth (e.g. 256 bins
for a 12 bit image)
@sa autowbLearningBased
*/
CV_EXPORTS_W void extractSimpleFeatures(InputArray src, OutputArray dst, int range_max_val = 255,
float saturation_thresh = 0.98f, int hist_bin_num = 64);
/** @brief Implements an efficient fixed-point approximation for applying channel gains.
@param src Input three-channel image in the BGR color space (either CV_8UC3 or CV_16UC3)
@param dst Output image of the same size and type as src.
@param gainB gain for the B channel
@param gainG gain for the G channel
@param gainR gain for the R channel
@sa autowbGrayworld, autowbLearningBased
CV_WRAP virtual void extractSimpleFeatures(InputArray src, OutputArray dst) = 0;
/** @brief Maximum possible value of the input image (e.g. 255 for 8 bit images,
4095 for 12 bit images)
@see setRangeMaxVal */
CV_WRAP virtual int getRangeMaxVal() const = 0;
/** @copybrief getRangeMaxVal @see getRangeMaxVal */
CV_WRAP virtual void setRangeMaxVal(int val) = 0;
/** @brief Threshold that is used to determine saturated pixels, i.e. pixels where at least one of the
channels exceeds \f$\texttt{saturation_threshold}\times\texttt{range_max_val}\f$ are ignored.
@see setSaturationThreshold */
CV_WRAP virtual float getSaturationThreshold() const = 0;
/** @copybrief getSaturationThreshold @see getSaturationThreshold */
CV_WRAP virtual void setSaturationThreshold(float val) = 0;
/** @brief Defines the size of one dimension of a three-dimensional RGB histogram that is used internally
by the algorithm. It often makes sense to increase the number of bins for images with higher bit depth
(e.g. 256 bins for a 12 bit image).
@see setHistBinNum */
CV_WRAP virtual int getHistBinNum() const = 0;
/** @copybrief getHistBinNum @see getHistBinNum */
CV_WRAP virtual void setHistBinNum(int val) = 0;
};
/** @brief Creates an instance of LearningBasedWB
@param path_to_model Path to a .yml file with the model. If not specified, the default model is used
*/
CV_EXPORTS_W void applyChannelGains(InputArray src, OutputArray dst, float gainB, float gainG, float gainR);
//! @}
CV_EXPORTS_W Ptr<LearningBasedWB> createLearningBasedWB(const String& path_to_model = String());
/** @brief Implements an efficient fixed-point approximation for applying channel gains, which is
the last step of multiple white balance algorithms.
@param src Input three-channel image in the BGR color space (either CV_8UC3 or CV_16UC3)
@param dst Output image of the same size and type as src.
@param gainB gain for the B channel
@param gainG gain for the G channel
@param gainR gain for the R channel
*/
CV_EXPORTS_W void applyChannelGains(InputArray src, OutputArray dst, float gainB, float gainG, float gainR);
//! @}
}
}
......
......@@ -21,8 +21,10 @@ PERF_TEST_P( Size_WBThresh, autowbGrayworld,
Mat dst(size, CV_8UC3);
declare.in(src, WARMUP_RNG).out(dst);
Ptr<xphoto::GrayworldWB> wb = xphoto::createGrayworldWB();
wb->setSaturationThreshold(wb_thresh);
TEST_CYCLE() xphoto::autowbGrayworld(src, dst, wb_thresh);
TEST_CYCLE() wb->balanceWhite(src, dst);
SANITY_CHECK(dst);
}
......
......@@ -65,8 +65,12 @@ PERF_TEST_P(learningBasedWBPerfTest, perf, Combine(SZ_ALL_HD, Values(CV_8UC3, CV
RNG rng(1234);
rng.fill(src_dscl, RNG::UNIFORM, 0, range_max_val);
resize(src_dscl, src, src.size());
Ptr<xphoto::LearningBasedWB> wb = xphoto::createLearningBasedWB();
wb->setRangeMaxVal(range_max_val);
wb->setSaturationThreshold(0.98f);
wb->setHistBinNum(hist_bin_num);
TEST_CYCLE() xphoto::autowbLearningBased(src, dst, range_max_val, 0.98f, hist_bin_num);
TEST_CYCLE() wb->balanceWhite(src, dst);
SANITY_CHECK_NOTHING();
}
......@@ -4,14 +4,16 @@
using namespace cv;
using namespace std;
const char *keys = {"{help h usage ? | | print this message}"
const char *keys = { "{help h usage ? | | print this message}"
"{i | | input image name }"
"{o | | output image name }"};
"{o | | output image name }"
"{a |grayworld| color balance algorithm (simple, grayworld or learning_based)}"
"{m | | path to the model for the learning-based algorithm (optional) }" };
int main(int argc, const char **argv)
{
CommandLineParser parser(argc, argv, keys);
parser.about("OpenCV learning-based color balance demonstration sample");
parser.about("OpenCV color balance demonstration sample");
if (parser.has("help") || argc < 2)
{
parser.printMessage();
......@@ -20,6 +22,8 @@ int main(int argc, const char **argv)
string inFilename = parser.get<string>("i");
string outFilename = parser.get<string>("o");
string algorithm = parser.get<string>("a");
string modelFilename = parser.get<string>("m");
if (!parser.check())
{
......@@ -35,7 +39,20 @@ int main(int argc, const char **argv)
}
Mat res;
xphoto::autowbLearningBased(src, res);
Ptr<xphoto::WhiteBalancer> wb;
if (algorithm == "simple")
wb = xphoto::createSimpleWB();
else if (algorithm == "grayworld")
wb = xphoto::createGrayworldWB();
else if (algorithm == "learning_based")
wb = xphoto::createLearningBasedWB(modelFilename);
else
{
printf("Unsupported algorithm: %s\n", algorithm.c_str());
return -1;
}
wb->balanceWhite(src, res);
if (outFilename == "")
{
......
......@@ -5,6 +5,7 @@ import numpy as np
import scipy.io
import cv2
import timeit
from learn_color_balance import load_ground_truth
def load_json(path):
......@@ -39,15 +40,24 @@ def stretch_to_8bit(arr, clip_percentile = 2.5):
return arr.astype(np.uint8)
def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder):
def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder, model_folder):
new_im = None
start_time = timeit.default_timer()
if algo=="grayworld":
new_im = cv2.xphoto.autowbGrayworld(im, 0.95)
inst = cv2.xphoto.createGrayworldWB()
inst.setSaturationThreshold(0.95)
new_im = inst.balanceWhite(im)
elif algo=="nothing":
new_im = im
elif algo=="learning_based":
new_im = cv2.xphoto.autowbLearningBased(im, None, range_thresh, 0.98, bin_num)
elif algo.split(":")[0]=="learning_based":
model_path = ""
if len(algo.split(":"))>1:
model_path = os.path.join(model_folder, algo.split(":")[1])
inst = cv2.xphoto.createLearningBasedWB(model_path)
inst.setRangeMaxVal(range_thresh)
inst.setSaturationThreshold(0.98)
inst.setHistBinNum(bin_num)
new_im = inst.balanceWhite(im)
elif algo=="GT":
gains = gt_illuminant / min(gt_illuminant)
g1 = float(1.0 / gains[2])
......@@ -59,7 +69,7 @@ def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder):
if len(dst_folder)>0:
if not os.path.exists(dst_folder):
os.makedirs(dst_folder)
im_name = ("%04d_" % i) + algo + ".jpg"
im_name = ("%04d_" % i) + algo.replace(":","_") + ".jpg"
cv2.imwrite(os.path.join(dst_folder, im_name), stretch_to_8bit(new_im))
#recover the illuminant from the color balancing result, assuming the standard model:
......@@ -140,7 +150,9 @@ if __name__ == '__main__':
metavar="ALGORITHMS",
default="",
help=("Comma-separated list of color balance algorithms to evaluate. "
"Currently available: GT,learning_based,grayworld,nothing."))
"Currently available: GT,learning_based,grayworld,nothing. "
"Use a colon to set a specific model for the learning-based "
"algorithm, e.g. learning_based:model1.yml,learning_based:model2.yml"))
parser.add_argument(
"-i",
"--input_folder",
......@@ -196,6 +208,12 @@ if __name__ == '__main__':
default="0,0",
help=("Comma-separated range of images from the dataset to evaluate on (for instance: 0,568). "
"All available images are used by default."))
parser.add_argument(
"-m",
"--model_folder",
metavar="MODEL_FOLDER",
default="",
help=("Path to the folder containing models for the learning-based color balance algorithm (optional)"))
args, other_args = parser.parse_known_args()
if not os.path.exists(args.input_folder):
......@@ -218,22 +236,8 @@ if __name__ == '__main__':
print("Error: Please specify the -r parameter in form <first_image_index>,<last_image_index>")
sys.exit(1)
gt = scipy.io.loadmat(args.ground_truth)
img_files = sorted(os.listdir(args.input_folder))
gt_illuminants = []
black_levels = []
if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys():
#NUS 8-camera dataset format
gt_illuminants = gt["groundtruth_illuminants"]
black_levels = len(gt_illuminants) * [gt["darkness_level"][0][0]]
elif "real_rgb" in gt.keys():
#Gehler-Shi dataset format
gt_illuminants = gt["real_rgb"]
black_levels = 87 * [0] + (len(gt_illuminants) - 87) * [129]
else:
print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported")
sys.exit(1)
(gt_illuminants,black_levels) = load_ground_truth(args.ground_truth)
for algorithm in algorithm_list:
i = 0
......@@ -254,7 +258,7 @@ if __name__ == '__main__':
im = stretch_to_8bit(im)
(time,angular_err) = evaluate(im, algorithm, gt_illuminants[i], i, range_thresh,
256 if range_thresh > 255 else 64, args.dst_folder)
256 if range_thresh > 255 else 64, args.dst_folder, args.model_folder)
state[algorithm][file] = {"angular_error": angular_err, "time": time}
sys.stdout.write("Algorithm: %-20s Done: [%3d/%3d]\r" % (algorithm, i, sz)),
sys.stdout.flush()
......
#include "opencv2/xphoto.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/core/utility.hpp"
using namespace cv;
using namespace std;
const char* keys =
{
"{i || input image name}"
"{o || output image name}"
};
int main( int argc, const char** argv )
{
bool printHelp = ( argc == 1 );
printHelp = printHelp || ( argc == 2 && string(argv[1]) == "--help" );
printHelp = printHelp || ( argc == 2 && string(argv[1]) == "-h" );
if ( printHelp )
{
printf("\nThis sample demonstrates the grayworld balance algorithm\n"
"Call:\n"
" simple_color_blance -i=in_image_name [-o=out_image_name]\n\n");
return 0;
}
CommandLineParser parser(argc, argv, keys);
if ( !parser.check() )
{
parser.printErrors();
return -1;
}
string inFilename = parser.get<string>("i");
string outFilename = parser.get<string>("o");
Mat src = imread(inFilename, 1);
if ( src.empty() )
{
printf("Cannot read image file: %s\n", inFilename.c_str());
return -1;
}
Mat res(src.size(), src.type());
xphoto::autowbGrayworld(src, res);
if ( outFilename == "" )
{
namedWindow("after white balance", 1);
imshow("after white balance", res);
waitKey(0);
}
else
imwrite(outFilename, res);
return 0;
}
......@@ -80,7 +80,7 @@ def get_tree_node_lists(tree, tree_depth):
return (dst_feature_idx, dst_thresh_vals, dst_leaf_vals)
def generate_code(model, input_params):
def generate_code(model, input_params, use_YML, out_file):
feature_idx = []
thresh_vals = []
leaf_vals = []
......@@ -95,31 +95,60 @@ def generate_code(model, input_params):
feature_idx += local_feature_idx
thresh_vals += local_thresh_vals
leaf_vals += local_leaf_vals
if use_YML:
fs = cv2.FileStorage(out_file, 1)
fs.write("num_trees", len(model))
fs.write("num_tree_nodes", 2**depth)
fs.write("feature_idx", np.array(feature_idx).astype(np.uint8))
fs.write("thresh_vals", np.array(thresh_vals).astype(np.float32))
fs.write("leaf_vals", np.array(leaf_vals).astype(np.float32))
fs.release()
else:
res = "/* This file was automatically generated by learn_color_balance.py script\n" +\
" * using the following parameters:\n"
for key in input_params:
res += " " + key + " " + input_params[key]
res += "\n */\n"
res += "const int num_trees = " + str(len(model)) + ";\n"
res += "const int num_features = 4;\n"
res += "const int num_tree_nodes = " + str(2**depth) + ";\n"
res += "const int _num_trees = " + str(len(model)) + ";\n"
res += "const int _num_tree_nodes = " + str(2**depth) + ";\n"
res += "unsigned char feature_idx[num_trees*num_features*2*(num_tree_nodes-1)] = {" + str(feature_idx[0])
res += "unsigned char _feature_idx[_num_trees*num_features*2*(_num_tree_nodes-1)] = {" + str(feature_idx[0])
for i in range(1,len(feature_idx)):
res += "," + str(feature_idx[i])
res += "};\n"
res += "float thresh_vals[num_trees*num_features*2*(num_tree_nodes-1)] = {" + ("%.3ff" % thresh_vals[0])[1:]
res += "float _thresh_vals[_num_trees*num_features*2*(_num_tree_nodes-1)] = {" + ("%.3ff" % thresh_vals[0])[1:]
for i in range(1,len(thresh_vals)):
res += "," + ("%.3ff" % thresh_vals[i])[1:]
res += "};\n"
res += "float leaf_vals[num_trees*num_features*2*num_tree_nodes] = {" + ("%.3ff" % leaf_vals[0])[1:]
res += "float _leaf_vals[_num_trees*num_features*2*_num_tree_nodes] = {" + ("%.3ff" % leaf_vals[0])[1:]
for i in range(1,len(leaf_vals)):
res += "," + ("%.3ff" % leaf_vals[i])[1:]
res += "};\n"
return res
f = open(out_file,"w")
f.write(res)
f.close()
def load_ground_truth(gt_path):
gt = scipy.io.loadmat(gt_path)
base_gt_illuminants = []
black_levels = []
if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys():
#NUS 8-camera dataset format
base_gt_illuminants = gt["groundtruth_illuminants"]
black_levels = len(base_gt_illuminants) * [gt["darkness_level"][0][0]]
elif "real_rgb" in gt.keys():
#Gehler-Shi dataset format
base_gt_illuminants = gt["real_rgb"]
black_levels = 87 * [0] + (len(base_gt_illuminants) - 87) * [129]
else:
print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported")
sys.exit(1)
return (base_gt_illuminants, black_levels)
if __name__ == '__main__':
......@@ -153,8 +182,9 @@ if __name__ == '__main__':
"-o",
"--out",
metavar="OUT",
default="learning_based_color_balance_model.hpp",
help="Path to the output learnt model")
default="color_balance_model.yml",
help="Path to the output learnt model. Either a .yml (for loading during runtime) "
"or .hpp (for compiling with the main code) file ")
parser.add_argument(
"--hist_bin_num",
metavar="HIST_BIN_NUM",
......@@ -196,39 +226,37 @@ if __name__ == '__main__':
print("Error: Please specify the -r parameter in form <first_image_index>,<last_image_index>")
sys.exit(1)
use_YML = None
if args.out.endswith(".yml"):
use_YML = True
elif args.out.endswith(".hpp"):
use_YML = False
else:
print("Error: Only .hpp and .yml are supported as output formats")
sys.exit(1)
hist_bin_num = int(args.hist_bin_num)
num_trees = int(args.num_trees)
max_tree_depth = int(args.max_tree_depth)
gt = scipy.io.loadmat(args.ground_truth)
img_files = sorted(os.listdir(args.input_folder))
base_gt_illuminants = []
black_levels = []
if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys():
#NUS 8-camera dataset format
base_gt_illuminants = gt["groundtruth_illuminants"]
black_levels = len(gt_illuminants) * [gt["darkness_level"][0][0]]
elif "real_rgb" in gt.keys():
#Gehler-Shi dataset format
base_gt_illuminants = gt["real_rgb"]
black_levels = 87 * [0] + (len(base_gt_illuminants) - 87) * [129]
else:
print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported")
sys.exit(1)
(base_gt_illuminants,black_levels) = load_ground_truth(args.ground_truth)
features = []
gt_illuminants = []
i=0
sz = len(img_files)
random.seed(1234)
inst = cv2.xphoto.createLearningBasedWB()
inst.setRangeMaxVal(255)
inst.setSaturationThreshold(0.98)
inst.setHistBinNum(hist_bin_num)
for file in img_files:
if (i>=img_range[0] and i<img_range[1]) or (img_range[0]==img_range[1]==0):
cur_path = os.path.join(args.input_folder,file)
im = cv2.imread(cur_path, -1).astype(np.float32)
im -= black_levels[i]
im_8bit = convert_to_8bit(im)
cur_img_features = cv2.xphoto.extractSimpleFeatures(im_8bit, None, 255, 0.98, hist_bin_num)
cur_img_features = inst.extractSimpleFeatures(im_8bit, None)
features.append(cur_img_features.tolist())
gt_illuminants.append(base_gt_illuminants[i].tolist())
......@@ -241,7 +269,7 @@ if __name__ == '__main__':
im_8bit[:,:,1] *= G_coef
im_8bit[:,:,2] *= R_coef
im_8bit = convert_to_8bit(im)
cur_img_features = cv2.xphoto.extractSimpleFeatures(im_8bit, None, 255, 0.98, hist_bin_num)
cur_img_features = inst.extractSimpleFeatures(im_8bit, None)
features.append(cur_img_features.tolist())
illum = base_gt_illuminants[i]
illum[0] *= R_coef
......@@ -255,10 +283,8 @@ if __name__ == '__main__':
print("\nLearning the model...")
model = learn_regression_tree_ensemble(features, gt_illuminants, num_trees, max_tree_depth)
print("Generating code...")
str = generate_code(model,{"-r":args.range, "--hist_bin_num": args.hist_bin_num, "--num_trees": args.num_trees,
"--max_tree_depth": args.max_tree_depth, "--num_augmented": args.num_augmented})
f = open(args.out,"w")
f.write(str)
f.close()
print("Writing the model...")
generate_code(model,{"-r":args.range, "--hist_bin_num": args.hist_bin_num, "--num_trees": args.num_trees,
"--max_tree_depth": args.max_tree_depth, "--num_augmented": args.num_augmented},
use_YML, args.out)
print("Done")
#include "opencv2/xphoto.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/core/utility.hpp"
#include "opencv2/imgproc/types_c.h"
const char* keys =
{
"{i || input image name}"
"{o || output image name}"
};
int main( int argc, const char** argv )
{
bool printHelp = ( argc == 1 );
printHelp = printHelp || ( argc == 2 && std::string(argv[1]) == "--help" );
printHelp = printHelp || ( argc == 2 && std::string(argv[1]) == "-h" );
if ( printHelp )
{
printf("\nThis sample demonstrates simple color balance algorithm\n"
"Call:\n"
" simple_color_blance -i=in_image_name [-o=out_image_name]\n\n");
return 0;
}
cv::CommandLineParser parser(argc, argv, keys);
if ( !parser.check() )
{
parser.printErrors();
return -1;
}
std::string inFilename = parser.get<std::string>("i");
std::string outFilename = parser.get<std::string>("o");
cv::Mat src = cv::imread(inFilename, 1);
if ( src.empty() )
{
printf("Cannot read image file: %s\n", inFilename.c_str());
return -1;
}
cv::Mat res(src.size(), src.type());
cv::xphoto::balanceWhite(src, res, cv::xphoto::WHITE_BALANCE_SIMPLE);
if ( outFilename == "" )
{
cv::namedWindow("after white balance", 1);
cv::imshow("after white balance", res);
cv::waitKey(0);
}
else
cv::imwrite(outFilename, res);
return 0;
}
\ No newline at end of file
......@@ -49,6 +49,54 @@ namespace xphoto
void calculateChannelSums(uint &sumB, uint &sumG, uint &sumR, uchar *src_data, int src_len, float thresh);
void calculateChannelSums(uint64 &sumB, uint64 &sumG, uint64 &sumR, ushort *src_data, int src_len, float thresh);
class GrayworldWBImpl : public GrayworldWB
{
private:
float thresh;
public:
GrayworldWBImpl() { thresh = 0.9f; }
float getSaturationThreshold() const { return thresh; }
void setSaturationThreshold(float val) { thresh = val; }
void balanceWhite(InputArray _src, OutputArray _dst)
{
CV_Assert(!_src.empty());
CV_Assert(_src.isContinuous());
CV_Assert(_src.type() == CV_8UC3 || _src.type() == CV_16UC3);
Mat src = _src.getMat();
int N = src.cols * src.rows, N3 = N * 3;
double dsumB = 0.0, dsumG = 0.0, dsumR = 0.0;
if (src.type() == CV_8UC3)
{
uint sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<uchar>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
else if (src.type() == CV_16UC3)
{
uint64 sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<ushort>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
// Find inverse of averages
double max_sum = max(dsumB, max(dsumR, dsumG));
const double eps = 0.1;
float dinvB = dsumB < eps ? 0.f : (float)(max_sum / dsumB),
dinvG = dsumG < eps ? 0.f : (float)(max_sum / dsumG),
dinvR = dsumR < eps ? 0.f : (float)(max_sum / dsumR);
// Use the inverse of averages as channel gains:
applyChannelGains(src, _dst, dinvB, dinvG, dinvR);
}
};
/* Computes sums for each channel, while ignoring saturated pixels which are determined by thresh
* (version for CV_8UC3)
*/
......@@ -297,41 +345,6 @@ void applyChannelGains(InputArray _src, OutputArray _dst, float gainB, float gai
}
}
void autowbGrayworld(InputArray _src, OutputArray _dst, float thresh)
{
Mat src = _src.getMat();
CV_Assert(!src.empty());
CV_Assert(src.isContinuous());
CV_Assert(src.type() == CV_8UC3 || src.type() == CV_16UC3);
int N = src.cols * src.rows, N3 = N * 3;
double dsumB = 0.0, dsumG = 0.0, dsumR = 0.0;
if (src.type() == CV_8UC3)
{
uint sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<uchar>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
else if (src.type() == CV_16UC3)
{
uint64 sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<ushort>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
// Find inverse of averages
double max_sum = max(dsumB, max(dsumR, dsumG));
const double eps = 0.1;
float dinvB = dsumB < eps ? 0.f : (float)(max_sum / dsumB), dinvG = dsumG < eps ? 0.f : (float)(max_sum / dsumG),
dinvR = dsumR < eps ? 0.f : (float)(max_sum / dsumR);
// Use the inverse of averages as channel gains:
applyChannelGains(src, _dst, dinvB, dinvG, dinvR);
}
Ptr<GrayworldWB> createGrayworldWB() { return makePtr<GrayworldWBImpl>(); }
}
}
......@@ -64,56 +64,115 @@ struct hist_elem
hist_elem(float _hist_val, Vec2f chromaticity) : hist_val(_hist_val), r(chromaticity[0]), g(chromaticity[1]) {}
};
bool operator<(const hist_elem &a, const hist_elem &b);
void getColorPalleteMode(Vec2f &dst, hist_elem *pallete, int pallete_sz, float bandwidth);
void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val, float saturation_thresh);
void getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, Vec2f &brightest_chromaticity, Mat &src,
Mat &mask);
void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity_pallete_mode, Mat &src, Mat &mask,
int hist_bin_num, int max_val);
bool operator<(const hist_elem &a, const hist_elem &b) { return a.hist_val > b.hist_val; }
/* Returns the most high-density point (i.e. mode) of the color pallete.
* Uses a simplistic kernel density estimator with a Epanechnikov kernel and
* fixed bandwidth.
*/
void getColorPalleteMode(Vec2f &dst, hist_elem *pallete, int pallete_sz, float bandwidth)
class LearningBasedWBImpl : public LearningBasedWB
{
float max_density = -1.0f;
float denom = bandwidth * bandwidth;
for (int i = 0; i < pallete_sz; i++)
private:
int range_max_val, hist_bin_num, palette_size;
float saturation_thresh, palette_bandwidth, prediction_thresh;
int num_trees, num_tree_nodes, tree_depth;
uchar *feature_idx;
float *thresh_vals, *leaf_vals;
Mat feature_idx_Mat, thresh_vals_Mat, leaf_vals_Mat;
Mat mask;
int src_max_val;
void preprocessing(Mat &src);
void getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, Vec2f &brightest_chromaticity, Mat &src);
void getColorPaletteMode(Vec2f &dst, hist_elem *palette);
void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity_palette_mode, Mat &src);
float regressionTreePredict(Vec2f src, uchar *tree_feature_idx, float *tree_thresh_vals, float *tree_leaf_vals);
Vec2f predictIlluminant(vector<Vec2f> features);
public:
LearningBasedWBImpl(String path_to_model)
{
range_max_val = 255;
saturation_thresh = 0.98f;
hist_bin_num = 64;
palette_size = 300;
palette_bandwidth = 0.1f;
prediction_thresh = 0.025f;
if (path_to_model.empty())
{
/* use the default model */
num_trees = _num_trees;
num_tree_nodes = _num_tree_nodes;
feature_idx = _feature_idx;
thresh_vals = _thresh_vals;
leaf_vals = _leaf_vals;
}
else
{
float cur_density = 0.0f;
float cur_dist_sq;
/* load model from file */
FileStorage fs(path_to_model, 0);
num_trees = fs["num_trees"];
num_tree_nodes = fs["num_tree_nodes"];
fs["feature_idx"] >> feature_idx_Mat;
fs["thresh_vals"] >> thresh_vals_Mat;
fs["leaf_vals"] >> leaf_vals_Mat;
feature_idx = feature_idx_Mat.ptr<uchar>();
thresh_vals = thresh_vals_Mat.ptr<float>();
leaf_vals = leaf_vals_Mat.ptr<float>();
}
}
int getRangeMaxVal() const { return range_max_val; }
void setRangeMaxVal(int val) { range_max_val = val; }
float getSaturationThreshold() const { return saturation_thresh; }
void setSaturationThreshold(float val) { saturation_thresh = val; }
for (int j = 0; j < pallete_sz; j++)
int getHistBinNum() const { return hist_bin_num; }
void setHistBinNum(int val) { hist_bin_num = val; }
void extractSimpleFeatures(InputArray _src, OutputArray _dst)
{
cur_dist_sq = (pallete[i].r - pallete[j].r) * (pallete[i].r - pallete[j].r) +
(pallete[i].g - pallete[j].g) * (pallete[i].g - pallete[j].g);
cur_density += max((1.0f - (cur_dist_sq / denom)), 0.0f);
CV_Assert(!_src.empty());
CV_Assert(_src.isContinuous());
CV_Assert(_src.type() == CV_8UC3 || _src.type() == CV_16UC3);
Mat src = _src.getMat();
vector<Vec2f> dst(num_features);
preprocessing(src);
getAverageAndBrightestColorChromaticity(dst[0], dst[1], src);
getHistogramBasedFeatures(dst[2], dst[3], src);
Mat(dst).convertTo(_dst, CV_32F);
}
if (cur_density > max_density)
void balanceWhite(InputArray _src, OutputArray _dst)
{
max_density = cur_density;
dst[0] = pallete[i].r;
dst[1] = pallete[i].g;
}
CV_Assert(!_src.empty());
CV_Assert(_src.isContinuous());
CV_Assert(_src.type() == CV_8UC3 || _src.type() == CV_16UC3);
Mat src = _src.getMat();
vector<Vec2f> features;
extractSimpleFeatures(src, features);
Vec2f illuminant = predictIlluminant(features);
float denom = 1 - illuminant[0] - illuminant[1];
float gainB = 1.0f;
float gainG = denom / illuminant[1];
float gainR = denom / illuminant[0];
applyChannelGains(src, _dst, gainB, gainG, gainR);
}
}
};
/* Computes a mask for non-saturated pixels and maximum pixel value
* which are then used for feature computation
*/
void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val, float saturation_thresh)
void LearningBasedWBImpl::preprocessing(Mat &src)
{
dst_mask = Mat(src.size(), CV_8U);
uchar *mask_ptr = dst_mask.ptr<uchar>();
mask.create(src.size(), CV_8U);
uchar *mask_ptr = mask.ptr<uchar>();
int src_len = src.rows * src.cols;
int thresh = (int)(saturation_thresh * range_max_val);
int i = 0;
int local_max;
dst_max_val = -1;
src_max_val = -1;
if (src.type() == CV_8UC3)
{
......@@ -133,15 +192,15 @@ void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val,
v_store(global_max, v_global_max);
for (int j = 0; j < 16; j++)
{
if (global_max[j] > dst_max_val)
dst_max_val = global_max[j];
if (global_max[j] > src_max_val)
src_max_val = global_max[j];
}
#endif
for (; i < src_len; i++)
{
local_max = max(src_ptr[3 * i], max(src_ptr[3 * i + 1], src_ptr[3 * i + 2]));
if (local_max > dst_max_val)
dst_max_val = local_max;
if (local_max > src_max_val)
src_max_val = local_max;
if (local_max < thresh)
mask_ptr[i] = 255;
else
......@@ -166,15 +225,15 @@ void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val,
v_store(global_max, v_global_max);
for (int j = 0; j < 8; j++)
{
if (global_max[j] > dst_max_val)
dst_max_val = global_max[j];
if (global_max[j] > src_max_val)
src_max_val = global_max[j];
}
#endif
for (; i < src_len; i++)
{
local_max = max(src_ptr[3 * i], max(src_ptr[3 * i + 1], src_ptr[3 * i + 2]));
if (local_max > dst_max_val)
dst_max_val = local_max;
if (local_max > src_max_val)
src_max_val = local_max;
if (local_max < thresh)
mask_ptr[i] = 255;
else
......@@ -183,8 +242,8 @@ void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val,
}
}
void getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, Vec2f &brightest_chromaticity, Mat &src,
Mat &mask)
void LearningBasedWBImpl::getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity,
Vec2f &brightest_chromaticity, Mat &src)
{
int i = 0;
int src_len = src.rows * src.cols;
......@@ -376,15 +435,42 @@ void getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, Vec2f
}
}
void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity_pallete_mode, Mat &src, Mat &mask,
int hist_bin_num, int max_val)
/* Returns the most high-density point (i.e. mode) of the color palette.
* Uses a simplistic kernel density estimator with a Epanechnikov kernel and
* fixed bandwidth.
*/
void LearningBasedWBImpl::getColorPaletteMode(Vec2f &dst, hist_elem *palette)
{
float max_density = -1.0f;
float denom = palette_bandwidth * palette_bandwidth;
for (int i = 0; i < palette_size; i++)
{
float cur_density = 0.0f;
float cur_dist_sq;
for (int j = 0; j < palette_size; j++)
{
cur_dist_sq = (palette[i].r - palette[j].r) * (palette[i].r - palette[j].r) +
(palette[i].g - palette[j].g) * (palette[i].g - palette[j].g);
cur_density += max((1.0f - (cur_dist_sq / denom)), 0.0f);
}
if (cur_density > max_density)
{
max_density = cur_density;
dst[0] = palette[i].r;
dst[1] = palette[i].g;
}
}
}
void LearningBasedWBImpl::getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity_palette_mode,
Mat &src)
{
const int pallete_size = 300;
const float pallete_bandwidth = 0.1f;
MatND hist;
int channels[] = {0, 1, 2};
int histSize[] = {hist_bin_num, hist_bin_num, hist_bin_num};
float range[] = {0, (float)max(hist_bin_num, max_val)};
float range[] = {0, (float)max(hist_bin_num, src_max_val)};
const float *ranges[] = {range, range, range};
calcHist(&src, 1, channels, mask, hist, 3, histSize, ranges);
......@@ -406,10 +492,10 @@ void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity
}
getChromaticity(dominant_chromaticity, (float)dominant_R, (float)dominant_G, (float)dominant_B);
vector<hist_elem> pallete;
pallete.reserve(pallete_size);
vector<hist_elem> palette;
palette.reserve(palette_size);
hist_ptr = hist.ptr<float>();
// extract top pallete_size most common colors and add them to the pallete:
// extract top palette_size most common colors and add them to the palette:
for (int i = 0; i < hist_bin_num; i++)
for (int j = 0; j < hist_bin_num; j++)
for (int k = 0; k < hist_bin_num; k++)
......@@ -424,45 +510,28 @@ void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity
getChromaticity(chromaticity, (float)k, (float)j, (float)i);
hist_elem el(bin_count, chromaticity);
if (pallete.size() < pallete_size)
if (palette.size() < (uint)palette_size)
{
pallete.push_back(el);
if (pallete.size() == pallete_size)
make_heap(pallete.begin(), pallete.end());
palette.push_back(el);
if (palette.size() == (uint)palette_size)
make_heap(palette.begin(), palette.end());
}
else if (bin_count > pallete.front().hist_val)
else if (bin_count > palette.front().hist_val)
{
pop_heap(pallete.begin(), pallete.end());
pallete.back() = el;
push_heap(pallete.begin(), pallete.end());
pop_heap(palette.begin(), palette.end());
palette.back() = el;
push_heap(palette.begin(), palette.end());
}
hist_ptr++;
}
getColorPalleteMode(chromaticity_pallete_mode, (hist_elem *)(&pallete[0]), (int)pallete.size(), pallete_bandwidth);
getColorPaletteMode(chromaticity_palette_mode, (hist_elem *)(&palette[0]));
}
void extractSimpleFeatures(InputArray _src, OutputArray _dst, int range_max_val, float saturation_thresh,
int hist_bin_num)
{
Mat src = _src.getMat();
CV_Assert(!src.empty());
CV_Assert(src.isContinuous());
CV_Assert(src.type() == CV_8UC3 || src.type() == CV_16UC3);
vector<Vec2f> dst(num_features);
Mat mask;
int max_val = 0;
preprocessing(mask, max_val, src, range_max_val, saturation_thresh);
getAverageAndBrightestColorChromaticity(dst[0], dst[1], src, mask);
getHistogramBasedFeatures(dst[2], dst[3], src, mask, hist_bin_num, max_val);
Mat(dst).convertTo(_dst, CV_32F);
}
inline float regressionTreePredict(Vec2f src, uchar *tree_feature_idx, float *tree_thresh_vals, float *tree_leaf_vals)
float LearningBasedWBImpl::regressionTreePredict(Vec2f src, uchar *tree_feature_idx, float *tree_thresh_vals,
float *tree_leaf_vals)
{
int node_idx = 0;
int depth = (int)round(log(num_tree_nodes) / log(2));
for (int i = 0; i < depth; i++)
for (int i = 0; i < tree_depth; i++)
{
if (src[tree_feature_idx[node_idx]] <= tree_thresh_vals[node_idx])
node_idx = 2 * node_idx + 1;
......@@ -472,22 +541,14 @@ inline float regressionTreePredict(Vec2f src, uchar *tree_feature_idx, float *tr
return tree_leaf_vals[node_idx - num_tree_nodes + 1];
}
void autowbLearningBased(InputArray _src, OutputArray _dst, int range_max_val, float saturation_thresh,
int hist_bin_num)
Vec2f LearningBasedWBImpl::predictIlluminant(vector<Vec2f> features)
{
const float prediction_thresh = 0.025f;
Mat src = _src.getMat();
CV_Assert(!src.empty());
CV_Assert(src.isContinuous());
CV_Assert(src.type() == CV_8UC3 || src.type() == CV_16UC3);
vector<Vec2f> features;
extractSimpleFeatures(src, features, range_max_val, saturation_thresh, hist_bin_num);
int feature_model_size = 2 * (num_tree_nodes - 1);
int local_model_size = num_features * feature_model_size;
int feature_model_size_leaf = 2 * num_tree_nodes;
int local_model_size_leaf = num_features * feature_model_size_leaf;
tree_depth = (int)round(log(num_tree_nodes) / log(2));
vector<float> consensus_r, consensus_g;
vector<float> all_r, all_g;
for (int i = 0; i < num_trees; i++)
......@@ -538,12 +599,13 @@ void autowbLearningBased(InputArray _src, OutputArray _dst, int range_max_val, f
nth_element(consensus_g.begin(), consensus_g.begin() + consensus_g.size() / 2, consensus_g.end());
illuminant_g = consensus_g[consensus_g.size() / 2];
}
return Vec2f(illuminant_r, illuminant_g);
}
float denom = 1 - illuminant_r - illuminant_g;
float gainB = 1.0f;
float gainG = denom / illuminant_g;
float gainR = denom / illuminant_r;
applyChannelGains(src, _dst, gainB, gainG, gainR);
Ptr<LearningBasedWB> createLearningBasedWB(const String& path_to_model)
{
Ptr<LearningBasedWB> inst = makePtr<LearningBasedWBImpl>(path_to_model);
return inst;
}
}
}
......@@ -2,10 +2,10 @@
* using the following parameters:
--num_trees 20 --hist_bin_num 64 --max_tree_depth 4 --num_augmented 2 -r 0,0
*/
const int num_trees = 20;
const int num_features = 4;
const int num_tree_nodes = 16;
unsigned char feature_idx[num_trees * num_features * 2 * (num_tree_nodes - 1)] = {
const int _num_trees = 20;
const int _num_tree_nodes = 16;
unsigned char _feature_idx[_num_trees * num_features * 2 * (_num_tree_nodes - 1)] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1,
1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
......@@ -68,7 +68,7 @@ unsigned char feature_idx[num_trees * num_features * 2 * (num_tree_nodes - 1)] =
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1,
1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
float thresh_vals[num_trees * num_features * 2 * (num_tree_nodes - 1)] = {
float _thresh_vals[_num_trees * num_features * 2 * (_num_tree_nodes - 1)] = {
.193f, .098f, .455f, .040f, .145f, .316f, .571f, .016f, .058f, .137f, .174f, .276f, .356f, .515f, .730f, .606f, .324f,
.794f, .230f, .440f, .683f, .878f, .134f, .282f, .406f, .532f, .036f, .747f, .830f, .931f, .196f, .145f, .363f, .047f,
.351f, .279f, .519f, .013f, .887f, .191f, .193f, .361f, .316f, .576f, .445f, .524f, .368f, .752f, .271f, .477f, .636f,
......@@ -211,7 +211,7 @@ float thresh_vals[num_trees * num_features * 2 * (num_tree_nodes - 1)] = {
.550f, .000f, .195f, .377f, .500f, .984f, .000f, .479f, .183f, .704f, .082f, .310f, .567f, .875f, .043f, .141f, .271f,
.372f, .511f, .630f, .762f, .896f, .325f, .164f, .602f, .086f, .230f, .414f, .761f, .040f, .131f, .197f, .283f, .352f,
.516f, .685f, .855f};
float leaf_vals[num_trees * num_features * 2 * num_tree_nodes] = {
float _leaf_vals[_num_trees * num_features * 2 * _num_tree_nodes] = {
.011f, .029f, .047f, .064f, .075f, .102f, .141f, .172f, .212f, .259f, .308f, .364f, .443f, .497f, .592f, .767f, .069f,
.165f, .241f, .278f, .357f, .412f, .463f, .540f, .562f, .623f, .676f, .734f, .797f, .838f, .894f, .944f, .014f, .040f,
.061f, .033f, .040f, .160f, .181f, .101f, .123f, .047f, .195f, .282f, .374f, .775f, .248f, .068f, .064f, .155f, .177f,
......
......@@ -37,50 +37,39 @@
//
//M*/
#include <vector>
#include <algorithm>
#include <iterator>
#include <iostream>
#include "opencv2/xphoto.hpp"
#include "opencv2/imgproc.hpp"
#include <iterator>
#include <vector>
#include "opencv2/core.hpp"
#include "opencv2/core/core_c.h"
#include "opencv2/core/types.hpp"
#include "opencv2/core/types_c.h"
#include "opencv2/imgproc.hpp"
#include "opencv2/xphoto.hpp"
namespace cv
{
namespace xphoto
{
template <typename T>
void balanceWhite(std::vector < Mat_<T> > &src, Mat &dst,
const float inputMin, const float inputMax,
const float outputMin, const float outputMax, const int algorithmType)
{
switch ( algorithmType )
{
case WHITE_BALANCE_SIMPLE:
{
template <typename T>
void balanceWhiteSimple(std::vector<Mat_<T> > &src, Mat &dst, const float inputMin, const float inputMax,
const float outputMin, const float outputMax, const float p)
{
/********************* Simple white balance *********************/
float s1 = 2.0f; // low quantile
float s2 = 2.0f; // high quantile
const float s1 = p; // low quantile
const float s2 = p; // high quantile
int depth = 2; // depth of histogram tree
if (src[0].depth() != CV_8U)
++depth;
int bins = 16; // number of bins at each histogram level
int nElements = int( pow((float)bins, (float)depth) );
int nElements = int(pow((float)bins, (float)depth));
// number of elements in histogram tree
for (size_t i = 0; i < src.size(); ++i)
{
std::vector <int> hist(nElements, 0);
std::vector<int> hist(nElements, 0);
typename Mat_<T>::iterator beginIt = src[i].begin();
typename Mat_<T>::iterator endIt = src[i].end();
......@@ -97,19 +86,19 @@ namespace xphoto
for (int j = 0; j < depth; ++j)
{
int currentBin = int( (val - minValue + 1e-4f) / interval );
int currentBin = int((val - minValue + 1e-4f) / interval);
++hist[pos + currentBin];
pos = (pos + currentBin)*bins;
pos = (pos + currentBin) * bins;
minValue = minValue + currentBin*interval;
minValue = minValue + currentBin * interval;
maxValue = minValue + interval;
interval /= bins;
}
}
int total = int( src[i].total() );
int total = int(src[i].total());
int p1 = 0, p2 = bins - 1;
int n1 = 0, n2 = total;
......@@ -134,78 +123,90 @@ namespace xphoto
n2 -= hist[p2--];
maxValue -= interval;
}
p2 = p2*bins - 1;
p2 = p2 * bins - 1;
interval /= bins;
}
src[i] = (outputMax - outputMin) * (src[i] - minValue)
/ (maxValue - minValue) + outputMin;
src[i] = (outputMax - outputMin) * (src[i] - minValue) / (maxValue - minValue) + outputMin;
}
/****************************************************************/
break;
}
default:
CV_Error_( CV_StsNotImplemented,
("Unsupported algorithm type (=%d)", algorithmType) );
}
dst.create(/**/ src[0].size(), CV_MAKETYPE( src[0].depth(), int( src.size() ) ) /**/);
dst.create(/**/ src[0].size(), CV_MAKETYPE(src[0].depth(), int(src.size())) /**/);
cv::merge(src, dst);
}
class SimpleWBImpl : public SimpleWB
{
private:
float inputMin, inputMax, outputMin, outputMax, p;
public:
SimpleWBImpl()
{
inputMin = 0.0f;
inputMax = 255.0f;
outputMin = 0.0f;
outputMax = 255.0f;
p = 2.0f;
}
/*!
* Wrappers over different white balance algorithm
*
* \param src : source image (RGB)
* \param dst : destination image
*
* \param inputMin : minimum input value
* \param inputMax : maximum input value
* \param outputMin : minimum output value
* \param outputMax : maximum output value
*
* \param algorithmType : type of the algorithm to use
*/
void balanceWhite(const Mat &src, Mat &dst, const int algorithmType,
const float inputMin, const float inputMax,
const float outputMin, const float outputMax)
float getInputMin() const { return inputMin; }
void setInputMin(float val) { inputMin = val; }
float getInputMax() const { return inputMax; }
void setInputMax(float val) { inputMax = val; }
float getOutputMin() const { return outputMin; }
void setOutputMin(float val) { outputMin = val; }
float getOutputMax() const { return outputMax; }
void setOutputMax(float val) { outputMax = val; }
float getP() const { return p; }
void setP(float val) { p = val; }
void balanceWhite(InputArray _src, OutputArray _dst)
{
switch ( src.depth() )
CV_Assert(!_src.empty());
CV_Assert(_src.depth() == CV_8U || _src.depth() == CV_16S || _src.depth() == CV_32S || _src.depth() == CV_32F);
Mat src = _src.getMat();
Mat &dst = _dst.getMatRef();
switch (src.depth())
{
case CV_8U:
{
std::vector < Mat_<uchar> > mv;
std::vector<Mat_<uchar> > mv;
split(src, mv);
balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType);
balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p);
break;
}
case CV_16S:
{
std::vector < Mat_<short> > mv;
std::vector<Mat_<short> > mv;
split(src, mv);
balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType);
balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p);
break;
}
case CV_32S:
{
std::vector < Mat_<int> > mv;
std::vector<Mat_<int> > mv;
split(src, mv);
balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType);
balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p);
break;
}
case CV_32F:
{
std::vector < Mat_<float> > mv;
std::vector<Mat_<float> > mv;
split(src, mv);
balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType);
balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p);
break;
}
default:
CV_Error_( CV_StsNotImplemented,
("Unsupported source image format (=%d)", src.type()) );
break;
}
}
};
Ptr<SimpleWB> createSimpleWB() { return makePtr<SimpleWBImpl>(); }
}
}
......@@ -7,6 +7,7 @@ namespace cvtest
cv::String dir = cvtest::TS::ptr()->get_data_path() + "cv/xphoto/simple_white_balance/";
int nTests = 12;
float threshold = 0.005f;
cv::Ptr<cv::xphoto::WhiteBalancer> wb = cv::xphoto::createSimpleWB();
for (int i = 0; i < nTests; ++i)
{
......@@ -18,7 +19,7 @@ namespace cvtest
cv::Mat previousResult = cv::imread( previousResultName, 1 );
cv::Mat currentResult;
cv::xphoto::balanceWhite(src, currentResult, cv::xphoto::WHITE_BALANCE_SIMPLE);
wb->balanceWhite(src, currentResult);
cv::Mat sqrError = ( currentResult - previousResult )
.mul( currentResult - previousResult );
......
......@@ -69,6 +69,8 @@ namespace cvtest {
const int nTests = 14;
const float wb_thresh = 0.5f;
const float acc_thresh = 2.f;
Ptr<xphoto::GrayworldWB> wb = xphoto::createGrayworldWB();
wb->setSaturationThreshold(wb_thresh);
for ( int i = 0; i < nTests; ++i )
{
......@@ -80,13 +82,13 @@ namespace cvtest {
ref_autowbGrayworld(src, referenceResult, wb_thresh);
Mat currentResult;
xphoto::autowbGrayworld(src, currentResult, wb_thresh);
wb->balanceWhite(src, currentResult);
ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh);
// test the 16-bit depth:
Mat currentResult_16U, src_16U;
src.convertTo(src_16U, CV_16UC3, 256.0);
xphoto::autowbGrayworld(src_16U, currentResult_16U, wb_thresh);
wb->balanceWhite(src_16U, currentResult_16U);
currentResult_16U.convertTo(currentResult, CV_8UC3, 1/256.0);
ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh);
}
......
......@@ -18,7 +18,11 @@ TEST(xphoto_simplefeatures, regression)
Vec2f ref2(200.0f / (240 + 220 + 200), 220.0f / (240 + 220 + 200));
vector<Vec2f> dst_features;
xphoto::extractSimpleFeatures(test_im, dst_features, 255, 0.98f, 64);
Ptr<xphoto::LearningBasedWB> wb = xphoto::createLearningBasedWB();
wb->setRangeMaxVal(255);
wb->setSaturationThreshold(0.98f);
wb->setHistBinNum(64);
wb->extractSimpleFeatures(test_im, dst_features);
ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh);
......@@ -26,7 +30,10 @@ TEST(xphoto_simplefeatures, regression)
// check 16 bit depth:
test_im.convertTo(test_im, CV_16U, 256.0);
xphoto::extractSimpleFeatures(test_im, dst_features, 65535, 0.98f, 64);
wb->setRangeMaxVal(65535);
wb->setSaturationThreshold(0.98f);
wb->setHistBinNum(128);
wb->extractSimpleFeatures(test_im, dst_features);
ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh);
......
Training the learning-based white balance algorithm {#tutorial_xphoto_training_white_balance}
===================================================
Introduction
------------
Many traditional white balance algorithms are statistics-based, i.e. they rely on the fact that certain assumptions should hold in properly white-balanced images
like the well-known grey-world assumption. However, better results can often be achieved by leveraging large datasets of images with ground-truth
illuminants in a learning-based framework. This tutorial demonstrates how to train a learning-based white balance algorithm and evaluate the quality of the results.
How to train a model
--------------------
-# Download a dataset for training. In this tutorial we will use the [Gehler-Shi dataset ](http://www.cs.sfu.ca/~colour/data/shi_gehler/). Extract all 568 training images
in one folder. A file containing ground-truth illuminant values (real_illum_568..mat) is downloaded separately.
-# We will be using a [Python script ](https://github.com/opencv/opencv_contrib/tree/master/modules/xphoto/samples/learn_color_balance.py) for training.
Call it with the following parameters:
@code
python learn_color_balance.py -i <path to the folder with training images> -g <path to real_illum_568..mat> -r 0,378 --num_trees 30 --max_tree_depth 6 --num_augmented 0
@endcode
This should start training a model on the first 378 images (2/3 of the whole dataset). We set the size of the model to be 30 regression tree pairs per feature and limit
the tree depth to be no more then 6. By default the resulting model will be saved to color_balance_model.yml
-# Use the trained model by passing its path when constructing an instance of LearningBasedWB:
@code{.cpp}
Ptr<xphoto::LearningBasedWB> wb = xphoto::createLearningBasedWB(modelFilename);
@endcode
How to evaluate a model
----------------------
-# We will use a [benchmarking script ](https://github.com/opencv/opencv_contrib/tree/master/modules/xphoto/samples/color_balance_benchmark.py) to compare
the model that we've trained with the classic grey-world algorithm on the remaining 1/3 of the dataset. Call the script with the following parameters:
@code
python color_balance_benchmark.py -a grayworld,learning_based:color_balance_model.yml -m <full path to folder containing the model> -i <path to the folder with training images> -g <path to real_illum_568..mat> -r 379,567 -d "img"
@endcode
-# The objective evaluation results are stored in white_balance_eval_result.html and the resulting white-balanced images are stored in the img folder for a qualitative
comparison of algorithms. Different algorithms are compared in terms of angular error between the estimated and ground-truth illuminants.
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment