Commit 3dfa9178 authored by Maria Dimashova's avatar Maria Dimashova

refactored train and predict methods of em

parent 8f7e5811
...@@ -213,7 +213,7 @@ void CvHybridTracker::updateTrackerWithEM(Mat image) { ...@@ -213,7 +213,7 @@ void CvHybridTracker::updateTrackerWithEM(Mat image) {
cv::Mat lbls; cv::Mat lbls;
EM em_model(1, EM::COV_MAT_SPHERICAL, TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.001)); EM em_model(1, EM::COV_MAT_SPHERICAL, TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.001));
em_model.train(cvarrToMat(samples), lbls); em_model.train(cvarrToMat(samples), noArray(), lbls);
if(labels) if(labels)
lbls.copyTo(cvarrToMat(labels)); lbls.copyTo(cvarrToMat(labels));
......
...@@ -1826,7 +1826,7 @@ public: ...@@ -1826,7 +1826,7 @@ public:
CV_WRAP cv::Mat getWeights() const; CV_WRAP cv::Mat getWeights() const;
CV_WRAP cv::Mat getProbs() const; CV_WRAP cv::Mat getProbs() const;
CV_WRAP inline double getLikelihood() const { return emObj.isTrained() ? likelihood : DBL_MAX; } CV_WRAP inline double getLikelihood() const { return emObj.isTrained() ? logLikelihood : DBL_MAX; }
#endif #endif
CV_WRAP virtual void clear(); CV_WRAP virtual void clear();
...@@ -1847,7 +1847,7 @@ protected: ...@@ -1847,7 +1847,7 @@ protected:
cv::EM emObj; cv::EM emObj;
cv::Mat probs; cv::Mat probs;
double likelihood; double logLikelihood;
CvMat meansHdr; CvMat meansHdr;
std::vector<CvMat> covsHdrs; std::vector<CvMat> covsHdrs;
......
...@@ -56,12 +56,12 @@ CvEMParams::CvEMParams( int _nclusters, int _cov_mat_type, int _start_step, ...@@ -56,12 +56,12 @@ CvEMParams::CvEMParams( int _nclusters, int _cov_mat_type, int _start_step,
probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit) probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
{} {}
CvEM::CvEM() : likelihood(DBL_MAX) CvEM::CvEM() : logLikelihood(DBL_MAX)
{ {
} }
CvEM::CvEM( const CvMat* samples, const CvMat* sample_idx, CvEM::CvEM( const CvMat* samples, const CvMat* sample_idx,
CvEMParams params, CvMat* labels ) : likelihood(DBL_MAX) CvEMParams params, CvMat* labels ) : logLikelihood(DBL_MAX)
{ {
train(samples, sample_idx, params, labels); train(samples, sample_idx, params, labels);
} }
...@@ -96,16 +96,14 @@ void CvEM::write( CvFileStorage* _fs, const char* name ) const ...@@ -96,16 +96,14 @@ void CvEM::write( CvFileStorage* _fs, const char* name ) const
double CvEM::calcLikelihood( const Mat &input_sample ) const double CvEM::calcLikelihood( const Mat &input_sample ) const
{ {
double likelihood; return emObj.predict(input_sample)[0];
emObj.predict(input_sample, noArray(), &likelihood);
return likelihood;
} }
float float
CvEM::predict( const CvMat* _sample, CvMat* _probs ) const CvEM::predict( const CvMat* _sample, CvMat* _probs ) const
{ {
Mat prbs0 = cvarrToMat(_probs), prbs = prbs0, sample = cvarrToMat(_sample); Mat prbs0 = cvarrToMat(_probs), prbs = prbs0, sample = cvarrToMat(_sample);
int cls = emObj.predict(sample, _probs ? _OutputArray(prbs) : cv::noArray()); int cls = static_cast<int>(emObj.predict(sample, _probs ? _OutputArray(prbs) : cv::noArray())[1]);
if(_probs) if(_probs)
{ {
if( prbs.data != prbs0.data ) if( prbs.data != prbs0.data )
...@@ -203,29 +201,27 @@ bool CvEM::train( const Mat& _samples, const Mat& _sample_idx, ...@@ -203,29 +201,27 @@ bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
CvEMParams _params, Mat* _labels ) CvEMParams _params, Mat* _labels )
{ {
CV_Assert(_sample_idx.empty()); CV_Assert(_sample_idx.empty());
Mat prbs, weights, means, likelihoods; Mat prbs, weights, means, logLikelihoods;
std::vector<Mat> covsHdrs; std::vector<Mat> covsHdrs;
init_params(_params, prbs, weights, means, covsHdrs); init_params(_params, prbs, weights, means, covsHdrs);
emObj = EM(_params.nclusters, _params.cov_mat_type, _params.term_crit); emObj = EM(_params.nclusters, _params.cov_mat_type, _params.term_crit);
bool isOk = false; bool isOk = false;
if( _params.start_step == EM::START_AUTO_STEP ) if( _params.start_step == EM::START_AUTO_STEP )
isOk = emObj.train(_samples, _labels ? _OutputArray(*_labels) : cv::noArray(), isOk = emObj.train(_samples,
probs, likelihoods); logLikelihoods, _labels ? _OutputArray(*_labels) : cv::noArray(), probs);
else if( _params.start_step == EM::START_E_STEP ) else if( _params.start_step == EM::START_E_STEP )
isOk = emObj.trainE(_samples, means, covsHdrs, weights, isOk = emObj.trainE(_samples, means, covsHdrs, weights,
_labels ? _OutputArray(*_labels) : cv::noArray(), logLikelihoods, _labels ? _OutputArray(*_labels) : cv::noArray(), probs);
probs, likelihoods);
else if( _params.start_step == EM::START_M_STEP ) else if( _params.start_step == EM::START_M_STEP )
isOk = emObj.trainM(_samples, prbs, isOk = emObj.trainM(_samples, prbs,
_labels ? _OutputArray(*_labels) : cv::noArray(), logLikelihoods, _labels ? _OutputArray(*_labels) : cv::noArray(), probs);
probs, likelihoods);
else else
CV_Error(CV_StsBadArg, "Bad start type of EM algorithm"); CV_Error(CV_StsBadArg, "Bad start type of EM algorithm");
if(isOk) if(isOk)
{ {
likelihoods = sum(likelihoods).val[0]; logLikelihood = sum(logLikelihoods).val[0];
set_mat_hdrs(); set_mat_hdrs();
} }
...@@ -235,8 +231,7 @@ bool CvEM::train( const Mat& _samples, const Mat& _sample_idx, ...@@ -235,8 +231,7 @@ bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
float float
CvEM::predict( const Mat& _sample, Mat* _probs ) const CvEM::predict( const Mat& _sample, Mat* _probs ) const
{ {
int cls = emObj.predict(_sample, _probs ? _OutputArray(*_probs) : cv::noArray()); return static_cast<float>(emObj.predict(_sample, _probs ? _OutputArray(*_probs) : cv::noArray())[1]);
return (float)cls;
} }
int CvEM::getNClusters() const int CvEM::getNClusters() const
......
...@@ -577,27 +577,26 @@ public: ...@@ -577,27 +577,26 @@ public:
CV_WRAP virtual void clear(); CV_WRAP virtual void clear();
CV_WRAP virtual bool train(InputArray samples, CV_WRAP virtual bool train(InputArray samples,
OutputArray logLikelihoods=noArray(),
OutputArray labels=noArray(), OutputArray labels=noArray(),
OutputArray probs=noArray(), OutputArray probs=noArray());
OutputArray logLikelihoods=noArray());
CV_WRAP virtual bool trainE(InputArray samples, CV_WRAP virtual bool trainE(InputArray samples,
InputArray means0, InputArray means0,
InputArray covs0=noArray(), InputArray covs0=noArray(),
InputArray weights0=noArray(), InputArray weights0=noArray(),
OutputArray logLikelihoods=noArray(),
OutputArray labels=noArray(), OutputArray labels=noArray(),
OutputArray probs=noArray(), OutputArray probs=noArray());
OutputArray logLikelihoods=noArray());
CV_WRAP virtual bool trainM(InputArray samples, CV_WRAP virtual bool trainM(InputArray samples,
InputArray probs0, InputArray probs0,
OutputArray logLikelihoods=noArray(),
OutputArray labels=noArray(), OutputArray labels=noArray(),
OutputArray probs=noArray(), OutputArray probs=noArray());
OutputArray logLikelihoods=noArray());
CV_WRAP int predict(InputArray sample, CV_WRAP Vec2d predict(InputArray sample,
OutputArray probs=noArray(), OutputArray probs=noArray()) const;
CV_OUT double* logLikelihood=0) const;
CV_WRAP bool isTrained() const; CV_WRAP bool isTrained() const;
...@@ -613,9 +612,9 @@ protected: ...@@ -613,9 +612,9 @@ protected:
const Mat* weights0); const Mat* weights0);
bool doTrain(int startStep, bool doTrain(int startStep,
OutputArray logLikelihoods,
OutputArray labels, OutputArray labels,
OutputArray probs, OutputArray probs);
OutputArray logLikelihoods);
virtual void eStep(); virtual void eStep();
virtual void mStep(); virtual void mStep();
...@@ -623,7 +622,7 @@ protected: ...@@ -623,7 +622,7 @@ protected:
void decomposeCovs(); void decomposeCovs();
void computeLogWeightDivDet(); void computeLogWeightDivDet();
void computeProbabilities(const Mat& sample, int& label, Mat* probs, double* logLikelihood) const; Vec2d computeProbabilities(const Mat& sample, Mat* probs) const;
// all inner matrices have type CV_64FC1 // all inner matrices have type CV_64FC1
CV_PROP_RW int nclusters; CV_PROP_RW int nclusters;
......
...@@ -81,22 +81,22 @@ void EM::clear() ...@@ -81,22 +81,22 @@ void EM::clear()
bool EM::train(InputArray samples, bool EM::train(InputArray samples,
OutputArray logLikelihoods,
OutputArray labels, OutputArray labels,
OutputArray probs, OutputArray probs)
OutputArray logLikelihoods)
{ {
Mat samplesMat = samples.getMat(); Mat samplesMat = samples.getMat();
setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0); setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
return doTrain(START_AUTO_STEP, labels, probs, logLikelihoods); return doTrain(START_AUTO_STEP, logLikelihoods, labels, probs);
} }
bool EM::trainE(InputArray samples, bool EM::trainE(InputArray samples,
InputArray _means0, InputArray _means0,
InputArray _covs0, InputArray _covs0,
InputArray _weights0, InputArray _weights0,
OutputArray logLikelihoods,
OutputArray labels, OutputArray labels,
OutputArray probs, OutputArray probs)
OutputArray logLikelihoods)
{ {
Mat samplesMat = samples.getMat(); Mat samplesMat = samples.getMat();
vector<Mat> covs0; vector<Mat> covs0;
...@@ -106,24 +106,24 @@ bool EM::trainE(InputArray samples, ...@@ -106,24 +106,24 @@ bool EM::trainE(InputArray samples,
setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0, setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
!_covs0.empty() ? &covs0 : 0, _weights0.empty() ? &weights0 : 0); !_covs0.empty() ? &covs0 : 0, _weights0.empty() ? &weights0 : 0);
return doTrain(START_E_STEP, labels, probs, logLikelihoods); return doTrain(START_E_STEP, logLikelihoods, labels, probs);
} }
bool EM::trainM(InputArray samples, bool EM::trainM(InputArray samples,
InputArray _probs0, InputArray _probs0,
OutputArray logLikelihoods,
OutputArray labels, OutputArray labels,
OutputArray probs, OutputArray probs)
OutputArray logLikelihoods)
{ {
Mat samplesMat = samples.getMat(); Mat samplesMat = samples.getMat();
Mat probs0 = _probs0.getMat(); Mat probs0 = _probs0.getMat();
setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0); setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
return doTrain(START_M_STEP, labels, probs, logLikelihoods); return doTrain(START_M_STEP, logLikelihoods, labels, probs);
} }
int EM::predict(InputArray _sample, OutputArray _probs, double* logLikelihood) const Vec2d EM::predict(InputArray _sample, OutputArray _probs) const
{ {
Mat sample = _sample.getMat(); Mat sample = _sample.getMat();
CV_Assert(isTrained()); CV_Assert(isTrained());
...@@ -136,16 +136,14 @@ int EM::predict(InputArray _sample, OutputArray _probs, double* logLikelihood) c ...@@ -136,16 +136,14 @@ int EM::predict(InputArray _sample, OutputArray _probs, double* logLikelihood) c
sample = tmp; sample = tmp;
} }
int label;
Mat probs; Mat probs;
if( _probs.needed() ) if( _probs.needed() )
{ {
_probs.create(1, nclusters, CV_64FC1); _probs.create(1, nclusters, CV_64FC1);
probs = _probs.getMat(); probs = _probs.getMat();
} }
computeProbabilities(sample, label, !probs.empty() ? &probs : 0, logLikelihood);
return label; return computeProbabilities(sample, !probs.empty() ? &probs : 0);
} }
bool EM::isTrained() const bool EM::isTrained() const
...@@ -394,7 +392,7 @@ void EM::computeLogWeightDivDet() ...@@ -394,7 +392,7 @@ void EM::computeLogWeightDivDet()
} }
} }
bool EM::doTrain(int startStep, OutputArray labels, OutputArray probs, OutputArray logLikelihoods) bool EM::doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
{ {
int dim = trainSamples.cols; int dim = trainSamples.cols;
// Precompute the empty initial train data in the cases of EM::START_E_STEP and START_AUTO_STEP // Precompute the empty initial train data in the cases of EM::START_E_STEP and START_AUTO_STEP
...@@ -472,7 +470,7 @@ bool EM::doTrain(int startStep, OutputArray labels, OutputArray probs, OutputArr ...@@ -472,7 +470,7 @@ bool EM::doTrain(int startStep, OutputArray labels, OutputArray probs, OutputArr
return true; return true;
} }
void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double* logLikelihood) const Vec2d EM::computeProbabilities(const Mat& sample, Mat* probs) const
{ {
// L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)] // L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
// q = arg(max_k(L_ik)) // q = arg(max_k(L_ik))
...@@ -488,7 +486,7 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double* ...@@ -488,7 +486,7 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double*
int dim = sample.cols; int dim = sample.cols;
Mat L(1, nclusters, CV_64FC1); Mat L(1, nclusters, CV_64FC1);
label = 0; int label = 0;
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++) for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
{ {
const Mat centeredSample = sample - means.row(clusterIndex); const Mat centeredSample = sample - means.row(clusterIndex);
...@@ -511,9 +509,6 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double* ...@@ -511,9 +509,6 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double*
label = clusterIndex; label = clusterIndex;
} }
if(!probs && !logLikelihood)
return;
double maxLVal = L.at<double>(label); double maxLVal = L.at<double>(label);
Mat expL_Lmax = L; // exp(L_ij - L_iq) Mat expL_Lmax = L; // exp(L_ij - L_iq)
for(int i = 0; i < L.cols; i++) for(int i = 0; i < L.cols; i++)
...@@ -528,8 +523,11 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double* ...@@ -528,8 +523,11 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double*
expL_Lmax.copyTo(*probs); expL_Lmax.copyTo(*probs);
} }
if(logLikelihood) Vec2d res;
*logLikelihood = std::log(expDiffSum) + maxLVal - 0.5 * dim * CV_LOG2PI; res[0] = std::log(expDiffSum) + maxLVal - 0.5 * dim * CV_LOG2PI;
res[1] = label;
return res;
} }
void EM::eStep() void EM::eStep()
...@@ -547,8 +545,9 @@ void EM::eStep() ...@@ -547,8 +545,9 @@ void EM::eStep()
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++) for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
{ {
Mat sampleProbs = trainProbs.row(sampleIndex); Mat sampleProbs = trainProbs.row(sampleIndex);
computeProbabilities(trainSamples.row(sampleIndex), trainLabels.at<int>(sampleIndex), Vec2d res = computeProbabilities(trainSamples.row(sampleIndex), &sampleProbs);
&sampleProbs, &trainLogLikelihoods.at<double>(sampleIndex)); trainLogLikelihoods.at<double>(sampleIndex) = res[0];
trainLabels.at<int>(sampleIndex) = static_cast<int>(res[1]);
} }
} }
......
...@@ -373,11 +373,11 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params, ...@@ -373,11 +373,11 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
cv::EM em(params.nclusters, params.covMatType, params.termCrit); cv::EM em(params.nclusters, params.covMatType, params.termCrit);
if( params.startStep == EM::START_AUTO_STEP ) if( params.startStep == EM::START_AUTO_STEP )
em.train( trainData, labels ); em.train( trainData, noArray(), labels );
else if( params.startStep == EM::START_E_STEP ) else if( params.startStep == EM::START_E_STEP )
em.trainE( trainData, *params.means, *params.covs, *params.weights, labels ); em.trainE( trainData, *params.means, *params.covs, *params.weights, noArray(), labels );
else if( params.startStep == EM::START_M_STEP ) else if( params.startStep == EM::START_M_STEP )
em.trainM( trainData, *params.probs, labels ); em.trainM( trainData, *params.probs, noArray(), labels );
// check train error // check train error
if( !calcErr( labels, trainLabels, sizes, err , false, false ) ) if( !calcErr( labels, trainLabels, sizes, err , false, false ) )
...@@ -396,9 +396,8 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params, ...@@ -396,9 +396,8 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
for( int i = 0; i < testData.rows; i++ ) for( int i = 0; i < testData.rows; i++ )
{ {
Mat sample = testData.row(i); Mat sample = testData.row(i);
double likelihood = 0;
Mat probs; Mat probs;
labels.at<int>(i,0) = (int)em.predict( sample, probs, &likelihood ); labels.at<int>(i) = static_cast<int>(em.predict( sample, probs )[1]);
} }
if( !calcErr( labels, testLabels, sizes, err, false, false ) ) if( !calcErr( labels, testLabels, sizes, err, false, false ) )
{ {
...@@ -523,7 +522,7 @@ protected: ...@@ -523,7 +522,7 @@ protected:
Mat firstResult(samples.rows, 1, CV_32SC1); Mat firstResult(samples.rows, 1, CV_32SC1);
for( int i = 0; i < samples.rows; i++) for( int i = 0; i < samples.rows; i++)
firstResult.at<int>(i) = em.predict(samples.row(i)); firstResult.at<int>(i) = static_cast<int>(em.predict(samples.row(i))[1]);
// Write out // Write out
string filename = tempfile() + ".xml"; string filename = tempfile() + ".xml";
...@@ -564,7 +563,7 @@ protected: ...@@ -564,7 +563,7 @@ protected:
int errCaseCount = 0; int errCaseCount = 0;
for( int i = 0; i < samples.rows; i++) for( int i = 0; i < samples.rows; i++)
errCaseCount = std::abs(em.predict(samples.row(i)) - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1; errCaseCount = std::abs(em.predict(samples.row(i))[1] - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;
if( errCaseCount > 0 ) if( errCaseCount > 0 )
{ {
...@@ -637,10 +636,9 @@ protected: ...@@ -637,10 +636,9 @@ protected:
const double lambda = 1.; const double lambda = 1.;
for(int i = 0; i < samples.rows; i++) for(int i = 0; i < samples.rows; i++)
{ {
double sampleLogLikelihoods0 = 0, sampleLogLikelihoods1 = 0;
Mat sample = samples.row(i); Mat sample = samples.row(i);
model0.predict(sample, noArray(), &sampleLogLikelihoods0); double sampleLogLikelihoods0 = model0.predict(sample)[0];
model1.predict(sample, noArray(), &sampleLogLikelihoods1); double sampleLogLikelihoods1 = model1.predict(sample)[0];
int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1; int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1;
......
...@@ -478,7 +478,7 @@ void find_decision_boundary_EM() ...@@ -478,7 +478,7 @@ void find_decision_boundary_EM()
for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++) for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
{ {
if(em_models[modelIndex].isTrained()) if(em_models[modelIndex].isTrained())
em_models[modelIndex].predict( testSample, noArray(), &logLikelihoods.at<double>(modelIndex) ); logLikelihoods.at<double>(modelIndex) = em_models[modelIndex].predict(testSample)[0];
} }
Point maxLoc; Point maxLoc;
minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc); minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment