Commit bad4ca2a authored by Vadim Pisarevsky's avatar Vadim Pisarevsky

added the optional balanced cross-validation in SVN::train_auto (by arman, ticket #314)

parent 0cab986e
......@@ -540,7 +540,8 @@ public:
CvParamGrid pGrid = get_default_grid(CvSVM::P),
CvParamGrid nuGrid = get_default_grid(CvSVM::NU),
CvParamGrid coeffGrid = get_default_grid(CvSVM::COEF),
CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE) );
CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
bool balanced=false );
virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
......@@ -561,7 +562,8 @@ public:
CvParamGrid pGrid = CvSVM::get_default_grid(CvSVM::P),
CvParamGrid nuGrid = CvSVM::get_default_grid(CvSVM::NU),
CvParamGrid coeffGrid = CvSVM::get_default_grid(CvSVM::COEF),
CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE) );
CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
bool balanced=false);
CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
#endif
......
......@@ -1593,10 +1593,27 @@ bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
return ok;
}
struct indexedratio
{
double val;
int ind;
int count_smallest, count_biggest;
void eval() { val = (double) count_smallest/(count_smallest+count_biggest); }
};
static int CV_CDECL
icvCmpIndexedratio( const void* a, const void* b )
{
return ((const indexedratio*)a)->val < ((const indexedratio*)b)->val ? -1
: ((const indexedratio*)a)->val > ((const indexedratio*)b)->val ? 1
: 0;
}
bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, int k_fold,
CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid,
bool balanced)
{
bool ok = false;
CvMat* responses = 0;
......@@ -1757,6 +1774,105 @@ bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
else
CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
}
if (!is_regression && class_labels->cols==2 && balanced)
{
// count class samples
int num_0=0,num_1=0;
for (i=0; i<sample_count; ++i)
{
if (responses->data.i[i]==class_labels->data.i[0])
++num_0;
else
++num_1;
}
int label_smallest_class;
int label_biggest_class;
if (num_0 < num_1)
{
label_biggest_class = class_labels->data.i[1];
label_smallest_class = class_labels->data.i[0];
}
else
{
label_biggest_class = class_labels->data.i[0];
label_smallest_class = class_labels->data.i[1];
int y;
CV_SWAP(num_0,num_1,y);
}
const double class_ratio = (double) num_0/sample_count;
// calculate class ratio of each fold
indexedratio *ratios=0;
ratios = (indexedratio*) cvAlloc(k_fold*sizeof(*ratios));
for (int k=0, i_begin=0; k<k_fold; ++k, i_begin+=testset_size)
{
int count0=0;
int count1=0;
int i_end = i_begin + (k<k_fold-1 ? testset_size : last_testset_size);
for (int i=i_begin; i<i_end; ++i)
{
if (responses->data.i[i]==label_smallest_class)
++count0;
else
++count1;
}
ratios[k].ind = k;
ratios[k].count_smallest = count0;
ratios[k].count_biggest = count1;
ratios[k].eval();
}
// initial distance
qsort(ratios, k_fold, sizeof(ratios[0]), icvCmpIndexedratio);
double old_dist = 0.0;
for (int k=0; k<k_fold; ++k)
old_dist += abs(ratios[k].val-class_ratio);
double new_dist = 1.0;
// iterate to make the folds more balanced
while (new_dist > 0.0)
{
if (ratios[0].count_biggest==0 || ratios[k_fold-1].count_smallest==0)
break; // we are not able to swap samples anymore
// what if we swap the samples, calculate the new distance
ratios[0].count_smallest++;
ratios[0].count_biggest--;
ratios[0].eval();
ratios[k_fold-1].count_smallest--;
ratios[k_fold-1].count_biggest++;
ratios[k_fold-1].eval();
qsort(ratios, k_fold, sizeof(ratios[0]), icvCmpIndexedratio);
new_dist = 0.0;
for (int k=0; k<k_fold; ++k)
new_dist += abs(ratios[k].val-class_ratio);
if (new_dist < old_dist)
{
// swapping really improves, so swap the samples
// index of the biggest_class sample from the minimum ratio fold
int i1 = ratios[0].ind * testset_size;
for ( ; i1<sample_count; ++i1)
{
if (responses->data.i[i1]==label_biggest_class)
break;
}
// index of the smallest_class sample from the maximum ratio fold
int i2 = ratios[k_fold-1].ind * testset_size;
for ( ; i2<sample_count; ++i2)
{
if (responses->data.i[i2]==label_smallest_class)
break;
}
// swap
const float* temp;
int y;
CV_SWAP( samples[i1], samples[i2], temp );
CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
old_dist = new_dist;
}
else
break; // does not improve, so break the loop
}
cvFree(&ratios);
}
int* cls_lbls = class_labels ? class_labels->data.i : 0;
C = C_grid.min_val;
......@@ -2011,12 +2127,12 @@ bool CvSVM::train( const Mat& _train_data, const Mat& _responses,
bool CvSVM::train_auto( const Mat& _train_data, const Mat& _responses,
const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params, int k_fold,
CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid, bool balanced )
{
CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, sidx = _sample_idx;
return train_auto(&tdata, &responses, vidx.data.ptr ? &vidx : 0,
sidx.data.ptr ? &sidx : 0, _params, k_fold, C_grid, gamma_grid, p_grid,
nu_grid, coef_grid, degree_grid);
nu_grid, coef_grid, degree_grid, balanced);
}
float CvSVM::predict( const Mat& _sample, bool returnDFVal ) const
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment