Commit f2c252f8 authored by Maria Dimashova's avatar Maria Dimashova

moved to double in EM; fixed bug

parent b6452f4b
...@@ -205,6 +205,7 @@ CvEM::CvEM( const Mat& samples, const Mat& sample_idx, CvEMParams params ) ...@@ -205,6 +205,7 @@ CvEM::CvEM( const Mat& samples, const Mat& sample_idx, CvEMParams params )
bool CvEM::train( const Mat& _samples, const Mat& _sample_idx, bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
CvEMParams _params, Mat* _labels ) CvEMParams _params, Mat* _labels )
{ {
CV_Assert(_sample_idx.empty());
Mat prbs, weights, means, likelihoods; Mat prbs, weights, means, likelihoods;
std::vector<Mat> covsHdrs; std::vector<Mat> covsHdrs;
init_params(_params, prbs, weights, means, covsHdrs); init_params(_params, prbs, weights, means, covsHdrs);
......
...@@ -578,7 +578,7 @@ public: ...@@ -578,7 +578,7 @@ public:
CV_WRAP virtual bool train(InputArray samples, CV_WRAP virtual bool train(InputArray samples,
OutputArray labels=noArray(), OutputArray labels=noArray(),
OutputArray probs=noArray(), OutputArray probs=noArray(),
OutputArray likelihoods=noArray()); OutputArray logLikelihoods=noArray());
CV_WRAP virtual bool trainE(InputArray samples, CV_WRAP virtual bool trainE(InputArray samples,
InputArray means0, InputArray means0,
...@@ -586,17 +586,17 @@ public: ...@@ -586,17 +586,17 @@ public:
InputArray weights0=noArray(), InputArray weights0=noArray(),
OutputArray labels=noArray(), OutputArray labels=noArray(),
OutputArray probs=noArray(), OutputArray probs=noArray(),
OutputArray likelihoods=noArray()); OutputArray logLikelihoods=noArray());
CV_WRAP virtual bool trainM(InputArray samples, CV_WRAP virtual bool trainM(InputArray samples,
InputArray probs0, InputArray probs0,
OutputArray labels=noArray(), OutputArray labels=noArray(),
OutputArray probs=noArray(), OutputArray probs=noArray(),
OutputArray likelihoods=noArray()); OutputArray logLikelihoods=noArray());
CV_WRAP int predict(InputArray sample, CV_WRAP int predict(InputArray sample,
OutputArray probs=noArray(), OutputArray probs=noArray(),
CV_OUT double* likelihood=0) const; CV_OUT double* logLikelihood=0) const;
CV_WRAP bool isTrained() const; CV_WRAP bool isTrained() const;
...@@ -614,7 +614,7 @@ protected: ...@@ -614,7 +614,7 @@ protected:
bool doTrain(int startStep, bool doTrain(int startStep,
OutputArray labels, OutputArray labels,
OutputArray probs, OutputArray probs,
OutputArray likelihoods); OutputArray logLikelihoods);
virtual void eStep(); virtual void eStep();
virtual void mStep(); virtual void mStep();
...@@ -622,9 +622,9 @@ protected: ...@@ -622,9 +622,9 @@ protected:
void decomposeCovs(); void decomposeCovs();
void computeLogWeightDivDet(); void computeLogWeightDivDet();
void computeProbabilities(const Mat& sample, int& label, Mat* probs, float* likelihood) const; void computeProbabilities(const Mat& sample, int& label, Mat* probs, double* logLikelihood) const;
// all inner matrices have type CV_32FC1 // all inner matrices have type CV_64FC1
CV_PROP_RW int nclusters; CV_PROP_RW int nclusters;
CV_PROP_RW int covMatType; CV_PROP_RW int covMatType;
CV_PROP_RW int maxIters; CV_PROP_RW int maxIters;
...@@ -632,7 +632,7 @@ protected: ...@@ -632,7 +632,7 @@ protected:
Mat trainSamples; Mat trainSamples;
Mat trainProbs; Mat trainProbs;
Mat trainLikelihoods; Mat trainLogLikelihoods;
Mat trainLabels; Mat trainLabels;
Mat trainCounts; Mat trainCounts;
......
This diff is collapsed.
...@@ -45,33 +45,33 @@ using namespace std; ...@@ -45,33 +45,33 @@ using namespace std;
using namespace cv; using namespace cv;
static static
void defaultDistribs( Mat& means, vector<Mat>& covs ) void defaultDistribs( Mat& means, vector<Mat>& covs, int type=CV_32FC1 )
{ {
float mp0[] = {0.0f, 0.0f}, cp0[] = {0.67f, 0.0f, 0.0f, 0.67f}; float mp0[] = {0.0f, 0.0f}, cp0[] = {0.67f, 0.0f, 0.0f, 0.67f};
float mp1[] = {5.0f, 0.0f}, cp1[] = {1.0f, 0.0f, 0.0f, 1.0f}; float mp1[] = {5.0f, 0.0f}, cp1[] = {1.0f, 0.0f, 0.0f, 1.0f};
float mp2[] = {1.0f, 5.0f}, cp2[] = {1.0f, 0.0f, 0.0f, 1.0f}; float mp2[] = {1.0f, 5.0f}, cp2[] = {1.0f, 0.0f, 0.0f, 1.0f};
means.create(3, 2, CV_32FC1); means.create(3, 2, type);
Mat m0( 1, 2, CV_32FC1, mp0 ), c0( 2, 2, CV_32FC1, cp0 ); Mat m0( 1, 2, CV_32FC1, mp0 ), c0( 2, 2, CV_32FC1, cp0 );
Mat m1( 1, 2, CV_32FC1, mp1 ), c1( 2, 2, CV_32FC1, cp1 ); Mat m1( 1, 2, CV_32FC1, mp1 ), c1( 2, 2, CV_32FC1, cp1 );
Mat m2( 1, 2, CV_32FC1, mp2 ), c2( 2, 2, CV_32FC1, cp2 ); Mat m2( 1, 2, CV_32FC1, mp2 ), c2( 2, 2, CV_32FC1, cp2 );
means.resize(3), covs.resize(3); means.resize(3), covs.resize(3);
Mat mr0 = means.row(0); Mat mr0 = means.row(0);
m0.copyTo(mr0); m0.convertTo(mr0, type);
c0.copyTo(covs[0]); c0.convertTo(covs[0], type);
Mat mr1 = means.row(1); Mat mr1 = means.row(1);
m1.copyTo(mr1); m1.convertTo(mr1, type);
c1.copyTo(covs[1]); c1.convertTo(covs[1], type);
Mat mr2 = means.row(2); Mat mr2 = means.row(2);
m2.copyTo(mr2); m2.convertTo(mr2, type);
c2.copyTo(covs[2]); c2.convertTo(covs[2], type);
} }
// generate points sets by normal distributions // generate points sets by normal distributions
static static
void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat& _means, const vector<Mat>& covs, int labelType ) void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat& _means, const vector<Mat>& covs, int dataType, int labelType )
{ {
vector<int>::const_iterator sit = sizes.begin(); vector<int>::const_iterator sit = sizes.begin();
int total = 0; int total = 0;
...@@ -79,7 +79,7 @@ void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat& ...@@ -79,7 +79,7 @@ void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat&
total += *sit; total += *sit;
assert( _means.rows == (int)sizes.size() && covs.size() == sizes.size() ); assert( _means.rows == (int)sizes.size() && covs.size() == sizes.size() );
assert( !data.empty() && data.rows == total ); assert( !data.empty() && data.rows == total );
assert( data.type() == CV_32FC1 ); assert( data.type() == dataType );
labels.create( data.rows, 1, labelType ); labels.create( data.rows, 1, labelType );
...@@ -98,7 +98,7 @@ void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat& ...@@ -98,7 +98,7 @@ void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat&
assert( cit->rows == data.cols && cit->cols == data.cols ); assert( cit->rows == data.cols && cit->cols == data.cols );
for( int i = bi; i < ei; i++, p++ ) for( int i = bi; i < ei; i++, p++ )
{ {
Mat r(1, data.cols, CV_32FC1, data.ptr<float>(i)); Mat r = data.row(i);
r = r * (*cit) + *mit; r = r * (*cit) + *mit;
if( labelType == CV_32FC1 ) if( labelType == CV_32FC1 )
labels.at<float>(p, 0) = (float)l; labels.at<float>(p, 0) = (float)l;
...@@ -226,7 +226,7 @@ void CV_KMeansTest::run( int /*start_from*/ ) ...@@ -226,7 +226,7 @@ void CV_KMeansTest::run( int /*start_from*/ )
Mat means; Mat means;
vector<Mat> covs; vector<Mat> covs;
defaultDistribs( means, covs ); defaultDistribs( means, covs );
generateData( data, labels, sizes, means, covs, CV_32SC1 ); generateData( data, labels, sizes, means, covs, CV_32FC1, CV_32SC1 );
int code = cvtest::TS::OK; int code = cvtest::TS::OK;
float err; float err;
...@@ -296,11 +296,11 @@ void CV_KNearestTest::run( int /*start_from*/ ) ...@@ -296,11 +296,11 @@ void CV_KNearestTest::run( int /*start_from*/ )
Mat means; Mat means;
vector<Mat> covs; vector<Mat> covs;
defaultDistribs( means, covs ); defaultDistribs( means, covs );
generateData( trainData, trainLabels, sizes, means, covs, CV_32FC1 ); generateData( trainData, trainLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );
// test data // test data
Mat testData( pointsCount, 2, CV_32FC1 ), testLabels, bestLabels; Mat testData( pointsCount, 2, CV_32FC1 ), testLabels, bestLabels;
generateData( testData, testLabels, sizes, means, covs, CV_32FC1 ); generateData( testData, testLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );
int code = cvtest::TS::OK; int code = cvtest::TS::OK;
KNearest knearest; KNearest knearest;
...@@ -392,7 +392,9 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params, ...@@ -392,7 +392,9 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
for( int i = 0; i < testData.rows; i++ ) for( int i = 0; i < testData.rows; i++ )
{ {
Mat sample = testData.row(i); Mat sample = testData.row(i);
labels.at<int>(i,0) = (int)em.predict( sample, noArray() ); double likelihood = 0;
Mat probs;
labels.at<int>(i,0) = (int)em.predict( sample, probs, &likelihood );
} }
if( !calcErr( labels, testLabels, sizes, err, false ) ) if( !calcErr( labels, testLabels, sizes, err, false ) )
{ {
...@@ -416,22 +418,22 @@ void CV_EMTest::run( int /*start_from*/ ) ...@@ -416,22 +418,22 @@ void CV_EMTest::run( int /*start_from*/ )
// Points distribution // Points distribution
Mat means; Mat means;
vector<Mat> covs; vector<Mat> covs;
defaultDistribs( means, covs ); defaultDistribs( means, covs, CV_64FC1 );
// train data // train data
Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels; Mat trainData( pointsCount, 2, CV_64FC1 ), trainLabels;
vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) ); vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
generateData( trainData, trainLabels, sizes, means, covs, CV_32SC1 ); generateData( trainData, trainLabels, sizes, means, covs, CV_64FC1, CV_32SC1 );
// test data // test data
Mat testData( pointsCount, 2, CV_32FC1 ), testLabels; Mat testData( pointsCount, 2, CV_64FC1 ), testLabels;
generateData( testData, testLabels, sizes, means, covs, CV_32SC1 ); generateData( testData, testLabels, sizes, means, covs, CV_64FC1, CV_32SC1 );
EM_Params params; EM_Params params;
params.nclusters = 3; params.nclusters = 3;
Mat probs(trainData.rows, params.nclusters, CV_32FC1, cv::Scalar(1)); Mat probs(trainData.rows, params.nclusters, CV_64FC1, cv::Scalar(1));
params.probs = &probs; params.probs = &probs;
Mat weights(1, params.nclusters, CV_32FC1, cv::Scalar(1)); Mat weights(1, params.nclusters, CV_64FC1, cv::Scalar(1));
params.weights = &weights; params.weights = &weights;
params.means = &means; params.means = &means;
params.covs = &covs; params.covs = &covs;
...@@ -505,18 +507,18 @@ protected: ...@@ -505,18 +507,18 @@ protected:
int code = cvtest::TS::OK; int code = cvtest::TS::OK;
cv::EM em(2); cv::EM em(2);
Mat samples = Mat(3,1,CV_32F); Mat samples = Mat(3,1,CV_64FC1);
samples.at<float>(0,0) = 1; samples.at<double>(0,0) = 1;
samples.at<float>(1,0) = 2; samples.at<double>(1,0) = 2;
samples.at<float>(2,0) = 3; samples.at<double>(2,0) = 3;
Mat labels; Mat labels;
em.train(samples, labels); em.train(samples, labels);
Mat firstResult(samples.rows, 1, CV_32FC1); Mat firstResult(samples.rows, 1, CV_32SC1);
for( int i = 0; i < samples.rows; i++) for( int i = 0; i < samples.rows; i++)
firstResult.at<float>(i) = (float)em.predict( samples.row(i) ); firstResult.at<int>(i) = em.predict(samples.row(i));
// Write out // Write out
string filename = tempfile() + ".xml"; string filename = tempfile() + ".xml";
...@@ -557,7 +559,7 @@ protected: ...@@ -557,7 +559,7 @@ protected:
int errCaseCount = 0; int errCaseCount = 0;
for( int i = 0; i < samples.rows; i++) for( int i = 0; i < samples.rows; i++)
errCaseCount = std::abs(em.predict(samples.row(i)) - firstResult.at<float>(i)) < FLT_EPSILON ? 0 : 1; errCaseCount = std::abs(em.predict(samples.row(i)) - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;
if( errCaseCount > 0 ) if( errCaseCount > 0 )
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment