Commit 56ea364c authored by Vlad Shakhuro's avatar Vlad Shakhuro

Separate WaldBost and Stump

parent 3e0c58c5
#include "stump.hpp"
namespace cv
{
namespace adas
{
/* Cumulative sum by rows */
static void cumsum(const Mat_<float>& src, Mat_<float> dst)
{
CV_Assert(src.cols > 0);
for( int row = 0; row < src.rows; ++row )
{
dst(row, 0) = src(row, 0);
for( int col = 1; col < src.cols; ++col )
{
dst(row, col) = dst(row, col - 1) + src(row, col);
}
}
}
int Stump::train(const Mat& data, const Mat& labels, const Mat& weights)
{
CV_Assert(labels.rows == 1 && labels.cols == data.cols);
CV_Assert(weights.rows == 1 && weights.cols == data.cols);
/* Assert that data and labels have int type */
/* Assert that weights have float type */
/* Prepare labels for each feature rearranged according to sorted order */
Mat sorted_labels(data.rows, data.cols, labels.type());
Mat sorted_weights(data.rows, data.cols, weights.type());
Mat indices;
sortIdx(data, indices, cv::SORT_EVERY_ROW | cv::SORT_ASCENDING);
for( int row = 0; row < indices.rows; ++row )
{
for( int col = 0; col < indices.cols; ++col )
{
sorted_labels.at<int>(row, col) =
labels.at<int>(0, indices.at<int>(row, col));
sorted_weights.at<float>(row, col) =
weights.at<float>(0, indices.at<float>(row, col));
}
}
/* Sort feature values */
Mat sorted_data(data.rows, data.cols, data.type());
sort(data, sorted_data, cv::SORT_EVERY_ROW | cv::SORT_ASCENDING);
/* Split positive and negative weights */
Mat pos_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols,
sorted_weights.type());
Mat neg_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols,
sorted_weights.type());
for( int row = 0; row < data.rows; ++row )
{
for( int col = 0; col < data.cols; ++col )
{
if( sorted_labels.at<int>(row, col) == +1 )
{
pos_weights.at<float>(row, col) =
sorted_weights.at<float>(row, col);
}
else
{
neg_weights.at<float>(row, col) =
sorted_weights.at<float>(row, col);
}
}
}
/* Compute cumulative sums for fast stump error computation */
Mat pos_cum_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols,
sorted_weights.type());
Mat neg_cum_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols,
sorted_weights.type());
cumsum(pos_weights, pos_cum_weights);
cumsum(neg_weights, neg_cum_weights);
/* Compute total weights of positive and negative samples */
float pos_total_weight = pos_cum_weights.at<float>(0, weights.cols - 1);
float neg_total_weight = neg_cum_weights.at<float>(0, weights.cols - 1);
float eps = 1. / 4 * labels.cols;
/* Compute minimal error */
float min_err = FLT_MAX;
int min_row = -1;
int min_col = -1;
int min_polarity = 0;
float min_pos_value = 1, min_neg_value = -1;
for( int row = 0; row < sorted_weights.rows; ++row )
{
for( int col = 0; col < sorted_weights.cols - 1; ++col )
{
float err, h_pos, h_neg;
// Direct polarity
float pos_wrong = pos_cum_weights.at<float>(row, col);
float pos_right = pos_total_weight - pos_wrong;
float neg_right = neg_cum_weights.at<float>(row, col);
float neg_wrong = neg_total_weight - neg_right;
h_pos = .5 * log((pos_right + eps) / (pos_wrong + eps));
h_neg = .5 * log((neg_wrong + eps) / (neg_right + eps));
err = sqrt(pos_right * neg_wrong) + sqrt(pos_wrong * neg_right);
if( err < min_err )
{
min_err = err;
min_row = row;
min_col = col;
min_polarity = +1;
min_pos_value = h_pos;
min_neg_value = h_neg;
}
// Opposite polarity
swap(pos_right, pos_wrong);
swap(neg_right, neg_wrong);
h_pos = -h_pos;
h_neg = -h_neg;
err = sqrt(pos_right * neg_wrong) + sqrt(pos_wrong * neg_right);
if( err < min_err )
{
min_err = err;
min_row = row;
min_col = col;
min_polarity = -1;
min_pos_value = h_pos;
min_neg_value = h_neg;
}
}
}
/* Compute threshold, store found values in fields */
threshold_ = ( sorted_data.at<int>(min_row, min_col) +
sorted_data.at<int>(min_row, min_col + 1) ) / 2;
polarity_ = min_polarity;
pos_value_ = min_pos_value;
neg_value_ = min_neg_value;
return min_row;
}
float Stump::predict(int value)
{
return polarity_ * (value - threshold_) > 0 ? pos_value_ : neg_value_;
}
} /* namespace adas */
} /* namespace cv */
#ifndef __OPENCV_ADAS_STUMP_HPP__
#define __OPENCV_ADAS_STUMP_HPP__
#include <opencv2/core.hpp>
namespace cv
{
namespace adas
{
class Stump
{
public:
/* Initialize zero stump */
Stump(): threshold_(0), polarity_(1), pos_value_(1), neg_value_(-1) {}
/* Initialize stump with given threshold, polarity
and classification values */
Stump(int threshold, int polarity, float pos_value, float neg_value):
threshold_(threshold), polarity_(polarity),
pos_value_(pos_value), neg_value_(neg_value) {}
/* Train stump for given data
data — matrix of feature values, size M x N, one feature per row
labels — matrix of sample class labels, size 1 x N. Labels can be from
{-1, +1}
weights — matrix of sample weights, size 1 x N
Returns chosen feature index. Feature enumeration starts from 0
*/
int train(const Mat& data, const Mat& labels, const Mat& weights);
/* Predict object class given
value — feature value. Feature must be the same as was chosen
during training stump
Returns real value, sign(value) means class
*/
float predict(int value);
private:
/* Stump decision threshold */
int threshold_;
/* Stump polarity, can be from {-1, +1} */
int polarity_;
/* Classification values for positive and negative classes */
float pos_value_, neg_value_;
};
} /* namespace adas */
} /* namespace cv */
#endif /* __OPENCV_ADAS_STUMP_HPP__ */
...@@ -5,170 +5,12 @@ using std::swap; ...@@ -5,170 +5,12 @@ using std::swap;
#include "waldboost.hpp" #include "waldboost.hpp"
using cv::Mat;
using cv::Mat_;
using cv::sum;
using cv::sort;
using cv::sortIdx;
using cv::adas::Stump;
using cv::adas::WaldBoost;
using cv::Ptr;
using cv::adas::ACFFeatureEvaluator;
using std::vector; using std::vector;
/* Cumulative sum by rows */ namespace cv
static void cumsum(const Mat_<float>& src, Mat_<float> dst)
{ {
CV_Assert(src.cols > 0); namespace adas
for( int row = 0; row < src.rows; ++row )
{
dst(row, 0) = src(row, 0);
for( int col = 1; col < src.cols; ++col )
{
dst(row, col) = dst(row, col - 1) + src(row, col);
}
}
}
int Stump::train(const Mat& data, const Mat& labels, const Mat& weights)
{ {
CV_Assert(labels.rows == 1 && labels.cols == data.cols);
CV_Assert(weights.rows == 1 && weights.cols == data.cols);
/* Assert that data and labels have int type */
/* Assert that weights have float type */
/* Prepare labels for each feature rearranged according to sorted order */
Mat sorted_labels(data.rows, data.cols, labels.type());
Mat sorted_weights(data.rows, data.cols, weights.type());
Mat indices;
sortIdx(data, indices, cv::SORT_EVERY_ROW | cv::SORT_ASCENDING);
for( int row = 0; row < indices.rows; ++row )
{
for( int col = 0; col < indices.cols; ++col )
{
sorted_labels.at<int>(row, col) =
labels.at<int>(0, indices.at<int>(row, col));
sorted_weights.at<float>(row, col) =
weights.at<float>(0, indices.at<float>(row, col));
}
}
/* Sort feature values */
Mat sorted_data(data.rows, data.cols, data.type());
sort(data, sorted_data, cv::SORT_EVERY_ROW | cv::SORT_ASCENDING);
/* Split positive and negative weights */
Mat pos_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols,
sorted_weights.type());
Mat neg_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols,
sorted_weights.type());
for( int row = 0; row < data.rows; ++row )
{
for( int col = 0; col < data.cols; ++col )
{
if( sorted_labels.at<int>(row, col) == +1 )
{
pos_weights.at<float>(row, col) =
sorted_weights.at<float>(row, col);
}
else
{
neg_weights.at<float>(row, col) =
sorted_weights.at<float>(row, col);
}
}
}
/* Compute cumulative sums for fast stump error computation */
Mat pos_cum_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols,
sorted_weights.type());
Mat neg_cum_weights = Mat::zeros(sorted_weights.rows, sorted_weights.cols,
sorted_weights.type());
cumsum(pos_weights, pos_cum_weights);
cumsum(neg_weights, neg_cum_weights);
/* Compute total weights of positive and negative samples */
float pos_total_weight = pos_cum_weights.at<float>(0, weights.cols - 1);
float neg_total_weight = neg_cum_weights.at<float>(0, weights.cols - 1);
float eps = 1. / 4 * labels.cols;
/* Compute minimal error */
float min_err = FLT_MAX;
int min_row = -1;
int min_col = -1;
int min_polarity = 0;
float min_pos_value = 1, min_neg_value = -1;
for( int row = 0; row < sorted_weights.rows; ++row )
{
for( int col = 0; col < sorted_weights.cols - 1; ++col )
{
float err, h_pos, h_neg;
// Direct polarity
float pos_wrong = pos_cum_weights.at<float>(row, col);
float pos_right = pos_total_weight - pos_wrong;
float neg_right = neg_cum_weights.at<float>(row, col);
float neg_wrong = neg_total_weight - neg_right;
h_pos = .5 * log((pos_right + eps) / (pos_wrong + eps));
h_neg = .5 * log((neg_wrong + eps) / (neg_right + eps));
err = sqrt(pos_right * neg_wrong) + sqrt(pos_wrong * neg_right);
if( err < min_err )
{
min_err = err;
min_row = row;
min_col = col;
min_polarity = +1;
min_pos_value = h_pos;
min_neg_value = h_neg;
}
// Opposite polarity
swap(pos_right, pos_wrong);
swap(neg_right, neg_wrong);
h_pos = -h_pos;
h_neg = -h_neg;
err = sqrt(pos_right * neg_wrong) + sqrt(pos_wrong * neg_right);
if( err < min_err )
{
min_err = err;
min_row = row;
min_col = col;
min_polarity = -1;
min_pos_value = h_pos;
min_neg_value = h_neg;
}
}
}
/* Compute threshold, store found values in fields */
threshold_ = ( sorted_data.at<int>(min_row, min_col) +
sorted_data.at<int>(min_row, min_col + 1) ) / 2;
polarity_ = min_polarity;
pos_value_ = min_pos_value;
neg_value_ = min_neg_value;
return min_row;
}
float Stump::predict(int value)
{
return polarity_ * (value - threshold_) > 0 ? pos_value_ : neg_value_;
}
WaldBoost::WaldBoost(const WaldBoostParams& params): params_(params) WaldBoost::WaldBoost(const WaldBoostParams& params): params_(params)
{ {
...@@ -275,3 +117,6 @@ float WaldBoost::predict(const Ptr<ACFFeatureEvaluator>& feature_evaluator) ...@@ -275,3 +117,6 @@ float WaldBoost::predict(const Ptr<ACFFeatureEvaluator>& feature_evaluator)
} }
return trace; return trace;
} }
} /* namespace adas */
} /* namespace cv */
...@@ -45,56 +45,13 @@ the use of this software, even if advised of the possibility of such damage. ...@@ -45,56 +45,13 @@ the use of this software, even if advised of the possibility of such damage.
#include <opencv2/core.hpp> #include <opencv2/core.hpp>
#include "acffeature.hpp" #include "acffeature.hpp"
#include "stump.hpp"
namespace cv namespace cv
{ {
namespace adas namespace adas
{ {
class Stump
{
public:
/* Initialize zero stump */
Stump(): threshold_(0), polarity_(1), pos_value_(1), neg_value_(-1) {}
/* Initialize stump with given threshold, polarity
and classification values */
Stump(int threshold, int polarity, float pos_value, float neg_value):
threshold_(threshold), polarity_(polarity),
pos_value_(pos_value), neg_value_(neg_value) {}
/* Train stump for given data
data — matrix of feature values, size M x N, one feature per row
labels — matrix of sample class labels, size 1 x N. Labels can be from
{-1, +1}
weights — matrix of sample weights, size 1 x N
Returns chosen feature index. Feature enumeration starts from 0
*/
int train(const Mat& data, const Mat& labels, const Mat& weights);
/* Predict object class given
value — feature value. Feature must be the same as was chosen
during training stump
Returns real value, sign(value) means class
*/
float predict(int value);
private:
/* Stump decision threshold */
int threshold_;
/* Stump polarity, can be from {-1, +1} */
int polarity_;
/* Classification values for positive and negative classes */
float pos_value_, neg_value_;
};
struct WaldBoostParams struct WaldBoostParams
{ {
int weak_count; int weak_count;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment