Commit 7120355e authored by Maria Dimashova's avatar Maria Dimashova

updated points_classifier sample to use bayes classifier after distributions estimation by EM

parent eaf0d38f
...@@ -442,16 +442,30 @@ void find_decision_boundary_EM() ...@@ -442,16 +442,30 @@ void find_decision_boundary_EM()
Mat trainSamples, trainClasses; Mat trainSamples, trainClasses;
prepare_train_data( trainSamples, trainClasses ); prepare_train_data( trainSamples, trainClasses );
cv::EM em; vector<cv::EM> em_models(classColors.size());
cv::EM::Params params;
params.nclusters = classColors.size();
params.covMatType = cv::EM::COV_MAT_GENERIC;
params.startStep = cv::EM::START_AUTO_STEP;
params.termCrit = cv::TermCriteria(cv::TermCriteria::COUNT + cv::TermCriteria::COUNT, 10, 0.1);
// learn classifier CV_Assert((int)trainClasses.total() == trainSamples.rows);
em.train( trainSamples, Mat(), params, &trainClasses ); CV_Assert((int)trainClasses.type() == CV_32SC1);
for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
{
const int componentCount = 3;
em_models[modelIndex] = EM(componentCount, cv::EM::COV_MAT_DIAGONAL);
Mat modelSamples;
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
{
if(trainClasses.at<int>(sampleIndex) == (int)modelIndex)
modelSamples.push_back(trainSamples.row(sampleIndex));
}
// learn models
if(!modelSamples.empty())
em_models[modelIndex].train(modelSamples);
}
// classify coordinate plane points using the bayes classifier, i.e.
// y(x) = arg max_i=1_modelsCount likelihoods_i(x)
Mat testSample(1, 2, CV_32FC1 ); Mat testSample(1, 2, CV_32FC1 );
for( int y = 0; y < img.rows; y += testStep ) for( int y = 0; y < img.rows; y += testStep )
{ {
...@@ -460,7 +474,16 @@ void find_decision_boundary_EM() ...@@ -460,7 +474,16 @@ void find_decision_boundary_EM()
testSample.at<float>(0) = (float)x; testSample.at<float>(0) = (float)x;
testSample.at<float>(1) = (float)y; testSample.at<float>(1) = (float)y;
int response = (int)em.predict( testSample ); Mat logLikelihoods(1, em_models.size(), CV_64FC1, Scalar(-DBL_MAX));
for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
{
if(em_models[modelIndex].isTrained())
em_models[modelIndex].predict( testSample, noArray(), &logLikelihoods.at<double>(modelIndex) );
}
Point maxLoc;
minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
int response = maxLoc.x;
circle( imgDst, Point(x,y), 2, classColors[response], 1 ); circle( imgDst, Point(x,y), 2, classColors[response], 1 );
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment