em.cpp 27.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                        Intel License Agreement
//                For Open Source Computer Vision Library
//
// Copyright( C) 2000, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of Intel Corporation may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
//(including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort(including negligence or otherwise) arising in any way out of
// the use of this software, even ifadvised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"

44
namespace cv
45
{
46 47
namespace ml
{
48

Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
49
const double minEigenValue = DBL_EPSILON;
50

51
class CV_EXPORTS EMImpl : public EM
52
{
53
public:
54 55 56 57 58 59 60 61

    int nclusters;
    int covMatType;
    TermCriteria termCrit;

    CV_IMPL_PROPERTY_S(TermCriteria, TermCriteria, termCrit)

    void setClustersNumber(int val)
62
    {
63
        nclusters = val;
64
        CV_Assert(nclusters >= 1);
65
    }
66

67 68 69 70
    int getClustersNumber() const
    {
        return nclusters;
    }
71

72
    void setCovarianceMatrixType(int val)
73
    {
74 75 76 77
        covMatType = val;
        CV_Assert(covMatType == COV_MAT_SPHERICAL ||
                  covMatType == COV_MAT_DIAGONAL ||
                  covMatType == COV_MAT_GENERIC);
78 79
    }

80
    int getCovarianceMatrixType() const
81
    {
82
        return covMatType;
83
    }
84

85 86 87 88 89 90 91 92 93
    EMImpl()
    {
        nclusters = DEFAULT_NCLUSTERS;
        covMatType=EM::COV_MAT_DIAGONAL;
        termCrit = TermCriteria(TermCriteria::COUNT+TermCriteria::EPS, EM::DEFAULT_MAX_ITERS, 1e-6);
    }

    virtual ~EMImpl() {}

94 95 96 97 98 99
    void clear()
    {
        trainSamples.release();
        trainProbs.release();
        trainLogLikelihoods.release();
        trainLabels.release();
100

101 102 103
        weights.release();
        means.release();
        covs.clear();
104

105 106 107
        covsEigenValues.clear();
        invCovsEigenValues.clear();
        covsRotateMats.clear();
108

109 110
        logWeightDivDet.release();
    }
111

112 113 114
    bool train(const Ptr<TrainData>& data, int)
    {
        Mat samples = data->getTrainSamples(), labels;
115
        return trainEM(samples, labels, noArray(), noArray());
116 117
    }

118
    bool trainEM(InputArray samples,
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
119
               OutputArray logLikelihoods,
120
               OutputArray labels,
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
121
               OutputArray probs)
122 123 124 125 126
    {
        Mat samplesMat = samples.getMat();
        setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
        return doTrain(START_AUTO_STEP, logLikelihoods, labels, probs);
    }
127

128
    bool trainE(InputArray samples,
129 130 131
                InputArray _means0,
                InputArray _covs0,
                InputArray _weights0,
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
132
                OutputArray logLikelihoods,
133
                OutputArray labels,
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
134
                OutputArray probs)
135 136 137 138
    {
        Mat samplesMat = samples.getMat();
        std::vector<Mat> covs0;
        _covs0.getMatVector(covs0);
139

140
        Mat means0 = _means0.getMat(), weights0 = _weights0.getMat();
141

142 143 144 145
        setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
                     !_covs0.empty() ? &covs0 : 0, !_weights0.empty() ? &weights0 : 0);
        return doTrain(START_E_STEP, logLikelihoods, labels, probs);
    }
146

147
    bool trainM(InputArray samples,
148
                InputArray _probs0,
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
149
                OutputArray logLikelihoods,
150
                OutputArray labels,
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
151
                OutputArray probs)
152
    {
153 154
        Mat samplesMat = samples.getMat();
        Mat probs0 = _probs0.getMat();
155

156 157
        setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
        return doTrain(START_M_STEP, logLikelihoods, labels, probs);
158 159
    }

160 161 162 163
    float predict(InputArray _inputs, OutputArray _outputs, int) const
    {
        bool needprobs = _outputs.needed();
        Mat samples = _inputs.getMat(), probs, probsrow;
164
        int ptype = CV_64F;
165 166
        float firstres = 0.f;
        int i, nsamples = samples.rows;
167

168 169 170 171
        if( needprobs )
        {
            if( _outputs.fixedType() )
                ptype = _outputs.type();
172
            _outputs.create(samples.rows, nclusters, ptype);
173
            probs = _outputs.getMat();
174 175 176
        }
        else
            nsamples = std::min(nsamples, 1);
177

178
        for( i = 0; i < nsamples; i++ )
179
        {
180 181 182 183 184
            if( needprobs )
                probsrow = probs.row(i);
            Vec2d res = computeProbabilities(samples.row(i), needprobs ? &probsrow : 0, ptype);
            if( i == 0 )
                firstres = (float)res[1];
185
        }
186
        return firstres;
187 188
    }

189
    Vec2d predict2(InputArray _sample, OutputArray _probs) const
190
    {
191
        int ptype = CV_64F;
192 193
        Mat sample = _sample.getMat();
        CV_Assert(isTrained());
194

195 196 197 198 199 200 201
        CV_Assert(!sample.empty());
        if(sample.type() != CV_64FC1)
        {
            Mat tmp;
            sample.convertTo(tmp, CV_64FC1);
            sample = tmp;
        }
202
        sample = sample.reshape(1, 1);
203

204 205 206 207 208
        Mat probs;
        if( _probs.needed() )
        {
            if( _probs.fixedType() )
                ptype = _probs.type();
209
            _probs.create(1, nclusters, ptype);
210 211
            probs = _probs.getMat();
        }
212

213
        return computeProbabilities(sample, !probs.empty() ? &probs : 0, ptype);
214
    }
215

216
    bool isTrained() const
217
    {
218
        return !means.empty();
219 220
    }

221
    bool isClassifier() const
222
    {
223
        return true;
224 225
    }

226
    int getVarCount() const
227
    {
228
        return means.cols;
229 230
    }

231
    String getDefaultName() const
232
    {
233 234
        return "opencv_ml_em";
    }
235

236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    static void checkTrainData(int startStep, const Mat& samples,
                               int nclusters, int covMatType, const Mat* probs, const Mat* means,
                               const std::vector<Mat>* covs, const Mat* weights)
    {
        // Check samples.
        CV_Assert(!samples.empty());
        CV_Assert(samples.channels() == 1);

        int nsamples = samples.rows;
        int dim = samples.cols;

        // Check training params.
        CV_Assert(nclusters > 0);
        CV_Assert(nclusters <= nsamples);
        CV_Assert(startStep == START_AUTO_STEP ||
                  startStep == START_E_STEP ||
                  startStep == START_M_STEP);
        CV_Assert(covMatType == COV_MAT_GENERIC ||
                  covMatType == COV_MAT_DIAGONAL ||
                  covMatType == COV_MAT_SPHERICAL);

        CV_Assert(!probs ||
            (!probs->empty() &&
             probs->rows == nsamples && probs->cols == nclusters &&
             (probs->type() == CV_32FC1 || probs->type() == CV_64FC1)));

        CV_Assert(!weights ||
            (!weights->empty() &&
             (weights->cols == 1 || weights->rows == 1) && static_cast<int>(weights->total()) == nclusters &&
             (weights->type() == CV_32FC1 || weights->type() == CV_64FC1)));

        CV_Assert(!means ||
            (!means->empty() &&
             means->rows == nclusters && means->cols == dim &&
             means->channels() == 1));

        CV_Assert(!covs ||
            (!covs->empty() &&
             static_cast<int>(covs->size()) == nclusters));
        if(covs)
276
        {
277 278 279 280 281 282
            const Size covSize(dim, dim);
            for(size_t i = 0; i < covs->size(); i++)
            {
                const Mat& m = (*covs)[i];
                CV_Assert(!m.empty() && m.size() == covSize && (m.channels() == 1));
            }
283
        }
284 285

        if(startStep == START_E_STEP)
286
        {
287
            CV_Assert(means);
288
        }
289
        else if(startStep == START_M_STEP)
290
        {
291
            CV_Assert(probs);
292
        }
293
    }
294

295
    static void preprocessSampleData(const Mat& src, Mat& dst, int dstType, bool isAlwaysClone)
296
    {
297 298
        if(src.type() == dstType && !isAlwaysClone)
            dst = src;
299
        else
300
            src.convertTo(dst, dstType);
301
    }
302

303
    static void preprocessProbability(Mat& probs)
304
    {
305
        max(probs, 0., probs);
306

307 308
        const double uniformProbability = (double)(1./probs.cols);
        for(int y = 0; y < probs.rows; y++)
309
        {
310
            Mat sampleProbs = probs.row(y);
311

312 313 314 315 316 317 318
            double maxVal = 0;
            minMaxLoc(sampleProbs, 0, &maxVal);
            if(maxVal < FLT_EPSILON)
                sampleProbs.setTo(uniformProbability);
            else
                normalize(sampleProbs, sampleProbs, 1, 0, NORM_L1);
        }
319 320
    }

321 322 323 324 325 326 327
    void setTrainData(int startStep, const Mat& samples,
                      const Mat* probs0,
                      const Mat* means0,
                      const std::vector<Mat>* covs0,
                      const Mat* weights0)
    {
        clear();
328

329
        checkTrainData(startStep, samples, nclusters, covMatType, probs0, means0, covs0, weights0);
330

331 332 333
        bool isKMeansInit = (startStep == START_AUTO_STEP) || (startStep == START_E_STEP && (covs0 == 0 || weights0 == 0));
        // Set checked data
        preprocessSampleData(samples, trainSamples, isKMeansInit ? CV_32FC1 : CV_64FC1, false);
334

335 336 337 338 339 340
        // set probs
        if(probs0 && startStep == START_M_STEP)
        {
            preprocessSampleData(*probs0, trainProbs, CV_64FC1, true);
            preprocessProbability(trainProbs);
        }
341

342 343 344 345
        // set weights
        if(weights0 && (startStep == START_E_STEP && covs0))
        {
            weights0->convertTo(weights, CV_64FC1);
346
            weights = weights.reshape(1,1);
347 348
            preprocessProbability(weights);
        }
349

350 351 352 353 354 355 356 357 358 359 360
        // set means
        if(means0 && (startStep == START_E_STEP/* || startStep == START_AUTO_STEP*/))
            means0->convertTo(means, isKMeansInit ? CV_32FC1 : CV_64FC1);

        // set covs
        if(covs0 && (startStep == START_E_STEP && weights0))
        {
            covs.resize(nclusters);
            for(size_t i = 0; i < covs0->size(); i++)
                (*covs0)[i].convertTo(covs[i], CV_64FC1);
        }
361
    }
362

363
    void decomposeCovs()
364
    {
365 366 367 368 369 370
        CV_Assert(!covs.empty());
        covsEigenValues.resize(nclusters);
        if(covMatType == COV_MAT_GENERIC)
            covsRotateMats.resize(nclusters);
        invCovsEigenValues.resize(nclusters);
        for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
371
        {
372 373 374 375 376 377 378 379 380 381 382
            CV_Assert(!covs[clusterIndex].empty());

            SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);

            if(covMatType == COV_MAT_SPHERICAL)
            {
                double maxSingularVal = svd.w.at<double>(0);
                covsEigenValues[clusterIndex] = Mat(1, 1, CV_64FC1, Scalar(maxSingularVal));
            }
            else if(covMatType == COV_MAT_DIAGONAL)
            {
art-programmer's avatar
art-programmer committed
383
                covsEigenValues[clusterIndex] = covs[clusterIndex].diag().clone(); //Preserve the original order of eigen values.
384 385 386 387 388 389 390 391
            }
            else //COV_MAT_GENERIC
            {
                covsEigenValues[clusterIndex] = svd.w;
                covsRotateMats[clusterIndex] = svd.u;
            }
            max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
            invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
392 393 394
        }
    }

395
    void clusterTrainSamples()
396
    {
397
        int nsamples = trainSamples.rows;
398

399
        // Cluster samples, compute/update means
400

401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
        // Convert samples and means to 32F, because kmeans requires this type.
        Mat trainSamplesFlt, meansFlt;
        if(trainSamples.type() != CV_32FC1)
            trainSamples.convertTo(trainSamplesFlt, CV_32FC1);
        else
            trainSamplesFlt = trainSamples;
        if(!means.empty())
        {
            if(means.type() != CV_32FC1)
                means.convertTo(meansFlt, CV_32FC1);
            else
                meansFlt = means;
        }

        Mat labels;
        kmeans(trainSamplesFlt, nclusters, labels,
               TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5),
               10, KMEANS_PP_CENTERS, meansFlt);
419

420 421 422 423 424 425 426 427 428
        // Convert samples and means back to 64F.
        CV_Assert(meansFlt.type() == CV_32FC1);
        if(trainSamples.type() != CV_64FC1)
        {
            Mat trainSamplesBuffer;
            trainSamplesFlt.convertTo(trainSamplesBuffer, CV_64FC1);
            trainSamples = trainSamplesBuffer;
        }
        meansFlt.convertTo(means, CV_64FC1);
429

430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
        // Compute weights and covs
        weights = Mat(1, nclusters, CV_64FC1, Scalar(0));
        covs.resize(nclusters);
        for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
        {
            Mat clusterSamples;
            for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++)
            {
                if(labels.at<int>(sampleIndex) == clusterIndex)
                {
                    const Mat sample = trainSamples.row(sampleIndex);
                    clusterSamples.push_back(sample);
                }
            }
            CV_Assert(!clusterSamples.empty());
445

446 447 448 449
            calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex),
                CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_64FC1);
            weights.at<double>(clusterIndex) = static_cast<double>(clusterSamples.rows)/static_cast<double>(nsamples);
        }
450

451
        decomposeCovs();
452
    }
453

454
    void computeLogWeightDivDet()
455
    {
456 457 458 459 460 461 462 463 464 465 466 467 468 469
        CV_Assert(!covsEigenValues.empty());

        Mat logWeights;
        cv::max(weights, DBL_MIN, weights);
        log(weights, logWeights);

        logWeightDivDet.create(1, nclusters, CV_64FC1);
        // note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)

        for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
        {
            double logDetCov = 0.;
            const int evalCount = static_cast<int>(covsEigenValues[clusterIndex].total());
            for(int di = 0; di < evalCount; di++)
470
                logDetCov += std::log(covsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0));
471 472 473

            logWeightDivDet.at<double>(clusterIndex) = logWeights.at<double>(clusterIndex) - 0.5 * logDetCov;
        }
474 475
    }

476
    bool doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
477
    {
478 479 480
        int dim = trainSamples.cols;
        // Precompute the empty initial train data in the cases of START_E_STEP and START_AUTO_STEP
        if(startStep != START_M_STEP)
481
        {
482 483 484 485 486
            if(covs.empty())
            {
                CV_Assert(weights.empty());
                clusterTrainSamples();
            }
487
        }
488 489

        if(!covs.empty() && covsEigenValues.empty() )
490
        {
491 492
            CV_Assert(invCovsEigenValues.empty());
            decomposeCovs();
493
        }
494

495 496
        if(startStep == START_M_STEP)
            mStep();
497

498
        double trainLogLikelihood, prevTrainLogLikelihood = 0.;
499 500 501
        int maxIters = (termCrit.type & TermCriteria::MAX_ITER) ?
            termCrit.maxCount : DEFAULT_MAX_ITERS;
        double epsilon = (termCrit.type & TermCriteria::EPS) ? termCrit.epsilon : 0.;
502

503 504 505 506
        for(int iter = 0; ; iter++)
        {
            eStep();
            trainLogLikelihood = sum(trainLogLikelihoods)[0];
507

508 509 510 511 512 513 514 515
            if(iter >= maxIters - 1)
                break;

            double trainLogLikelihoodDelta = trainLogLikelihood - prevTrainLogLikelihood;
            if( iter != 0 &&
                (trainLogLikelihoodDelta < -DBL_EPSILON ||
                 trainLogLikelihoodDelta < epsilon * std::fabs(trainLogLikelihood)))
                break;
516

517
            mStep();
518

519 520 521 522
            prevTrainLogLikelihood = trainLogLikelihood;
        }

        if( trainLogLikelihood <= -DBL_MAX/10000. )
523
        {
524 525
            clear();
            return false;
526 527
        }

528 529 530 531
        // postprocess covs
        covs.resize(nclusters);
        for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
        {
532
            if(covMatType == COV_MAT_SPHERICAL)
533 534 535 536
            {
                covs[clusterIndex].create(dim, dim, CV_64FC1);
                setIdentity(covs[clusterIndex], Scalar(covsEigenValues[clusterIndex].at<double>(0)));
            }
537
            else if(covMatType == COV_MAT_DIAGONAL)
538 539 540 541
            {
                covs[clusterIndex] = Mat::diag(covsEigenValues[clusterIndex]);
            }
        }
542

543 544 545 546 547 548
        if(labels.needed())
            trainLabels.copyTo(labels);
        if(probs.needed())
            trainProbs.copyTo(probs);
        if(logLikelihoods.needed())
            trainLogLikelihoods.copyTo(logLikelihoods);
549

550 551 552 553
        trainSamples.release();
        trainProbs.release();
        trainLabels.release();
        trainLogLikelihoods.release();
554

555 556
        return true;
    }
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
557

558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
    Vec2d computeProbabilities(const Mat& sample, Mat* probs, int ptype) const
    {
        // L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
        // q = arg(max_k(L_ik))
        // probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
        // see Alex Smola's blog http://blog.smola.org/page/2 for
        // details on the log-sum-exp trick

        int stype = sample.type();
        CV_Assert(!means.empty());
        CV_Assert((stype == CV_32F || stype == CV_64F) && (ptype == CV_32F || ptype == CV_64F));
        CV_Assert(sample.size() == Size(means.cols, 1));

        int dim = sample.cols;

        Mat L(1, nclusters, CV_64FC1), centeredSample(1, dim, CV_64F);
        int i, label = 0;
        for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
        {
            const double* mptr = means.ptr<double>(clusterIndex);
            double* dptr = centeredSample.ptr<double>();
            if( stype == CV_32F )
            {
                const float* sptr = sample.ptr<float>();
                for( i = 0; i < dim; i++ )
                    dptr[i] = sptr[i] - mptr[i];
            }
            else
            {
                const double* sptr = sample.ptr<double>();
                for( i = 0; i < dim; i++ )
                    dptr[i] = sptr[i] - mptr[i];
            }
591

592 593
            Mat rotatedCenteredSample = covMatType != COV_MAT_GENERIC ?
                    centeredSample : centeredSample * covsRotateMats[clusterIndex];
594

595 596 597 598 599 600 601 602 603
            double Lval = 0;
            for(int di = 0; di < dim; di++)
            {
                double w = invCovsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0);
                double val = rotatedCenteredSample.at<double>(di);
                Lval += w * val * val;
            }
            CV_DbgAssert(!logWeightDivDet.empty());
            L.at<double>(clusterIndex) = logWeightDivDet.at<double>(clusterIndex) - 0.5 * Lval;
604

605 606 607
            if(L.at<double>(clusterIndex) > L.at<double>(label))
                label = clusterIndex;
        }
608

609 610 611 612 613 614 615 616
        double maxLVal = L.at<double>(label);
        double expDiffSum = 0;
        for( i = 0; i < L.cols; i++ )
        {
            double v = std::exp(L.at<double>(i) - maxLVal);
            L.at<double>(i) = v;
            expDiffSum += v; // sum_j(exp(L_ij - L_iq))
        }
617

618 619
        if(probs)
            L.convertTo(*probs, ptype, 1./expDiffSum);
620

621 622 623
        Vec2d res;
        res[0] = std::log(expDiffSum)  + maxLVal - 0.5 * dim * CV_LOG2PI;
        res[1] = label;
624

625 626
        return res;
    }
627

628
    void eStep()
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
629
    {
630
        // Compute probs_ik from means_k, covs_k and weights_k.
631
        trainProbs.create(trainSamples.rows, nclusters, CV_64FC1);
632 633
        trainLabels.create(trainSamples.rows, 1, CV_32SC1);
        trainLogLikelihoods.create(trainSamples.rows, 1, CV_64FC1);
634

635 636 637 638
        computeLogWeightDivDet();

        CV_DbgAssert(trainSamples.type() == CV_64FC1);
        CV_DbgAssert(means.type() == CV_64FC1);
639

Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
640
        for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
641 642 643 644 645 646
        {
            Mat sampleProbs = trainProbs.row(sampleIndex);
            Vec2d res = computeProbabilities(trainSamples.row(sampleIndex), &sampleProbs, CV_64F);
            trainLogLikelihoods.at<double>(sampleIndex) = res[0];
            trainLabels.at<int>(sampleIndex) = static_cast<int>(res[1]);
        }
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
647 648
    }

649
    void mStep()
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
650
    {
651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668
        // Update means_k, covs_k and weights_k from probs_ik
        int dim = trainSamples.cols;

        // Update weights
        // not normalized first
        reduce(trainProbs, weights, 0, CV_REDUCE_SUM);

        // Update means
        means.create(nclusters, dim, CV_64FC1);
        means = Scalar(0);

        const double minPosWeight = trainSamples.rows * DBL_EPSILON;
        double minWeight = DBL_MAX;
        int minWeightClusterIndex = -1;
        for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
        {
            if(weights.at<double>(clusterIndex) <= minPosWeight)
                continue;
669

670 671 672 673 674
            if(weights.at<double>(clusterIndex) < minWeight)
            {
                minWeight = weights.at<double>(clusterIndex);
                minWeightClusterIndex = clusterIndex;
            }
675

676 677 678 679 680
            Mat clusterMean = means.row(clusterIndex);
            for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
                clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
            clusterMean /= weights.at<double>(clusterIndex);
        }
681

682 683 684 685 686 687 688
        // Update covsEigenValues and invCovsEigenValues
        covs.resize(nclusters);
        covsEigenValues.resize(nclusters);
        if(covMatType == COV_MAT_GENERIC)
            covsRotateMats.resize(nclusters);
        invCovsEigenValues.resize(nclusters);
        for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
689
        {
690 691
            if(weights.at<double>(clusterIndex) <= minPosWeight)
                continue;
692

693 694
            if(covMatType != COV_MAT_SPHERICAL)
                covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
695
            else
696 697 698 699 700 701 702 703 704 705 706 707
                covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);

            if(covMatType == COV_MAT_GENERIC)
                covs[clusterIndex].create(dim, dim, CV_64FC1);

            Mat clusterCov = covMatType != COV_MAT_GENERIC ?
                covsEigenValues[clusterIndex] : covs[clusterIndex];

            clusterCov = Scalar(0);

            Mat centeredSample;
            for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
708
            {
709 710 711 712 713
                centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);

                if(covMatType == COV_MAT_GENERIC)
                    clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
                else
714
                {
715 716 717 718 719 720
                    double p = trainProbs.at<double>(sampleIndex, clusterIndex);
                    for(int di = 0; di < dim; di++ )
                    {
                        double val = centeredSample.at<double>(di);
                        clusterCov.at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0) += p*val*val;
                    }
721
                }
722 723
            }

724 725 726 727 728 729 730 731 732 733 734 735
            if(covMatType == COV_MAT_SPHERICAL)
                clusterCov /= dim;

            clusterCov /= weights.at<double>(clusterIndex);

            // Update covsRotateMats for COV_MAT_GENERIC only
            if(covMatType == COV_MAT_GENERIC)
            {
                SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
                covsEigenValues[clusterIndex] = svd.w;
                covsRotateMats[clusterIndex] = svd.u;
            }
736

737
            max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
738

739 740 741 742 743
            // update invCovsEigenValues
            invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
        }

        for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
744
        {
745 746 747 748 749 750 751 752 753 754
            if(weights.at<double>(clusterIndex) <= minPosWeight)
            {
                Mat clusterMean = means.row(clusterIndex);
                means.row(minWeightClusterIndex).copyTo(clusterMean);
                covs[minWeightClusterIndex].copyTo(covs[clusterIndex]);
                covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]);
                if(covMatType == COV_MAT_GENERIC)
                    covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]);
                invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]);
            }
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
755
        }
756

757 758 759
        // Normalize weights
        weights /= trainSamples.rows;
    }
760

761 762
    void write_params(FileStorage& fs) const
    {
763 764 765 766 767 768
        fs << "nclusters" << nclusters;
        fs << "cov_mat_type" << (covMatType == COV_MAT_SPHERICAL ? String("spherical") :
                                 covMatType == COV_MAT_DIAGONAL ? String("diagonal") :
                                 covMatType == COV_MAT_GENERIC ? String("generic") :
                                 format("unknown_%d", covMatType));
        writeTermCrit(fs, termCrit);
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
769
    }
770

771
    void write(FileStorage& fs) const
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
772
    {
773
        writeFormat(fs);
774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789
        fs << "training_params" << "{";
        write_params(fs);
        fs << "}";
        fs << "weights" << weights;
        fs << "means" << means;

        size_t i, n = covs.size();

        fs << "covs" << "[";
        for( i = 0; i < n; i++ )
            fs << covs[i];
        fs << "]";
    }

    void read_params(const FileNode& fn)
    {
790
        nclusters = (int)fn["nclusters"];
791
        String s = (String)fn["cov_mat_type"];
792
        covMatType = s == "spherical" ? COV_MAT_SPHERICAL :
793 794
                             s == "diagonal" ? COV_MAT_DIAGONAL :
                             s == "generic" ? COV_MAT_GENERIC : -1;
795 796
        CV_Assert(covMatType >= 0);
        termCrit = readTermCrit(fn);
797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816
    }

    void read(const FileNode& fn)
    {
        clear();
        read_params(fn["training_params"]);

        fn["weights"] >> weights;
        fn["means"] >> means;

        FileNode cfn = fn["covs"];
        FileNodeIterator cfn_it = cfn.begin();
        int i, n = (int)cfn.size();
        covs.resize(n);

        for( i = 0; i < n; i++, ++cfn_it )
            (*cfn_it) >> covs[i];

        decomposeCovs();
        computeLogWeightDivDet();
817
    }
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
818

819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842
    Mat getWeights() const { return weights; }
    Mat getMeans() const { return means; }
    void getCovs(std::vector<Mat>& _covs) const
    {
        _covs.resize(covs.size());
        std::copy(covs.begin(), covs.end(), _covs.begin());
    }

    // all inner matrices have type CV_64FC1
    Mat trainSamples;
    Mat trainProbs;
    Mat trainLogLikelihoods;
    Mat trainLabels;

    Mat weights;
    Mat means;
    std::vector<Mat> covs;

    std::vector<Mat> covsEigenValues;
    std::vector<Mat> covsRotateMats;
    std::vector<Mat> invCovsEigenValues;
    Mat logWeightDivDet;
};

843
Ptr<EM> EM::create()
844
{
845
    return makePtr<EMImpl>();
846 847 848
}

}
849
} // namespace cv
850 851

/* End of file. */