Commit dda5b885 authored by Vlad Shakhuro's avatar Vlad Shakhuro

Make Stump real-valued

parent ee55e8d6
...@@ -65,7 +65,7 @@ public: ...@@ -65,7 +65,7 @@ public:
ACFFeatureEvaluator(const std::vector<Point>& features); ACFFeatureEvaluator(const std::vector<Point>& features);
/* Set channels for feature evaluation */ /* Set channels for feature evaluation */
void setChannels(const std::vector<Mat> >& channels); void setChannels(const std::vector<Mat>& channels);
/* Set window position */ /* Set window position */
void setPosition(Size position); void setPosition(Size position);
......
#include <cmath>
#include <algorithm>
using std::swap;
#include "waldboost.hpp" #include "waldboost.hpp"
using cv::Mat; using cv::Mat;
...@@ -6,6 +11,7 @@ using cv::sort; ...@@ -6,6 +11,7 @@ using cv::sort;
using cv::sortIdx; using cv::sortIdx;
using cv::adas::Stump; using cv::adas::Stump;
using cv::adas::WaldBoost; using cv::adas::WaldBoost;
using cv::Ptr;
/* Cumulative sum by rows */ /* Cumulative sum by rows */
static void cumsum(const Mat_<float>& src, Mat_<float> dst) static void cumsum(const Mat_<float>& src, Mat_<float> dst)
...@@ -81,28 +87,37 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights) ...@@ -81,28 +87,37 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights)
cumsum(neg_weights, neg_cum_weights); cumsum(neg_weights, neg_cum_weights);
/* Compute total weights of positive and negative samples */ /* Compute total weights of positive and negative samples */
int pos_total_weight = 0, neg_total_weight = 0; float pos_total_weight = pos_cum_weights.at<float>(0, weights.cols - 1);
for( int col = 0; col < labels.cols; ++col ) float neg_total_weight = neg_cum_weights.at<float>(0, weights.cols - 1);
{
if( labels.at<int>(0, col) == +1)
pos_total_weight += weights.at<float>(0, col); float eps = 1. / 4 * labels.cols;
else
neg_total_weight += weights.at<float>(0, col);
}
/* Compute minimal error */ /* Compute minimal error */
float min_err = FLT_MAX; float min_err = FLT_MAX;
int min_row = -1; int min_row = -1;
int min_col = -1; int min_col = -1;
int min_polarity = 0; int min_polarity = 0;
float min_pos_value = 1, min_neg_value = -1;
for( int row = 0; row < sorted_weights.rows; ++row ) for( int row = 0; row < sorted_weights.rows; ++row )
{ {
for( int col = 0; col < sorted_weights.cols - 1; ++col ) for( int col = 0; col < sorted_weights.cols - 1; ++col )
{ {
float err; float err, h_pos, h_neg;
// Direct polarity
err = pos_cum_weights.at<float>(row, col) + float pos_wrong = pos_cum_weights.at<float>(row, col);
(neg_total_weight - neg_cum_weights.at<float>(row, col)); float pos_right = pos_total_weight - pos_wrong;
float neg_right = neg_cum_weights.at<float>(row, col);
float neg_wrong = neg_total_weight - neg_right;
h_pos = .5 * log((pos_right + eps) / (pos_wrong + eps));
h_neg = .5 * log((neg_wrong + eps) / (neg_right + eps));
err = sqrt(pos_right * neg_wrong) + sqrt(pos_wrong * neg_right);
if( err < min_err ) if( err < min_err )
{ {
...@@ -110,11 +125,19 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights) ...@@ -110,11 +125,19 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights)
min_row = row; min_row = row;
min_col = col; min_col = col;
min_polarity = +1; min_polarity = +1;
min_pos_value = h_pos;
min_neg_value = h_neg;
} }
// Opposite polarity
swap(pos_right, pos_wrong);
swap(neg_right, neg_wrong);
h_pos = -h_pos;
h_neg = -h_neg;
err = sqrt(pos_right * neg_wrong) + sqrt(pos_wrong * neg_right);
err = (pos_total_weight - pos_cum_weights.at<float>(row, col)) +
neg_cum_weights.at<float>(row, col);
if( err < min_err ) if( err < min_err )
{ {
...@@ -122,6 +145,8 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights) ...@@ -122,6 +145,8 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights)
min_row = row; min_row = row;
min_col = col; min_col = col;
min_polarity = -1; min_polarity = -1;
min_pos_value = h_pos;
min_neg_value = h_neg;
} }
} }
} }
...@@ -130,18 +155,13 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights) ...@@ -130,18 +155,13 @@ int Stump::train(const Mat& data, const Mat& labels, const Mat& weights)
threshold_ = ( sorted_data.at<int>(min_row, min_col) + threshold_ = ( sorted_data.at<int>(min_row, min_col) +
sorted_data.at<int>(min_row, min_col + 1) ) / 2; sorted_data.at<int>(min_row, min_col + 1) ) / 2;
polarity_ = min_polarity; polarity_ = min_polarity;
pos_value_ = min_pos_value;
neg_value_ = min_neg_value;
return min_row; return min_row;
} }
static inline int sign(int value) float Stump::predict(int value)
{
if (value > 0)
return +1;
return -1;
}
int Stump::predict(int value)
{ {
return polarity_ * sign(value - threshold_); return polarity_ * (value - threshold_) > 0 ? pos_value_ : neg_value_;
} }
...@@ -56,11 +56,13 @@ class Stump ...@@ -56,11 +56,13 @@ class Stump
public: public:
/* Initialize zero stump */ /* Initialize zero stump */
Stump(): threshold_(0), polarity_(1) {}; Stump(): threshold_(0), polarity_(1), pos_value_(1), neg_value_(-1) {}
/* Initialize stump with given threshold and polarity */ /* Initialize stump with given threshold, polarity
Stump(int threshold, int polarity): threshold_(threshold), and classification values */
polarity_(polarity) {}; Stump(int threshold, int polarity, float pos_value, float neg_value):
threshold_(threshold), polarity_(polarity),
pos_value_(pos_value), neg_value_(neg_value) {}
/* Train stump for given data /* Train stump for given data
...@@ -77,21 +79,20 @@ public: ...@@ -77,21 +79,20 @@ public:
/* Predict object class given /* Predict object class given
value — feature value. Feature must be the same as chose during training value — feature value. Feature must be the same as was chosen
stump during training stump
Returns object class from {-1, +1} Returns real value, sign(value) means class
*/ */
int predict(int value); float predict(int value);
private: private:
/* Stump decision threshold */ /* Stump decision threshold */
int threshold_; int threshold_;
/* Stump polarity, can be from {-1, +1} */ /* Stump polarity, can be from {-1, +1} */
int polarity_; int polarity_;
/* Stump decision rule: /* Classification values for positive and negative classes */
h(value) = polarity * sign(value - threshold) float pos_value_, neg_value_;
*/
}; };
struct WaldBoostParams struct WaldBoostParams
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment