Commit 4331f76d authored by Vadim Pisarevsky's avatar Vadim Pisarevsky

add hack to disable optimization of linear svms; improved precision of…

add hack to disable optimization of linear svms; improved precision of optimize_linear_svm; add the relevant test, which however requires some big database (so it's disabled by default)
parent 63a5587d
......@@ -1551,6 +1551,8 @@ void CvSVM::optimize_linear_svm()
return;
int var_count = get_var_count();
cv::AutoBuffer<double> vbuf;
double* v = vbuf;
int sample_size = (int)(var_count*sizeof(sv[0][0]));
float** new_sv = (float**)cvMemStorageAlloc(storage, df_count*sizeof(new_sv[0]));
......@@ -1558,15 +1560,17 @@ void CvSVM::optimize_linear_svm()
{
new_sv[i] = (float*)cvMemStorageAlloc(storage, sample_size);
float* dst = new_sv[i];
memset(dst, 0, sample_size);
memset(v, 0, var_count*sizeof(v[0]));
int j, k, sv_count = df[i].sv_count;
for( j = 0; j < sv_count; j++ )
{
const float* src = class_count > 1 && df[i].sv_index ? sv[df[i].sv_index[j]] : sv[j];
double a = df[i].alpha[j];
for( k = 0; k < var_count; k++ )
dst[k] = (float)(dst[k] + src[k]*a);
v[k] += src[k]*a;
}
for( k = 0; k < var_count; k++ )
dst[k] = (float)v[k];
df[i].sv_count = 1;
df[i].alpha[0] = 1.;
if( class_count > 1 && df[i].sv_index )
......@@ -2570,7 +2574,8 @@ void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
}
optimize_linear_svm();
if( cvReadIntByName(fs, svm_node, "optimize_linear", 1) != 0 )
optimize_linear_svm();
create_kernel();
__END__;
......
......@@ -133,4 +133,32 @@ TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); }
TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
TEST(ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
TEST(DISABLED_ML_SVM, linear_save_load)
{
CvSVM svm1, svm2, svm3;
svm1.load("SVM45_X_38-1.xml");
svm2.load("SVM45_X_38-2.xml");
string tname = tempfile("a.xml");
svm2.save(tname.c_str());
svm3.load(tname.c_str());
ASSERT_EQ(svm1.get_var_count(), svm2.get_var_count());
ASSERT_EQ(svm1.get_var_count(), svm3.get_var_count());
int m = 10000, n = svm1.get_var_count();
Mat samples(m, n, CV_32F), r1, r2, r3;
randu(samples, 0., 1.);
svm1.predict(samples, r1);
svm2.predict(samples, r2);
svm3.predict(samples, r3);
double eps = 1e-4;
EXPECT_LE(norm(r1, r2, NORM_INF), eps);
EXPECT_LE(norm(r1, r3, NORM_INF), eps);
remove(tname.c_str());
}
/* End of file. */
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment