em.cpp 21.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                        Intel License Agreement
//                For Open Source Computer Vision Library
//
// Copyright( C) 2000, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of Intel Corporation may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
//(including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort(including negligence or otherwise) arising in any way out of
// the use of this software, even ifadvised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"

44
namespace cv
45 46
{

47
const double minEigenValue = DBL_EPSILON;
48 49

///////////////////////////////////////////////////////////////////////////////////////////////////////
50

51
EM::EM(int _nclusters, int _covMatType, const TermCriteria& _criteria)
52
{
53 54 55 56
    nclusters = _nclusters;
    covMatType = _covMatType;
    maxIters = (_criteria.type & TermCriteria::MAX_ITER) ? _criteria.maxCount : DEFAULT_MAX_ITERS;
    epsilon = (_criteria.type & TermCriteria::EPS) ? _criteria.epsilon : 0;
57 58
}

59
EM::~EM()
60
{
61
    //clear();
62 63
}

64
void EM::clear()
65
{
66 67
    trainSamples.release();
    trainProbs.release();
68
    trainLogLikelihoods.release();
69
    trainLabels.release();
70

71 72 73
    weights.release();
    means.release();
    covs.clear();
74

75 76 77 78 79
    covsEigenValues.clear();
    invCovsEigenValues.clear();
    covsRotateMats.clear();

    logWeightDivDet.release();
80 81
}

82 83 84 85
    
bool EM::train(InputArray samples,
               OutputArray labels,
               OutputArray probs,
86
               OutputArray logLikelihoods)
87
{
88 89
    Mat samplesMat = samples.getMat();
    setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
90
    return doTrain(START_AUTO_STEP, labels, probs, logLikelihoods);
91
}
92

93 94 95 96 97 98
bool EM::trainE(InputArray samples,
                InputArray _means0,
                InputArray _covs0,
                InputArray _weights0,
                OutputArray labels,
                OutputArray probs,
99
                OutputArray logLikelihoods)
100
{
101
    Mat samplesMat = samples.getMat();
102 103 104 105 106
    vector<Mat> covs0;
    _covs0.getMatVector(covs0);
    
    Mat means0 = _means0.getMat(), weights0 = _weights0.getMat();

107
    setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
108
                 !_covs0.empty() ? &covs0 : 0, _weights0.empty() ? &weights0 : 0);
109
    return doTrain(START_E_STEP, labels, probs, logLikelihoods);
110
}
111

112 113 114 115
bool EM::trainM(InputArray samples,
                InputArray _probs0,
                OutputArray labels,
                OutputArray probs,
116
                OutputArray logLikelihoods)
117
{
118
    Mat samplesMat = samples.getMat();
119 120
    Mat probs0 = _probs0.getMat();
    
121
    setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
122
    return doTrain(START_M_STEP, labels, probs, logLikelihoods);
123 124
}

125
    
126
int EM::predict(InputArray _sample, OutputArray _probs, double* logLikelihood) const
127
{
128
    Mat sample = _sample.getMat();
129
    CV_Assert(isTrained());
130

131
    CV_Assert(!sample.empty());
132 133 134 135 136 137
    if(sample.type() != CV_64FC1)
    {
        Mat tmp;
        sample.convertTo(tmp, CV_64FC1);
        sample = tmp;
    }
138

139
    int label;
140 141 142
    Mat probs;
    if( _probs.needed() )
    {
143
        _probs.create(1, nclusters, CV_64FC1);
144 145
        probs = _probs.getMat();
    }
146
    computeProbabilities(sample, label, !probs.empty() ? &probs : 0, logLikelihood);
147

148 149
    return label;
}
150

151 152 153 154
bool EM::isTrained() const
{
    return !means.empty();
}
155 156


157
static
158 159 160
void checkTrainData(int startStep, const Mat& samples,
                    int nclusters, int covMatType, const Mat* probs, const Mat* means,
                    const vector<Mat>* covs, const Mat* weights)
161 162 163
{
    // Check samples.
    CV_Assert(!samples.empty());
164
    CV_Assert(samples.channels() == 1);
165 166 167 168 169

    int nsamples = samples.rows;
    int dim = samples.cols;

    // Check training params.
170 171 172 173 174
    CV_Assert(nclusters > 0);
    CV_Assert(nclusters <= nsamples);
    CV_Assert(startStep == EM::START_AUTO_STEP ||
              startStep == EM::START_E_STEP ||
              startStep == EM::START_M_STEP);
175 176 177
    CV_Assert(covMatType == EM::COV_MAT_GENERIC ||
              covMatType == EM::COV_MAT_DIAGONAL ||
              covMatType == EM::COV_MAT_SPHERICAL);
178 179 180 181

    CV_Assert(!probs ||
        (!probs->empty() &&
         probs->rows == nsamples && probs->cols == nclusters &&
182
         (probs->type() == CV_32FC1 || probs->type() == CV_64FC1)));
183 184 185 186

    CV_Assert(!weights ||
        (!weights->empty() &&
         (weights->cols == 1 || weights->rows == 1) && static_cast<int>(weights->total()) == nclusters &&
187
         (weights->type() == CV_32FC1 || weights->type() == CV_64FC1)));
188 189 190 191

    CV_Assert(!means ||
        (!means->empty() &&
         means->rows == nclusters && means->cols == dim &&
192
         means->channels() == 1));
193 194 195 196 197

    CV_Assert(!covs ||
        (!covs->empty() &&
         static_cast<int>(covs->size()) == nclusters));
    if(covs)
198
    {
199 200
        const Size covSize(dim, dim);
        for(size_t i = 0; i < covs->size(); i++)
201
        {
202
            const Mat& m = (*covs)[i];
203
            CV_Assert(!m.empty() && m.size() == covSize && (m.channels() == 1));
204 205
        }
    }
206

207
    if(startStep == EM::START_E_STEP)
208
    {
209
        CV_Assert(means);
210
    }
211
    else if(startStep == EM::START_M_STEP)
212
    {
213
        CV_Assert(probs);
214
    }
215
}
216

217
static
218
void preprocessSampleData(const Mat& src, Mat& dst, int dstType, bool isAlwaysClone)
219
{
220 221
    if(src.type() == dstType && !isAlwaysClone)
        dst = src;
222
    else
223
        src.convertTo(dst, dstType);
224 225
}

226
static
227
void preprocessProbability(Mat& probs)
228
{
229
    max(probs, 0., probs);
230

231
    const double uniformProbability = (double)(1./probs.cols);
232 233
    for(int y = 0; y < probs.rows; y++)
    {
234
        Mat sampleProbs = probs.row(y);
235

236
        double maxVal = 0;
237
        minMaxLoc(sampleProbs, 0, &maxVal);
238 239 240
        if(maxVal < FLT_EPSILON)
            sampleProbs.setTo(uniformProbability);
        else
241
            normalize(sampleProbs, sampleProbs, 1, 0, NORM_L1);
242
    }
243
}
244

245 246 247 248 249
void EM::setTrainData(int startStep, const Mat& samples,
                      const Mat* probs0,
                      const Mat* means0,
                      const vector<Mat>* covs0,
                      const Mat* weights0)
250
{
251
    clear();
252

253
    checkTrainData(startStep, samples, nclusters, covMatType, probs0, means0, covs0, weights0);
254

255
    bool isKMeansInit = (startStep == EM::START_AUTO_STEP) || (startStep == EM::START_E_STEP && (covs0 == 0 || weights0 == 0));
256
    // Set checked data
257
    preprocessSampleData(samples, trainSamples, isKMeansInit ? CV_32FC1 : CV_64FC1, false);
258

259
    // set probs
260
    if(probs0 && startStep == EM::START_M_STEP)
261
    {
262
        preprocessSampleData(*probs0, trainProbs, CV_64FC1, true);
263
        preprocessProbability(trainProbs);
264 265
    }

266
    // set weights
267
    if(weights0 && (startStep == EM::START_E_STEP && covs0))
268
    {
269
        weights0->convertTo(weights, CV_64FC1);
270 271
        weights.reshape(1,1);
        preprocessProbability(weights);
272 273
    }

274
    // set means
275 276
    if(means0 && (startStep == EM::START_E_STEP/* || startStep == EM::START_AUTO_STEP*/))
        means0->convertTo(means, isKMeansInit ? CV_32FC1 : CV_64FC1);
277

278
    // set covs
279
    if(covs0 && (startStep == EM::START_E_STEP && weights0))
280
    {
281
        covs.resize(nclusters);
282
        for(size_t i = 0; i < covs0->size(); i++)
283
            (*covs0)[i].convertTo(covs[i], CV_64FC1);
284 285 286
    }
}

287
void EM::decomposeCovs()
288
{
289 290 291 292 293 294 295 296
    CV_Assert(!covs.empty());
    covsEigenValues.resize(nclusters);
    if(covMatType == EM::COV_MAT_GENERIC)
        covsRotateMats.resize(nclusters);
    invCovsEigenValues.resize(nclusters);
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    {
        CV_Assert(!covs[clusterIndex].empty());
297

298
        SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
299

300
        if(covMatType == EM::COV_MAT_SPHERICAL)
301
        {
302 303
            double maxSingularVal = svd.w.at<double>(0);
            covsEigenValues[clusterIndex] = Mat(1, 1, CV_64FC1, Scalar(maxSingularVal));
304
        }
305
        else if(covMatType == EM::COV_MAT_DIAGONAL)
306
        {
307
            covsEigenValues[clusterIndex] = svd.w;
308
        }
309 310 311 312 313
        else //EM::COV_MAT_GENERIC
        {
            covsEigenValues[clusterIndex] = svd.w;
            covsRotateMats[clusterIndex] = svd.u;
        }
314
        max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
315
        invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
316 317
    }
}
318

319
void EM::clusterTrainSamples()
320
{
321 322 323
    int nsamples = trainSamples.rows;

    // Cluster samples, compute/update means
324 325

    // Convert samples and means to 32F, because kmeans requires this type.
326 327 328 329 330 331 332 333 334 335 336 337 338
    Mat trainSamplesFlt, meansFlt;
    if(trainSamples.type() != CV_32FC1)
        trainSamples.convertTo(trainSamplesFlt, CV_32FC1);
    else
        trainSamplesFlt = trainSamples;
    if(!means.empty())
    {
        if(means.type() != CV_32FC1)
            means.convertTo(meansFlt, CV_32FC1);
        else
            meansFlt = means;
    }

339
    Mat labels;
340 341
    kmeans(trainSamplesFlt, nclusters, labels,  TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5), 10, KMEANS_PP_CENTERS, meansFlt);

342
    // Convert samples and means back to 64F.
343 344
    CV_Assert(meansFlt.type() == CV_32FC1);
    if(trainSamples.type() != CV_64FC1)
345 346 347 348 349
    {
        Mat trainSamplesBuffer;
        trainSamplesFlt.convertTo(trainSamplesBuffer, CV_64FC1);
        trainSamples = trainSamplesBuffer;
    }
350
    meansFlt.convertTo(means, CV_64FC1);
351 352

    // Compute weights and covs
353
    weights = Mat(1, nclusters, CV_64FC1, Scalar(0));
354 355
    covs.resize(nclusters);
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
356
    {
357
        Mat clusterSamples;
358
        for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++)
359
        {
360
            if(labels.at<int>(sampleIndex) == clusterIndex)
361
            {
362
                const Mat sample = trainSamples.row(sampleIndex);
363
                clusterSamples.push_back(sample);
364 365
            }
        }
366
        CV_Assert(!clusterSamples.empty());
367

368
        calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex),
369 370
            CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_64FC1);
        weights.at<double>(clusterIndex) = static_cast<double>(clusterSamples.rows)/static_cast<double>(nsamples);
371 372
    }

373
    decomposeCovs();
374 375
}

376
void EM::computeLogWeightDivDet()
377
{
378
    CV_Assert(!covsEigenValues.empty());
379

380
    Mat logWeights;
381
    cv::max(weights, DBL_MIN, weights);
382
    log(weights, logWeights);
383

384
    logWeightDivDet.create(1, nclusters, CV_64FC1);
385
    // note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)
386

387
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
388
    {
389
        double logDetCov = 0.;
390
        for(int di = 0; di < covsEigenValues[clusterIndex].cols; di++)
391
            logDetCov += std::log(covsEigenValues[clusterIndex].at<double>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0));
392

393
        logWeightDivDet.at<double>(clusterIndex) = logWeights.at<double>(clusterIndex) - 0.5 * logDetCov;
394 395
    }
}
396

397
bool EM::doTrain(int startStep, OutputArray labels, OutputArray probs, OutputArray logLikelihoods)
398 399 400 401
{
    int dim = trainSamples.cols;
    // Precompute the empty initial train data in the cases of EM::START_E_STEP and START_AUTO_STEP
    if(startStep != EM::START_M_STEP)
402
    {
403
        if(covs.empty())
404
        {
405
            CV_Assert(weights.empty());
406
            clusterTrainSamples();
407 408 409
        }
    }

410 411 412 413 414
    if(!covs.empty() && covsEigenValues.empty() )
    {
        CV_Assert(invCovsEigenValues.empty());
        decomposeCovs();
    }
415

416 417
    if(startStep == EM::START_M_STEP)
        mStep();
418

419
    double trainLogLikelihood, prevTrainLogLikelihood = 0.;
420 421 422
    for(int iter = 0; ; iter++)
    {
        eStep();
423
        trainLogLikelihood = sum(trainLogLikelihoods)[0];
424

425
        if(iter >= maxIters - 1)
426
            break;
427

428
        double trainLogLikelihoodDelta = trainLogLikelihood - prevTrainLogLikelihood;
429
        if( iter != 0 &&
430 431
            (trainLogLikelihoodDelta < -DBL_EPSILON ||
             trainLogLikelihoodDelta < epsilon * std::fabs(trainLogLikelihood)))
432
            break;
433

434
        mStep();
435

436
        prevTrainLogLikelihood = trainLogLikelihood;
437 438
    }

439
    if( trainLogLikelihood <= -DBL_MAX/10000. )
440 441
    {
        clear();
442
        return false;
443
    }
444

445 446 447
    // postprocess covs
    covs.resize(nclusters);
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
448
    {
449
        if(covMatType == EM::COV_MAT_SPHERICAL)
450
        {
451 452
            covs[clusterIndex].create(dim, dim, CV_64FC1);
            setIdentity(covs[clusterIndex], Scalar(covsEigenValues[clusterIndex].at<double>(0)));
453
        }
454
        else if(covMatType == EM::COV_MAT_DIAGONAL)
455 456 457
        {
            covs[clusterIndex] = Mat::diag(covsEigenValues[clusterIndex]);
        }
458
    }
459 460 461 462 463
    
    if(labels.needed())
        trainLabels.copyTo(labels);
    if(probs.needed())
        trainProbs.copyTo(probs);
464 465
    if(logLikelihoods.needed())
        trainLogLikelihoods.copyTo(logLikelihoods);
466 467 468 469
    
    trainSamples.release();
    trainProbs.release();
    trainLabels.release();
470
    trainLogLikelihoods.release();
471

472
    return true;
473 474
}

475
void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double* logLikelihood) const
476
{
477 478
    // L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
    // q = arg(max_k(L_ik))
479
    // probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
480 481
    // see Alex Smola's blog http://blog.smola.org/page/2 for
    // details on the log-sum-exp trick
482

483 484
    CV_Assert(!means.empty());
    CV_Assert(sample.type() == CV_64FC1);
Maria Dimashova's avatar
Maria Dimashova committed
485
    CV_Assert(sample.rows == 1);
486
    CV_Assert(sample.cols == means.cols);
487

488
    int dim = sample.cols;
489

490
    Mat L(1, nclusters, CV_64FC1);
491 492
    label = 0;
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
493
    {
494
        const Mat centeredSample = sample - means.row(clusterIndex);
495

496
        Mat rotatedCenteredSample = covMatType != EM::COV_MAT_GENERIC ?
497
                centeredSample : centeredSample * covsRotateMats[clusterIndex];
498

499
        double Lval = 0;
500
        for(int di = 0; di < dim; di++)
501
        {
502 503
            double w = invCovsEigenValues[clusterIndex].at<double>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0);
            double val = rotatedCenteredSample.at<double>(di);
504
            Lval += w * val * val;
505
        }
506
        CV_DbgAssert(!logWeightDivDet.empty());
507 508
        Lval = logWeightDivDet.at<double>(clusterIndex) - 0.5 * Lval;
        L.at<double>(clusterIndex) = Lval;
509

510
        if(Lval > L.at<double>(label))
511 512
            label = clusterIndex;
    }
513

514
    if(!probs && !logLikelihood)
515
        return;
516

Maria Dimashova's avatar
Maria Dimashova committed
517
    double maxLVal = L.at<double>(label);
518
    Mat expL_Lmax = L; // exp(L_ij - L_iq)
Maria Dimashova's avatar
Maria Dimashova committed
519 520
    for(int i = 0; i < L.cols; i++)
        expL_Lmax.at<double>(i) = std::exp(L.at<double>(i) - maxLVal);
521
    double expDiffSum = sum(expL_Lmax)[0]; // sum_j(exp(L_ij - L_iq))
522

523 524 525
    if(probs)
    {
        probs->create(1, nclusters, CV_64FC1);
526
        double factor = 1./expDiffSum;
527 528 529
        expL_Lmax *= factor;
        expL_Lmax.copyTo(*probs);
    }
530

531
    if(logLikelihood)
532
        *logLikelihood = std::log(expDiffSum)  + maxLVal - 0.5 * dim * CV_LOG2PI;
533 534
}

535
void EM::eStep()
536
{
537
    // Compute probs_ik from means_k, covs_k and weights_k.
538
    trainProbs.create(trainSamples.rows, nclusters, CV_64FC1);
539
    trainLabels.create(trainSamples.rows, 1, CV_32SC1);
540
    trainLogLikelihoods.create(trainSamples.rows, 1, CV_64FC1);
541

542
    computeLogWeightDivDet();
543

544 545 546
    CV_DbgAssert(trainSamples.type() == CV_64FC1);
    CV_DbgAssert(means.type() == CV_64FC1);

547 548
    for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
    {
549
        Mat sampleProbs = trainProbs.row(sampleIndex);
550
        computeProbabilities(trainSamples.row(sampleIndex), trainLabels.at<int>(sampleIndex),
551
                             &sampleProbs, &trainLogLikelihoods.at<double>(sampleIndex));
552
    }
553 554
}

555
void EM::mStep()
556
{
557 558
    // Update means_k, covs_k and weights_k from probs_ik
    int dim = trainSamples.cols;
559

560 561 562
    // Update weights
    // not normalized first
    reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
563

564 565 566
    // Update means
    means.create(nclusters, dim, CV_64FC1);
    means = Scalar(0);
567

568 569 570 571 572 573 574
    const double minPosWeight = trainSamples.rows * DBL_EPSILON;
    double minWeight = DBL_MAX;
    int minWeightClusterIndex = -1;
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    {
        if(weights.at<double>(clusterIndex) <= minPosWeight)
            continue;
575

576
        if(weights.at<double>(clusterIndex) < minWeight)
577
        {
578 579
            minWeight = weights.at<double>(clusterIndex);
            minWeightClusterIndex = clusterIndex;
580 581
        }

582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
        Mat clusterMean = means.row(clusterIndex);
        for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
            clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
        clusterMean /= weights.at<double>(clusterIndex);
    }

    // Update covsEigenValues and invCovsEigenValues
    covs.resize(nclusters);
    covsEigenValues.resize(nclusters);
    if(covMatType == EM::COV_MAT_GENERIC)
        covsRotateMats.resize(nclusters);
    invCovsEigenValues.resize(nclusters);
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    {
        if(weights.at<double>(clusterIndex) <= minPosWeight)
            continue;

        if(covMatType != EM::COV_MAT_SPHERICAL)
            covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
        else
            covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);

604
        if(covMatType == EM::COV_MAT_GENERIC)
605
            covs[clusterIndex].create(dim, dim, CV_64FC1);
606

607 608
        Mat clusterCov = covMatType != EM::COV_MAT_GENERIC ?
            covsEigenValues[clusterIndex] : covs[clusterIndex];
609

610
        clusterCov = Scalar(0);
611

612 613 614 615
        Mat centeredSample;
        for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
        {
            centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
616

617 618 619
            if(covMatType == EM::COV_MAT_GENERIC)
                clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
            else
620
            {
621 622
                double p = trainProbs.at<double>(sampleIndex, clusterIndex);
                for(int di = 0; di < dim; di++ )
623
                {
624 625
                    double val = centeredSample.at<double>(di);
                    clusterCov.at<double>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0) += p*val*val;
626
                }
627
            }
628
        }
629

630 631
        if(covMatType == EM::COV_MAT_SPHERICAL)
            clusterCov /= dim;
632

633
        clusterCov /= weights.at<double>(clusterIndex);
634

635 636 637 638 639 640 641
        // Update covsRotateMats for EM::COV_MAT_GENERIC only
        if(covMatType == EM::COV_MAT_GENERIC)
        {
            SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
            covsEigenValues[clusterIndex] = svd.w;
            covsRotateMats[clusterIndex] = svd.u;
        }
642

643
        max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
644

645 646 647
        // update invCovsEigenValues
        invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
    }
648

649 650 651 652 653 654 655 656 657 658 659 660
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    {
        if(weights.at<double>(clusterIndex) <= minPosWeight)
        {
            Mat clusterMean = means.row(clusterIndex);
            means.row(minWeightClusterIndex).copyTo(clusterMean);
            covs[minWeightClusterIndex].copyTo(covs[clusterIndex]);
            covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]);
            if(covMatType == EM::COV_MAT_GENERIC)
                covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]);
            invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]);
        }
661
    }
662 663 664

    // Normalize weights
    weights /= trainSamples.rows;
665 666
}

667
void EM::read(const FileNode& fn)
668
{
669
    Algorithm::read(fn);
670

671 672
    decomposeCovs();
    computeLogWeightDivDet();
673 674
}

675
static Algorithm* createEM()
676
{
677
    return new EM;
678
}
679
static AlgorithmInfo em_info("StatModel.EM", createEM);
680

681
AlgorithmInfo* EM::info() const
682
{
683 684
    static volatile bool initialized = false;
    if( !initialized )
685
    {
686 687 688
        EM obj;
        em_info.addParam(obj, "nclusters", obj.nclusters);
        em_info.addParam(obj, "covMatType", obj.covMatType);
689

690 691 692
        em_info.addParam(obj, "weights", obj.weights);
        em_info.addParam(obj, "means", obj.means);
        em_info.addParam(obj, "covs", obj.covs);
693

694
        initialized = true;
695
    }
696
    return &em_info;
697
}
698
} // namespace cv
699 700

/* End of file. */