Commit 008a1c91 authored by Maria Dimashova's avatar Maria Dimashova

fixed em test

parent 94c258cf
......@@ -87,8 +87,10 @@ void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const vecto
r = r * (*cit) + *mit;
if( labelType == CV_32FC1 )
labels.at<float>(p, 0) = (float)l;
else
else if( labelType == CV_32SC1 )
labels.at<int>(p, 0) = l;
else
CV_DbgAssert(0);
}
}
}
......@@ -201,20 +203,23 @@ void CV_KMeansTest::run( int /*start_from*/ )
generateData( data, labels, sizes, means, covs, CV_32SC1 );
int code = cvtest::TS::OK;
float err;
Mat bestLabels;
// 1. flag==KMEANS_PP_CENTERS
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_PP_CENTERS, noArray() );
if( calcErr( bestLabels, labels, sizes, false ) > 0.01f )
err = calcErr( bestLabels, labels, sizes, false );
if( err > 0.01f )
{
ts->printf( cvtest::TS::LOG, "bad accuracy if flag==KMEANS_PP_CENTERS" );
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
}
// 2. flag==KMEANS_RANDOM_CENTERS
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_RANDOM_CENTERS, noArray() );
if( calcErr( bestLabels, labels, sizes, false ) > 0.01f )
err = calcErr( bestLabels, labels, sizes, false );
if( err > 0.01f )
{
ts->printf( cvtest::TS::LOG, "bad accuracy if flag==KMEANS_PP_CENTERS" );
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
}
......@@ -224,9 +229,10 @@ void CV_KMeansTest::run( int /*start_from*/ )
for( int i = 0; i < 0.5f * pointsCount; i++ )
bestLabels.at<int>( rng.next() % pointsCount, 0 ) = rng.next() % 3;
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_USE_INITIAL_LABELS, noArray() );
if( calcErr( bestLabels, labels, sizes, false ) > 0.01f )
err = calcErr( bestLabels, labels, sizes, false );
if( err > 0.01f )
{
ts->printf( cvtest::TS::LOG, "bad accuracy if flag==KMEANS_PP_CENTERS" );
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
}
......@@ -261,9 +267,10 @@ void CV_KNearestTest::run( int /*start_from*/ )
KNearest knearest;
knearest.train( trainData, trainLabels );
knearest.find_nearest( testData, 4, &bestLabels );
if( calcErr( bestLabels, testLabels, sizes, true ) > 0.01f )
float err = calcErr( bestLabels, testLabels, sizes, true );
if( err > 0.01f )
{
ts->printf( cvtest::TS::LOG, "bad accuracy on test data" );
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
}
ts->set_failed_test_info( code );
......@@ -294,15 +301,17 @@ void CV_EMTest::run( int /*start_from*/ )
generateData( testData, testLabels, sizes, means, covs, CV_32SC1 );
int code = cvtest::TS::OK;
float err;
ExpectationMaximization em;
CvEMParams params;
params.nclusters = 3;
em.train( trainData, Mat(), params, &bestLabels );
// check train error
if( calcErr( bestLabels, trainLabels, sizes, true ) > 0.002f )
err = calcErr( bestLabels, trainLabels, sizes, false );
if( err > 0.002f )
{
ts->printf( cvtest::TS::LOG, "bad accuracy on train data" );
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on train data.\n", err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
}
......@@ -313,9 +322,10 @@ void CV_EMTest::run( int /*start_from*/ )
Mat sample( 1, testData.cols, CV_32FC1, testData.ptr<float>(i));
bestLabels.at<int>(i,0) = (int)em.predict( sample, 0 );
}
if( calcErr( bestLabels, testLabels, sizes, true ) > 0.005f )
err = calcErr( bestLabels, testLabels, sizes, false );
if( err > 0.005f )
{
ts->printf( cvtest::TS::LOG, "bad accuracy on test data" );
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
}
......@@ -324,4 +334,4 @@ void CV_EMTest::run( int /*start_from*/ )
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
TEST(ML_EMTest, accuracy) { CV_EMTest test; test.safe_run(); }
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment