svm.cpp 78.6 KB
Newer Older
1 2 3 4 5 6 7 8 9
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
10 11
//                           License Agreement
//                For Open Source Computer Vision Library
12 13
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
14
// Copyright (C) 2014, Itseez Inc, all rights reserved.
15 16 17 18 19 20 21 22 23 24 25 26
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
27
//   * The name of the copyright holders may not be used to endorse or promote products
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"

45 46 47
#include <stdarg.h>
#include <ctype.h>

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
/****************************************************************************************\
                                COPYRIGHT NOTICE
                                ----------------

  The code has been derived from libsvm library (version 2.6)
  (http://www.csie.ntu.edu.tw/~cjlin/libsvm).

  Here is the orignal copyright:
------------------------------------------------------------------------------------------
    Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
    All rights reserved.

    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions
    are met:

    1. Redistributions of source code must retain the above copyright
    notice, this list of conditions and the following disclaimer.

    2. Redistributions in binary form must reproduce the above copyright
    notice, this list of conditions and the following disclaimer in the
    documentation and/or other materials provided with the distribution.

    3. Neither name of copyright holders nor the names of its contributors
    may be used to endorse or promote products derived from this software
    without specific prior written permission.


    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\****************************************************************************************/

89
namespace cv { namespace ml {
90 91

typedef float Qfloat;
92
const int QFLOAT_TYPE = DataDepth<Qfloat>::value;
93 94

// Param Grid
95
static void checkParamGrid(const ParamGrid& pg)
96
{
97 98 99 100 101 102
    if( pg.minVal > pg.maxVal )
        CV_Error( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
    if( pg.minVal < DBL_EPSILON )
        CV_Error( CV_StsBadArg, "Lower bound of the grid must be positive" );
    if( pg.logStep < 1. + FLT_EPSILON )
        CV_Error( CV_StsBadArg, "Grid step must greater then 1" );
103 104 105
}

// SVM training parameters
106
struct SvmParams
107
{
108 109 110 111 112 113 114 115 116 117
    int         svmType;
    int         kernelType;
    double      gamma;
    double      coef0;
    double      degree;
    double      C;
    double      nu;
    double      p;
    Mat         classWeights;
    TermCriteria termCrit;
118

119 120 121 122 123 124 125 126 127 128 129 130
    SvmParams()
    {
        svmType = SVM::C_SVC;
        kernelType = SVM::RBF;
        degree = 0;
        gamma = 1;
        coef0 = 0;
        C = 1;
        nu = 0;
        p = 0;
        termCrit = TermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
    }
131

132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    SvmParams( int _svmType, int _kernelType,
            double _degree, double _gamma, double _coef0,
            double _Con, double _nu, double _p,
            const Mat& _classWeights, TermCriteria _termCrit )
    {
        svmType = _svmType;
        kernelType = _kernelType;
        degree = _degree;
        gamma = _gamma;
        coef0 = _coef0;
        C = _Con;
        nu = _nu;
        p = _p;
        classWeights = _classWeights;
        termCrit = _termCrit;
    }

};
150 151

/////////////////////////////////////// SVM kernel ///////////////////////////////////////
152
class SVMKernelImpl : public SVM::Kernel
153
{
154
public:
155
    SVMKernelImpl( const SvmParams& _params = SvmParams() )
156
    {
157
        params = _params;
158 159
    }

160 161 162 163
    int getType() const
    {
        return params.kernelType;
    }
164

165 166 167 168 169 170
    void calc_non_rbf_base( int vcount, int var_count, const float* vecs,
                            const float* another, Qfloat* results,
                            double alpha, double beta )
    {
        int j, k;
        for( j = 0; j < vcount; j++ )
171
        {
172 173 174 175 176 177 178 179
            const float* sample = &vecs[j*var_count];
            double s = 0;
            for( k = 0; k <= var_count - 4; k += 4 )
                s += sample[k]*another[k] + sample[k+1]*another[k+1] +
                sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
            for( ; k < var_count; k++ )
                s += sample[k]*another[k];
            results[j] = (Qfloat)(s*alpha + beta);
180 181 182
        }
    }

183 184
    void calc_linear( int vcount, int var_count, const float* vecs,
                      const float* another, Qfloat* results )
185
    {
186
        calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
Andrey Kamaev's avatar
Andrey Kamaev committed
187
    }
188

189 190
    void calc_poly( int vcount, int var_count, const float* vecs,
                    const float* another, Qfloat* results )
191
    {
192 193 194 195 196 197 198 199
        Mat R( 1, vcount, QFLOAT_TYPE, results );
        calc_non_rbf_base( vcount, var_count, vecs, another, results, params.gamma, params.coef0 );
        if( vcount > 0 )
            pow( R, params.degree, R );
    }

    void calc_sigmoid( int vcount, int var_count, const float* vecs,
                       const float* another, Qfloat* results )
Andrey Kamaev's avatar
Andrey Kamaev committed
200
    {
201 202 203 204 205
        int j;
        calc_non_rbf_base( vcount, var_count, vecs, another, results,
                          -2*params.gamma, -2*params.coef0 );
        // TODO: speedup this
        for( j = 0; j < vcount; j++ )
Andrey Kamaev's avatar
Andrey Kamaev committed
206
        {
207 208 209 210 211 212
            Qfloat t = results[j];
            Qfloat e = std::exp(-std::abs(t));
            if( t > 0 )
                results[j] = (Qfloat)((1. - e)/(1. + e));
            else
                results[j] = (Qfloat)((e - 1.)/(e + 1.));
Andrey Kamaev's avatar
Andrey Kamaev committed
213 214 215 216
        }
    }


217 218
    void calc_rbf( int vcount, int var_count, const float* vecs,
                   const float* another, Qfloat* results )
219
    {
220 221
        double gamma = -params.gamma;
        int j, k;
222

223 224 225 226
        for( j = 0; j < vcount; j++ )
        {
            const float* sample = &vecs[j*var_count];
            double s = 0;
227

228 229 230 231
            for( k = 0; k <= var_count - 4; k += 4 )
            {
                double t0 = sample[k] - another[k];
                double t1 = sample[k+1] - another[k+1];
232

233
                s += t0*t0 + t1*t1;
234

235 236
                t0 = sample[k+2] - another[k+2];
                t1 = sample[k+3] - another[k+3];
237

238 239
                s += t0*t0 + t1*t1;
            }
240

241 242 243 244 245 246 247
            for( ; k < var_count; k++ )
            {
                double t0 = sample[k] - another[k];
                s += t0*t0;
            }
            results[j] = (Qfloat)(s*gamma);
        }
248

249 250 251 252 253
        if( vcount > 0 )
        {
            Mat R( 1, vcount, QFLOAT_TYPE, results );
            exp( R, R );
        }
254
    }
255 256 257 258

    /// Histogram intersection kernel
    void calc_intersec( int vcount, int var_count, const float* vecs,
                        const float* another, Qfloat* results )
259
    {
260 261 262 263 264 265 266 267 268 269 270 271
        int j, k;
        for( j = 0; j < vcount; j++ )
        {
            const float* sample = &vecs[j*var_count];
            double s = 0;
            for( k = 0; k <= var_count - 4; k += 4 )
                s += std::min(sample[k],another[k]) + std::min(sample[k+1],another[k+1]) +
                std::min(sample[k+2],another[k+2]) + std::min(sample[k+3],another[k+3]);
            for( ; k < var_count; k++ )
                s += std::min(sample[k],another[k]);
            results[j] = (Qfloat)(s);
        }
272 273
    }

274 275 276
    /// Exponential chi2 kernel
    void calc_chi2( int vcount, int var_count, const float* vecs,
                    const float* another, Qfloat* results )
277
    {
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
        Mat R( 1, vcount, QFLOAT_TYPE, results );
        double gamma = -params.gamma;
        int j, k;
        for( j = 0; j < vcount; j++ )
        {
            const float* sample = &vecs[j*var_count];
            double chi2 = 0;
            for(k = 0 ; k < var_count; k++ )
            {
                double d = sample[k]-another[k];
                double devisor = sample[k]+another[k];
                /// if devisor == 0, the Chi2 distance would be zero,
                // but calculation would rise an error because of deviding by zero
                if (devisor != 0)
                {
                    chi2 += d*d/devisor;
                }
            }
            results[j] = (Qfloat) (gamma*chi2);
        }
        if( vcount > 0 )
            exp( R, R );
300
    }
301

302 303
    void calc( int vcount, int var_count, const float* vecs,
               const float* another, Qfloat* results )
304
    {
305
        switch( params.kernelType )
306
        {
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
        case SVM::LINEAR:
            calc_linear(vcount, var_count, vecs, another, results);
            break;
        case SVM::RBF:
            calc_rbf(vcount, var_count, vecs, another, results);
            break;
        case SVM::POLY:
            calc_poly(vcount, var_count, vecs, another, results);
            break;
        case SVM::SIGMOID:
            calc_sigmoid(vcount, var_count, vecs, another, results);
            break;
        case SVM::CHI2:
            calc_chi2(vcount, var_count, vecs, another, results);
            break;
        case SVM::INTER:
            calc_intersec(vcount, var_count, vecs, another, results);
            break;
        default:
            CV_Error(CV_StsBadArg, "Unknown kernel type");
327
        }
328 329
        const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
        for( int j = 0; j < vcount; j++ )
330
        {
331 332
            if( results[j] > max_val )
                results[j] = max_val;
333 334 335
        }
    }

336
    SvmParams params;
337
};
338 339 340



341
/////////////////////////////////////////////////////////////////////////
342

343 344
static void sortSamplesByClasses( const Mat& _samples, const Mat& _responses,
                           vector<int>& sidx_all, vector<int>& class_ranges )
345
{
346 347
    int i, nsamples = _samples.rows;
    CV_Assert( _responses.isContinuous() && _responses.checkVector(1, CV_32S) == nsamples );
348

349
    setRangeVector(sidx_all, nsamples);
350

351 352 353 354
    const int* rptr = _responses.ptr<int>();
    std::sort(sidx_all.begin(), sidx_all.end(), cmp_lt_idx<int>(rptr));
    class_ranges.clear();
    class_ranges.push_back(0);
355

356 357 358 359 360 361
    for( i = 0; i < nsamples; i++ )
    {
        if( i == nsamples-1 || rptr[sidx_all[i]] != rptr[sidx_all[i+1]] )
            class_ranges.push_back(i+1);
    }
}
362

363
//////////////////////// SVM implementation //////////////////////////////
364

365 366 367 368 369 370
Ptr<ParamGrid> SVM::getDefaultGridPtr( int param_id)
{
  ParamGrid grid = getDefaultGrid(param_id); // this is not a nice solution..
  return makePtr<ParamGrid>(grid.minVal, grid.maxVal, grid.logStep);
}

371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
ParamGrid SVM::getDefaultGrid( int param_id )
{
    ParamGrid grid;
    if( param_id == SVM::C )
    {
        grid.minVal = 0.1;
        grid.maxVal = 500;
        grid.logStep = 5; // total iterations = 5
    }
    else if( param_id == SVM::GAMMA )
    {
        grid.minVal = 1e-5;
        grid.maxVal = 0.6;
        grid.logStep = 15; // total iterations = 4
    }
    else if( param_id == SVM::P )
    {
        grid.minVal = 0.01;
        grid.maxVal = 100;
        grid.logStep = 7; // total iterations = 4
    }
    else if( param_id == SVM::NU )
    {
        grid.minVal = 0.01;
        grid.maxVal = 0.2;
        grid.logStep = 3; // total iterations = 3
    }
    else if( param_id == SVM::COEF )
    {
        grid.minVal = 0.1;
        grid.maxVal = 300;
        grid.logStep = 14; // total iterations = 3
    }
    else if( param_id == SVM::DEGREE )
    {
        grid.minVal = 0.01;
        grid.maxVal = 4;
        grid.logStep = 7; // total iterations = 3
    }
    else
        cvError( CV_StsBadArg, "SVM::getDefaultGrid", "Invalid type of parameter "
                "(use one of SVM::C, SVM::GAMMA et al.)", __FILE__, __LINE__ );
    return grid;
}

416

417
class SVMImpl : public SVM
418
{
419 420
public:
    struct DecisionFunc
421
    {
422 423 424 425 426
        DecisionFunc(double _rho, int _ofs) : rho(_rho), ofs(_ofs) {}
        DecisionFunc() : rho(0.), ofs(0) {}
        double rho;
        int ofs;
    };
427

428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
    // Generalized SMO+SVMlight algorithm
    // Solves:
    //
    //  min [0.5(\alpha^T Q \alpha) + b^T \alpha]
    //
    //      y^T \alpha = \delta
    //      y_i = +1 or -1
    //      0 <= alpha_i <= Cp for y_i = 1
    //      0 <= alpha_i <= Cn for y_i = -1
    //
    // Given:
    //
    //  Q, b, y, Cp, Cn, and an initial feasible point \alpha
    //  l is the size of vectors and matrices
    //  eps is the stopping criterion
    //
    // solution will be put in \alpha, objective value will be put in obj
    //
    class Solver
    {
    public:
        enum { MIN_CACHE_SIZE = (40 << 20) /* 40Mb */, MAX_CACHE_SIZE = (500 << 20) /* 500Mb */ };

        typedef bool (Solver::*SelectWorkingSet)( int& i, int& j );
        typedef Qfloat* (Solver::*GetRow)( int i, Qfloat* row, Qfloat* dst, bool existed );
        typedef void (Solver::*CalcRho)( double& rho, double& r );

        struct KernelRow
        {
            KernelRow() { idx = -1; prev = next = 0; }
            KernelRow(int _idx, int _prev, int _next) : idx(_idx), prev(_prev), next(_next) {}
            int idx;
            int prev;
            int next;
        };

        struct SolutionInfo
        {
            SolutionInfo() { obj = rho = upper_bound_p = upper_bound_n = r = 0; }
            double obj;
            double rho;
            double upper_bound_p;
            double upper_bound_n;
            double r;   // for Solver_NU
        };

        void clear()
        {
            alpha_vec = 0;
            select_working_set_func = 0;
            calc_rho_func = 0;
            get_row_func = 0;
            lru_cache.clear();
481 482
        }

483 484 485 486 487 488
        Solver( const Mat& _samples, const vector<schar>& _y,
                vector<double>& _alpha, const vector<double>& _b,
                double _Cp, double _Cn,
                const Ptr<SVM::Kernel>& _kernel, GetRow _get_row,
                SelectWorkingSet _select_working_set, CalcRho _calc_rho,
                TermCriteria _termCrit )
489
        {
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
            clear();

            samples = _samples;
            sample_count = samples.rows;
            var_count = samples.cols;

            y_vec = _y;
            alpha_vec = &_alpha;
            alpha_count = (int)alpha_vec->size();
            b_vec = _b;
            kernel = _kernel;

            C[0] = _Cn;
            C[1] = _Cp;
            eps = _termCrit.epsilon;
            max_iter = _termCrit.maxCount;

            G_vec.resize(alpha_count);
            alpha_status_vec.resize(alpha_count);
            buf[0].resize(sample_count*2);
            buf[1].resize(sample_count*2);

            select_working_set_func = _select_working_set;
            CV_Assert(select_working_set_func != 0);

            calc_rho_func = _calc_rho;
            CV_Assert(calc_rho_func != 0);

            get_row_func = _get_row;
            CV_Assert(get_row_func != 0);

            // assume that for large training sets ~25% of Q matrix is used
            int64 csize = (int64)sample_count*sample_count/4;
            csize = std::max(csize, (int64)(MIN_CACHE_SIZE/sizeof(Qfloat)) );
            csize = std::min(csize, (int64)(MAX_CACHE_SIZE/sizeof(Qfloat)) );
            max_cache_size = (int)((csize + sample_count-1)/sample_count);
            max_cache_size = std::min(std::max(max_cache_size, 1), sample_count);
            cache_size = 0;

            lru_cache.clear();
            lru_cache.resize(sample_count+1, KernelRow(-1, 0, 0));
            lru_first = lru_last = 0;
            lru_cache_data.create(max_cache_size, sample_count, QFLOAT_TYPE);
        }
534

535 536 537 538 539 540 541
        Qfloat* get_row_base( int i, bool* _existed )
        {
            int i1 = i < sample_count ? i : i - sample_count;
            KernelRow& kr = lru_cache[i1+1];
            if( _existed )
                *_existed = kr.idx >= 0;
            if( kr.idx < 0 )
542
            {
543 544 545 546
                if( cache_size < max_cache_size )
                {
                    kr.idx = cache_size;
                    cache_size++;
547 548
                    if (!lru_last)
                        lru_last = i1+1;
549 550 551 552 553 554 555 556
                }
                else
                {
                    KernelRow& last = lru_cache[lru_last];
                    kr.idx = last.idx;
                    last.idx = -1;
                    lru_cache[last.prev].next = 0;
                    lru_last = last.prev;
557 558
                    last.prev = 0;
                    last.next = 0;
559 560 561
                }
                kernel->calc( sample_count, var_count, samples.ptr<float>(),
                              samples.ptr<float>(i1), lru_cache_data.ptr<Qfloat>(kr.idx) );
562
            }
563
            else
564
            {
565 566 567 568 569 570 571 572
                if( kr.next )
                    lru_cache[kr.next].prev = kr.prev;
                else
                    lru_last = kr.prev;
                if( kr.prev )
                    lru_cache[kr.prev].next = kr.next;
                else
                    lru_first = kr.next;
573
            }
574 575
            if (lru_first)
                lru_cache[lru_first].prev = i1+1;
576 577 578
            kr.next = lru_first;
            kr.prev = 0;
            lru_first = i1+1;
579

580
            return lru_cache_data.ptr<Qfloat>(kr.idx);
581 582
        }

583 584 585
        Qfloat* get_row_svc( int i, Qfloat* row, Qfloat*, bool existed )
        {
            if( !existed )
586
            {
587 588
                const schar* _y = &y_vec[0];
                int j, len = sample_count;
589

590 591 592 593 594 595 596 597 598 599
                if( _y[i] > 0 )
                {
                    for( j = 0; j < len; j++ )
                        row[j] = _y[j]*row[j];
                }
                else
                {
                    for( j = 0; j < len; j++ )
                        row[j] = -_y[j]*row[j];
                }
600
            }
601
            return row;
602 603
        }

604 605 606 607
        Qfloat* get_row_one_class( int, Qfloat* row, Qfloat*, bool )
        {
            return row;
        }
608

609 610 611 612 613 614 615
        Qfloat* get_row_svr( int i, Qfloat* row, Qfloat* dst, bool )
        {
            int j, len = sample_count;
            Qfloat* dst_pos = dst;
            Qfloat* dst_neg = dst + len;
            if( i >= len )
                std::swap(dst_pos, dst_neg);
616

617 618 619 620 621 622 623 624
            for( j = 0; j < len; j++ )
            {
                Qfloat t = row[j];
                dst_pos[j] = t;
                dst_neg[j] = -t;
            }
            return dst;
        }
625

626 627 628 629 630 631
        Qfloat* get_row( int i, float* dst )
        {
            bool existed = false;
            float* row = get_row_base( i, &existed );
            return (this->*get_row_func)( i, row, dst, existed );
        }
632

633 634
        #undef is_upper_bound
        #define is_upper_bound(i) (alpha_status[i] > 0)
635

636 637
        #undef is_lower_bound
        #define is_lower_bound(i) (alpha_status[i] < 0)
638

639 640
        #undef is_free
        #define is_free(i) (alpha_status[i] == 0)
641

642 643
        #undef get_C
        #define get_C(i) (C[y[i]>0])
644

645 646 647
        #undef update_alpha_status
        #define update_alpha_status(i) \
            alpha_status[i] = (schar)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
648

649 650
        #undef reconstruct_gradient
        #define reconstruct_gradient() /* empty for now */
651

652 653 654 655 656 657 658
        bool solve_generic( SolutionInfo& si )
        {
            const schar* y = &y_vec[0];
            double* alpha = &alpha_vec->at(0);
            schar* alpha_status = &alpha_status_vec[0];
            double* G = &G_vec[0];
            double* b = &b_vec[0];
659

660 661
            int iter = 0;
            int i, j, k;
662

663 664
            // 1. initialize gradient and alpha status
            for( i = 0; i < alpha_count; i++ )
665
            {
666 667 668 669
                update_alpha_status(i);
                G[i] = b[i];
                if( fabs(G[i]) > 1e200 )
                    return false;
670
            }
671 672

            for( i = 0; i < alpha_count; i++ )
673
            {
674 675 676 677 678 679 680 681
                if( !is_lower_bound(i) )
                {
                    const Qfloat *Q_i = get_row( i, &buf[0][0] );
                    double alpha_i = alpha[i];

                    for( j = 0; j < alpha_count; j++ )
                        G[j] += alpha_i*Q_i[j];
                }
682
            }
683 684 685

            // 2. optimization loop
            for(;;)
686
            {
687 688 689 690
                const Qfloat *Q_i, *Q_j;
                double C_i, C_j;
                double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
                double delta_alpha_i, delta_alpha_j;
691

692 693 694 695 696
        #ifdef _DEBUG
                for( i = 0; i < alpha_count; i++ )
                {
                    if( fabs(G[i]) > 1e+300 )
                        return false;
697

698 699 700 701
                    if( fabs(alpha[i]) > 1e16 )
                        return false;
                }
        #endif
702

703 704
                if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
                    break;
705

706 707
                Q_i = get_row( i, &buf[0][0] );
                Q_j = get_row( j, &buf[1][0] );
708

709 710
                C_i = get_C(i);
                C_j = get_C(j);
711

712 713
                alpha_i = old_alpha_i = alpha[i];
                alpha_j = old_alpha_j = alpha[j];
714

715 716 717 718 719 720 721
                if( y[i] != y[j] )
                {
                    double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
                    double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
                    double diff = alpha_i - alpha_j;
                    alpha_i += delta;
                    alpha_j += delta;
722

723 724 725 726 727 728 729 730 731 732
                    if( diff > 0 && alpha_j < 0 )
                    {
                        alpha_j = 0;
                        alpha_i = diff;
                    }
                    else if( diff <= 0 && alpha_i < 0 )
                    {
                        alpha_i = 0;
                        alpha_j = -diff;
                    }
733

734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751
                    if( diff > C_i - C_j && alpha_i > C_i )
                    {
                        alpha_i = C_i;
                        alpha_j = C_i - diff;
                    }
                    else if( diff <= C_i - C_j && alpha_j > C_j )
                    {
                        alpha_j = C_j;
                        alpha_i = C_j + diff;
                    }
                }
                else
                {
                    double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
                    double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
                    double sum = alpha_i + alpha_j;
                    alpha_i -= delta;
                    alpha_j += delta;
752

753 754 755 756 757 758 759 760 761 762
                    if( sum > C_i && alpha_i > C_i )
                    {
                        alpha_i = C_i;
                        alpha_j = sum - C_i;
                    }
                    else if( sum <= C_i && alpha_j < 0)
                    {
                        alpha_j = 0;
                        alpha_i = sum;
                    }
763

764 765 766 767 768 769 770 771 772 773 774
                    if( sum > C_j && alpha_j > C_j )
                    {
                        alpha_j = C_j;
                        alpha_i = sum - C_j;
                    }
                    else if( sum <= C_j && alpha_i < 0 )
                    {
                        alpha_i = 0;
                        alpha_j = sum;
                    }
                }
775

776 777 778 779 780
                // update alpha
                alpha[i] = alpha_i;
                alpha[j] = alpha_j;
                update_alpha_status(i);
                update_alpha_status(j);
781

782 783 784
                // update G
                delta_alpha_i = alpha_i - old_alpha_i;
                delta_alpha_j = alpha_j - old_alpha_j;
785

786 787
                for( k = 0; k < alpha_count; k++ )
                    G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
788 789
            }

790 791
            // calculate rho
            (this->*calc_rho_func)( si.rho, si.r );
792

793 794 795
            // calculate objective value
            for( i = 0, si.obj = 0; i < alpha_count; i++ )
                si.obj += alpha[i] * (G[i] + b[i]);
796

797
            si.obj *= 0.5;
798

799 800
            si.upper_bound_p = C[1];
            si.upper_bound_n = C[0];
801

802
            return true;
803 804
        }

805 806
        // return 1 if already optimal, return 0 otherwise
        bool select_working_set( int& out_i, int& out_j )
807
        {
808 809 810 811 812
            // return i,j which maximize -grad(f)^T d , under constraint
            // if alpha_i == C, d != +1
            // if alpha_i == 0, d != -1
            double Gmax1 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = +1 }
            int Gmax1_idx = -1;
813

814 815
            double Gmax2 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = -1 }
            int Gmax2_idx = -1;
816

817 818 819
            const schar* y = &y_vec[0];
            const schar* alpha_status = &alpha_status_vec[0];
            const double* G = &G_vec[0];
820

821
            for( int i = 0; i < alpha_count; i++ )
822
            {
823
                double t;
824

825
                if( y[i] > 0 )    // y = +1
826
                {
827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845
                    if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
                    {
                        Gmax1 = t;
                        Gmax1_idx = i;
                    }
                    if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
                    {
                        Gmax2 = t;
                        Gmax2_idx = i;
                    }
                }
                else        // y = -1
                {
                    if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 )  // d = +1
                    {
                        Gmax2 = t;
                        Gmax2_idx = i;
                    }
                    if( !is_lower_bound(i) && (t = G[i]) > Gmax1 )  // d = -1
846
                    {
847 848
                        Gmax1 = t;
                        Gmax1_idx = i;
849 850 851
                    }
                }
            }
852 853 854 855 856

            out_i = Gmax1_idx;
            out_j = Gmax2_idx;

            return Gmax1 + Gmax2 < eps;
857 858
        }

859
        void calc_rho( double& rho, double& r )
860
        {
861 862 863 864 865 866 867
            int nr_free = 0;
            double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
            const schar* y = &y_vec[0];
            const schar* alpha_status = &alpha_status_vec[0];
            const double* G = &G_vec[0];

            for( int i = 0; i < alpha_count; i++ )
868
            {
869
                double yG = y[i]*G[i];
870

871
                if( is_lower_bound(i) )
872
                {
873 874 875 876
                    if( y[i] > 0 )
                        ub = MIN(ub,yG);
                    else
                        lb = MAX(lb,yG);
877
                }
878
                else if( is_upper_bound(i) )
879
                {
880 881 882 883
                    if( y[i] < 0)
                        ub = MIN(ub,yG);
                    else
                        lb = MAX(lb,yG);
884
                }
885
                else
886
                {
887 888
                    ++nr_free;
                    sum_free += yG;
889
                }
890 891 892 893 894 895 896 897 898 899 900 901 902
            }

            rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
            r = 0;
        }

        bool select_working_set_nu_svm( int& out_i, int& out_j )
        {
            // return i,j which maximize -grad(f)^T d , under constraint
            // if alpha_i == C, d != +1
            // if alpha_i == 0, d != -1
            double Gmax1 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = +1 }
            int Gmax1_idx = -1;
903

904 905
            double Gmax2 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = -1 }
            int Gmax2_idx = -1;
906

907 908
            double Gmax3 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = +1 }
            int Gmax3_idx = -1;
909

910 911
            double Gmax4 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = -1 }
            int Gmax4_idx = -1;
912

913 914 915
            const schar* y = &y_vec[0];
            const schar* alpha_status = &alpha_status_vec[0];
            const double* G = &G_vec[0];
916

917 918 919 920 921
            for( int i = 0; i < alpha_count; i++ )
            {
                double t;

                if( y[i] > 0 )    // y == +1
922
                {
923 924 925 926 927 928
                    if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
                    {
                        Gmax1 = t;
                        Gmax1_idx = i;
                    }
                    if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
929
                    {
930 931
                        Gmax2 = t;
                        Gmax2_idx = i;
932 933
                    }
                }
934
                else        // y == -1
935
                {
936 937 938 939 940 941
                    if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 )  // d = +1
                    {
                        Gmax3 = t;
                        Gmax3_idx = i;
                    }
                    if( !is_lower_bound(i) && (t = G[i]) > Gmax4 )  // d = -1
942
                    {
943 944
                        Gmax4 = t;
                        Gmax4_idx = i;
945 946 947 948
                    }
                }
            }

949 950
            if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
                return 1;
951

952
            if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
953
            {
954 955
                out_i = Gmax1_idx;
                out_j = Gmax2_idx;
956
            }
957 958 959 960 961 962
            else
            {
                out_i = Gmax3_idx;
                out_j = Gmax4_idx;
            }
            return 0;
963 964
        }

965
        void calc_rho_nu_svm( double& rho, double& r )
966
        {
967 968 969 970 971 972 973 974 975 976
            int nr_free1 = 0, nr_free2 = 0;
            double ub1 = DBL_MAX, ub2 = DBL_MAX;
            double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
            double sum_free1 = 0, sum_free2 = 0;

            const schar* y = &y_vec[0];
            const schar* alpha_status = &alpha_status_vec[0];
            const double* G = &G_vec[0];

            for( int i = 0; i < alpha_count; i++ )
977
            {
978 979
                double G_i = G[i];
                if( y[i] > 0 )
980
                {
981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001
                    if( is_lower_bound(i) )
                        ub1 = MIN( ub1, G_i );
                    else if( is_upper_bound(i) )
                        lb1 = MAX( lb1, G_i );
                    else
                    {
                        ++nr_free1;
                        sum_free1 += G_i;
                    }
                }
                else
                {
                    if( is_lower_bound(i) )
                        ub2 = MIN( ub2, G_i );
                    else if( is_upper_bound(i) )
                        lb2 = MAX( lb2, G_i );
                    else
                    {
                        ++nr_free2;
                        sum_free2 += G_i;
                    }
1002 1003
                }
            }
1004

1005 1006
            double r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
            double r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
1007

1008 1009
            rho = (r1 - r2)*0.5;
            r = (r1 + r2)*0.5;
1010
        }
1011

1012 1013 1014 1015 1016 1017 1018 1019
        /*
        ///////////////////////// construct and solve various formulations ///////////////////////
        */
        static bool solve_c_svc( const Mat& _samples, const vector<schar>& _y,
                                 double _Cp, double _Cn, const Ptr<SVM::Kernel>& _kernel,
                                 vector<double>& _alpha, SolutionInfo& _si, TermCriteria termCrit )
        {
            int sample_count = _samples.rows;
1020

1021 1022
            _alpha.assign(sample_count, 0.);
            vector<double> _b(sample_count, -1.);
1023

1024 1025 1026 1027 1028
            Solver solver( _samples, _y, _alpha, _b, _Cp, _Cn, _kernel,
                           &Solver::get_row_svc,
                           &Solver::select_working_set,
                           &Solver::calc_rho,
                           termCrit );
1029

1030 1031
            if( !solver.solve_generic( _si ))
                return false;
1032

1033 1034
            for( int i = 0; i < sample_count; i++ )
                _alpha[i] *= _y[i];
1035

1036 1037
            return true;
        }
1038 1039


1040 1041 1042 1043 1044 1045
        static bool solve_nu_svc( const Mat& _samples, const vector<schar>& _y,
                                  double nu, const Ptr<SVM::Kernel>& _kernel,
                                  vector<double>& _alpha, SolutionInfo& _si,
                                  TermCriteria termCrit )
        {
            int sample_count = _samples.rows;
1046

1047 1048
            _alpha.resize(sample_count);
            vector<double> _b(sample_count, 0.);
1049

1050 1051
            double sum_pos = nu * sample_count * 0.5;
            double sum_neg = nu * sample_count * 0.5;
1052

1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067
            for( int i = 0; i < sample_count; i++ )
            {
                double a;
                if( _y[i] > 0 )
                {
                    a = std::min(1.0, sum_pos);
                    sum_pos -= a;
                }
                else
                {
                    a = std::min(1.0, sum_neg);
                    sum_neg -= a;
                }
                _alpha[i] = a;
            }
1068

1069 1070 1071 1072 1073
            Solver solver( _samples, _y, _alpha, _b, 1., 1., _kernel,
                           &Solver::get_row_svc,
                           &Solver::select_working_set_nu_svm,
                           &Solver::calc_rho_nu_svm,
                           termCrit );
1074

1075 1076
            if( !solver.solve_generic( _si ))
                return false;
1077

1078
            double inv_r = 1./_si.r;
1079

1080 1081
            for( int i = 0; i < sample_count; i++ )
                _alpha[i] *= _y[i]*inv_r;
1082

1083 1084 1085 1086
            _si.rho *= inv_r;
            _si.obj *= (inv_r*inv_r);
            _si.upper_bound_p = inv_r;
            _si.upper_bound_n = inv_r;
1087

1088 1089
            return true;
        }
1090

1091 1092 1093 1094 1095 1096 1097 1098
        static bool solve_one_class( const Mat& _samples, double nu,
                                     const Ptr<SVM::Kernel>& _kernel,
                                     vector<double>& _alpha, SolutionInfo& _si,
                                     TermCriteria termCrit )
        {
            int sample_count = _samples.rows;
            vector<schar> _y(sample_count, 1);
            vector<double> _b(sample_count, 0.);
1099

1100
            int i, n = cvRound( nu*sample_count );
1101

1102 1103 1104
            _alpha.resize(sample_count);
            for( i = 0; i < sample_count; i++ )
                _alpha[i] = i < n ? 1 : 0;
1105

1106 1107 1108 1109
            if( n < sample_count )
                _alpha[n] = nu * sample_count - n;
            else
                _alpha[n-1] = nu * sample_count - (n-1);
1110

1111 1112 1113 1114 1115
            Solver solver( _samples, _y, _alpha, _b, 1., 1., _kernel,
                           &Solver::get_row_one_class,
                           &Solver::select_working_set,
                           &Solver::calc_rho,
                           termCrit );
1116

1117 1118
            return solver.solve_generic(_si);
        }
1119

1120 1121 1122 1123 1124 1125 1126
        static bool solve_eps_svr( const Mat& _samples, const vector<float>& _yf,
                                   double p, double C, const Ptr<SVM::Kernel>& _kernel,
                                   vector<double>& _alpha, SolutionInfo& _si,
                                   TermCriteria termCrit )
        {
            int sample_count = _samples.rows;
            int alpha_count = sample_count*2;
1127

1128
            CV_Assert( (int)_yf.size() == sample_count );
1129

1130 1131 1132
            _alpha.assign(alpha_count, 0.);
            vector<schar> _y(alpha_count);
            vector<double> _b(alpha_count);
1133

1134
            for( int i = 0; i < sample_count; i++ )
1135
            {
1136 1137
                _b[i] = p - _yf[i];
                _y[i] = 1;
1138

1139 1140 1141
                _b[i+sample_count] = p + _yf[i];
                _y[i+sample_count] = -1;
            }
1142

1143 1144 1145 1146 1147
            Solver solver( _samples, _y, _alpha, _b, C, C, _kernel,
                           &Solver::get_row_svr,
                           &Solver::select_working_set,
                           &Solver::calc_rho,
                           termCrit );
1148

1149 1150
            if( !solver.solve_generic( _si ))
                return false;
1151

1152 1153
            for( int i = 0; i < sample_count; i++ )
                _alpha[i] -= _alpha[i+sample_count];
1154

1155
            return true;
1156 1157 1158
        }


1159 1160 1161 1162 1163 1164 1165 1166
        static bool solve_nu_svr( const Mat& _samples, const vector<float>& _yf,
                                  double nu, double C, const Ptr<SVM::Kernel>& _kernel,
                                  vector<double>& _alpha, SolutionInfo& _si,
                                  TermCriteria termCrit )
        {
            int sample_count = _samples.rows;
            int alpha_count = sample_count*2;
            double sum = C * nu * sample_count * 0.5;
1167

1168
            CV_Assert( (int)_yf.size() == sample_count );
1169

1170 1171 1172
            _alpha.resize(alpha_count);
            vector<schar> _y(alpha_count);
            vector<double> _b(alpha_count);
1173

1174 1175 1176 1177
            for( int i = 0; i < sample_count; i++ )
            {
                _alpha[i] = _alpha[i + sample_count] = std::min(sum, C);
                sum -= _alpha[i];
1178

1179 1180
                _b[i] = -_yf[i];
                _y[i] = 1;
1181

1182 1183 1184
                _b[i + sample_count] = _yf[i];
                _y[i + sample_count] = -1;
            }
1185

1186 1187 1188 1189 1190
            Solver solver( _samples, _y, _alpha, _b, 1., 1., _kernel,
                           &Solver::get_row_svr,
                           &Solver::select_working_set_nu_svm,
                           &Solver::calc_rho_nu_svm,
                           termCrit );
1191

1192 1193
            if( !solver.solve_generic( _si ))
                return false;
1194

1195 1196
            for( int i = 0; i < sample_count; i++ )
                _alpha[i] -= _alpha[i+sample_count];
1197

1198 1199
            return true;
        }
1200

1201 1202 1203 1204 1205
        int sample_count;
        int var_count;
        int cache_size;
        int max_cache_size;
        Mat samples;
1206
        SvmParams params;
1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225
        vector<KernelRow> lru_cache;
        int lru_first;
        int lru_last;
        Mat lru_cache_data;

        int alpha_count;

        vector<double> G_vec;
        vector<double>* alpha_vec;
        vector<schar> y_vec;
        // -1 - lower bound, 0 - free, 1 - upper bound
        vector<schar> alpha_status_vec;
        vector<double> b_vec;

        vector<Qfloat> buf[2];
        double eps;
        int max_iter;
        double C[2];  // C[0] == Cn, C[1] == Cp
        Ptr<SVM::Kernel> kernel;
1226

1227 1228 1229 1230 1231 1232 1233
        SelectWorkingSet select_working_set_func;
        CalcRho calc_rho_func;
        GetRow get_row_func;
    };

    //////////////////////////////////////////////////////////////////////////////////////////
    SVMImpl()
1234
    {
1235
        clear();
1236
        checkParams();
1237
    }
1238 1239

    ~SVMImpl()
1240
    {
1241 1242
        clear();
    }
1243

1244 1245 1246 1247 1248 1249
    void clear()
    {
        decision_func.clear();
        df_alpha.clear();
        df_index.clear();
        sv.release();
1250 1251 1252 1253 1254 1255
        uncompressed_sv.release();
    }

    Mat getUncompressedSupportVectors_() const
    {
        return uncompressed_sv;
1256
    }
1257

1258 1259 1260 1261
    Mat getSupportVectors() const
    {
        return sv;
    }
1262

1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273
    CV_IMPL_PROPERTY(int, Type, params.svmType)
    CV_IMPL_PROPERTY(double, Gamma, params.gamma)
    CV_IMPL_PROPERTY(double, Coef0, params.coef0)
    CV_IMPL_PROPERTY(double, Degree, params.degree)
    CV_IMPL_PROPERTY(double, C, params.C)
    CV_IMPL_PROPERTY(double, Nu, params.nu)
    CV_IMPL_PROPERTY(double, P, params.p)
    CV_IMPL_PROPERTY_S(cv::Mat, ClassWeights, params.classWeights)
    CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)

    int getKernelType() const
1274
    {
1275 1276
        return params.kernelType;
    }
1277

1278 1279 1280 1281 1282 1283
    void setKernel(int kernelType)
    {
        params.kernelType = kernelType;
        if (kernelType != CUSTOM)
            kernel = makePtr<SVMKernelImpl>(params);
    }
1284

1285 1286 1287 1288 1289
    void setCustomKernel(const Ptr<Kernel> &_kernel)
    {
        params.kernelType = CUSTOM;
        kernel = _kernel;
    }
1290

1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314
    void checkParams()
    {
        int kernelType = params.kernelType;
        if (kernelType != CUSTOM)
        {
            if( kernelType != LINEAR && kernelType != POLY &&
                kernelType != SIGMOID && kernelType != RBF &&
                kernelType != INTER && kernelType != CHI2)
                CV_Error( CV_StsBadArg, "Unknown/unsupported kernel type" );

            if( kernelType == LINEAR )
                params.gamma = 1;
            else if( params.gamma <= 0 )
                CV_Error( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );

            if( kernelType != SIGMOID && kernelType != POLY )
                params.coef0 = 0;
            else if( params.coef0 < 0 )
                CV_Error( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );

            if( kernelType != POLY )
                params.degree = 0;
            else if( params.degree <= 0 )
                CV_Error( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
1315

1316 1317 1318 1319 1320 1321 1322
            kernel = makePtr<SVMKernelImpl>(params);
        }
        else
        {
            if (!kernel)
                CV_Error( CV_StsBadArg, "Custom kernel is not set" );
        }
1323

1324
        int svmType = params.svmType;
1325

1326 1327 1328 1329
        if( svmType != C_SVC && svmType != NU_SVC &&
            svmType != ONE_CLASS && svmType != EPS_SVR &&
            svmType != NU_SVR )
            CV_Error( CV_StsBadArg, "Unknown/unsupported SVM type" );
1330

1331 1332 1333 1334
        if( svmType == ONE_CLASS || svmType == NU_SVC )
            params.C = 0;
        else if( params.C <= 0 )
            CV_Error( CV_StsOutOfRange, "The parameter C must be positive" );
1335

1336 1337 1338 1339
        if( svmType == C_SVC || svmType == EPS_SVR )
            params.nu = 0;
        else if( params.nu <= 0 || params.nu >= 1 )
            CV_Error( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
1340

1341 1342 1343 1344
        if( svmType != EPS_SVR )
            params.p = 0;
        else if( params.p <= 0 )
            CV_Error( CV_StsOutOfRange, "The parameter p must be positive" );
1345

1346 1347
        if( svmType != C_SVC )
            params.classWeights.release();
1348

1349 1350 1351 1352 1353 1354
        if( !(params.termCrit.type & TermCriteria::EPS) )
            params.termCrit.epsilon = DBL_EPSILON;
        params.termCrit.epsilon = std::max(params.termCrit.epsilon, DBL_EPSILON);
        if( !(params.termCrit.type & TermCriteria::COUNT) )
            params.termCrit.maxCount = INT_MAX;
        params.termCrit.maxCount = std::max(params.termCrit.maxCount, 1);
1355
    }
1356

1357
    void setParams( const SvmParams& _params)
1358
    {
1359 1360
        params = _params;
        checkParams();
1361
    }
1362 1363 1364 1365 1366

    int getSVCount(int i) const
    {
        return (i < (int)(decision_func.size()-1) ? decision_func[i+1].ofs :
                (int)df_index.size()) - decision_func[i].ofs;
1367 1368
    }

1369 1370 1371 1372 1373 1374
    bool do_train( const Mat& _samples, const Mat& _responses )
    {
        int svmType = params.svmType;
        int i, j, k, sample_count = _samples.rows;
        vector<double> _alpha;
        Solver::SolutionInfo sinfo;
1375

1376 1377
        CV_Assert( _samples.type() == CV_32F );
        var_count = _samples.cols;
1378

1379 1380 1381 1382
        if( svmType == ONE_CLASS || svmType == EPS_SVR || svmType == NU_SVR )
        {
            int sv_count = 0;
            decision_func.clear();
1383

1384 1385 1386
            vector<float> _yf;
            if( !_responses.empty() )
                _responses.convertTo(_yf, CV_32F);
1387

1388
            bool ok =
1389 1390 1391
            svmType == ONE_CLASS ? Solver::solve_one_class( _samples, params.nu, kernel, _alpha, sinfo, params.termCrit ) :
            svmType == EPS_SVR ? Solver::solve_eps_svr( _samples, _yf, params.p, params.C, kernel, _alpha, sinfo, params.termCrit ) :
            svmType == NU_SVR ? Solver::solve_nu_svr( _samples, _yf, params.nu, params.C, kernel, _alpha, sinfo, params.termCrit ) : false;
1392

1393 1394
            if( !ok )
                return false;
1395

1396 1397
            for( i = 0; i < sample_count; i++ )
                sv_count += fabs(_alpha[i]) > 0;
1398

1399
            CV_Assert(sv_count != 0);
1400

1401 1402 1403
            sv.create(sv_count, _samples.cols, CV_32F);
            df_alpha.resize(sv_count);
            df_index.resize(sv_count);
1404

1405 1406 1407 1408 1409 1410 1411 1412 1413 1414
            for( i = k = 0; i < sample_count; i++ )
            {
                if( std::abs(_alpha[i]) > 0 )
                {
                    _samples.row(i).copyTo(sv.row(k));
                    df_alpha[k] = _alpha[i];
                    df_index[k] = k;
                    k++;
                }
            }
1415

1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430
            decision_func.push_back(DecisionFunc(sinfo.rho, 0));
        }
        else
        {
            int class_count = (int)class_labels.total();
            vector<int> svidx, sidx, sidx_all, sv_tab(sample_count, 0);
            Mat temp_samples, class_weights;
            vector<int> class_ranges;
            vector<schar> temp_y;
            double nu = params.nu;
            CV_Assert( svmType == C_SVC || svmType == NU_SVC );

            if( svmType == C_SVC && !params.classWeights.empty() )
            {
                const Mat cw = params.classWeights;
1431

1432 1433 1434 1435 1436
                if( (cw.cols != 1 && cw.rows != 1) ||
                    (int)cw.total() != class_count ||
                    (cw.type() != CV_32F && cw.type() != CV_64F) )
                    CV_Error( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
                        "containing as many elements as the number of classes" );
1437

1438 1439 1440
                cw.convertTo(class_weights, CV_64F, params.C);
                //normalize(cw, class_weights, params.C, 0, NORM_L1, CV_64F);
            }
1441

1442 1443 1444
            decision_func.clear();
            df_alpha.clear();
            df_index.clear();
1445

1446
            sortSamplesByClasses( _samples, _responses, sidx_all, class_ranges );
1447

1448 1449 1450
            //check that while cross-validation there were the samples from all the classes
            if( class_ranges[class_count] <= 0 )
                CV_Error( CV_StsBadArg, "While cross-validation one or more of the classes have "
1451
                "been fell out of the sample. Try to reduce <Params::k_fold>" );
1452

1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467
            if( svmType == NU_SVC )
            {
                // check if nu is feasible
                for( i = 0; i < class_count; i++ )
                {
                    int ci = class_ranges[i+1] - class_ranges[i];
                    for( j = i+1; j< class_count; j++ )
                    {
                        int cj = class_ranges[j+1] - class_ranges[j];
                        if( nu*(ci + cj)*0.5 > std::min( ci, cj ) )
                            // TODO: add some diagnostic
                            return false;
                    }
                }
            }
1468

1469
            size_t samplesize = _samples.cols*_samples.elemSize();
1470

1471 1472 1473 1474 1475 1476 1477 1478
            // train n*(n-1)/2 classifiers
            for( i = 0; i < class_count; i++ )
            {
                for( j = i+1; j < class_count; j++ )
                {
                    int si = class_ranges[i], ci = class_ranges[i+1] - si;
                    int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
                    double Cp = params.C, Cn = Cp;
1479

1480 1481 1482
                    temp_samples.create(ci + cj, _samples.cols, _samples.type());
                    sidx.resize(ci + cj);
                    temp_y.resize(ci + cj);
1483

1484 1485 1486 1487 1488 1489 1490 1491
                    // form input for the binary classification problem
                    for( k = 0; k < ci+cj; k++ )
                    {
                        int idx = k < ci ? si+k : sj+k-ci;
                        memcpy(temp_samples.ptr(k), _samples.ptr(sidx_all[idx]), samplesize);
                        sidx[k] = sidx_all[idx];
                        temp_y[k] = k < ci ? 1 : -1;
                    }
1492

1493 1494 1495 1496 1497
                    if( !class_weights.empty() )
                    {
                        Cp = class_weights.at<double>(i);
                        Cn = class_weights.at<double>(j);
                    }
1498

1499 1500 1501
                    DecisionFunc df;
                    bool ok = params.svmType == C_SVC ?
                                Solver::solve_c_svc( temp_samples, temp_y, Cp, Cn,
1502
                                                     kernel, _alpha, sinfo, params.termCrit ) :
1503 1504
                              params.svmType == NU_SVC ?
                                Solver::solve_nu_svc( temp_samples, temp_y, params.nu,
1505
                                                      kernel, _alpha, sinfo, params.termCrit ) :
1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524
                              false;
                    if( !ok )
                        return false;
                    df.rho = sinfo.rho;
                    df.ofs = (int)df_index.size();
                    decision_func.push_back(df);

                    for( k = 0; k < ci + cj; k++ )
                    {
                        if( std::abs(_alpha[k]) > 0 )
                        {
                            int idx = k < ci ? si+k : sj+k-ci;
                            sv_tab[sidx_all[idx]] = 1;
                            df_index.push_back(sidx_all[idx]);
                            df_alpha.push_back(_alpha[k]);
                        }
                    }
                }
            }
1525

1526 1527 1528 1529 1530 1531
            // allocate support vectors and initialize sv_tab
            for( i = 0, k = 0; i < sample_count; i++ )
            {
                if( sv_tab[i] )
                    sv_tab[i] = ++k;
            }
1532

1533 1534
            int sv_total = k;
            sv.create(sv_total, _samples.cols, _samples.type());
1535

1536 1537 1538 1539 1540 1541
            for( i = 0; i < sample_count; i++ )
            {
                if( !sv_tab[i] )
                    continue;
                memcpy(sv.ptr(sv_tab[i]-1), _samples.ptr(i), samplesize);
            }
1542

1543 1544 1545 1546 1547 1548 1549 1550
            // set sv pointers
            int n = (int)df_index.size();
            for( i = 0; i < n; i++ )
            {
                CV_Assert( sv_tab[df_index[i]] > 0 );
                df_index[i] = sv_tab[df_index[i]] - 1;
            }
        }
1551

1552
        optimize_linear_svm();
1553

1554 1555
        return true;
    }
1556

1557 1558 1559 1560 1561
    void optimize_linear_svm()
    {
        // we optimize only linear SVM: compress all the support vectors into one.
        if( params.kernelType != LINEAR )
            return;
1562

1563
        int i, df_count = (int)decision_func.size();
1564

1565 1566 1567 1568 1569
        for( i = 0; i < df_count; i++ )
        {
            if( getSVCount(i) != 1 )
                break;
        }
1570

1571 1572 1573 1574
        // if every decision functions uses a single support vector;
        // it's already compressed. skip it then.
        if( i == df_count )
            return;
1575

1576 1577 1578
        AutoBuffer<double> vbuf(var_count);
        double* v = vbuf;
        Mat new_sv(df_count, var_count, CV_32F);
1579

1580
        vector<DecisionFunc> new_df;
1581

1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600
        for( i = 0; i < df_count; i++ )
        {
            float* dst = new_sv.ptr<float>(i);
            memset(v, 0, var_count*sizeof(v[0]));
            int j, k, sv_count = getSVCount(i);
            const DecisionFunc& df = decision_func[i];
            const int* sv_index = &df_index[df.ofs];
            const double* sv_alpha = &df_alpha[df.ofs];
            for( j = 0; j < sv_count; j++ )
            {
                const float* src = sv.ptr<float>(sv_index[j]);
                double a = sv_alpha[j];
                for( k = 0; k < var_count; k++ )
                    v[k] += src[k]*a;
            }
            for( k = 0; k < var_count; k++ )
                dst[k] = (float)v[k];
            new_df.push_back(DecisionFunc(df.rho, i));
        }
1601

1602 1603
        setRangeVector(df_index, df_count);
        df_alpha.assign(df_count, 1.);
1604
        sv.copyTo(uncompressed_sv);
1605 1606
        std::swap(sv, new_sv);
        std::swap(decision_func, new_df);
1607 1608
    }

1609
    bool train( const Ptr<TrainData>& data, int )
1610
    {
1611
        clear();
1612

1613 1614
        checkParams();

1615 1616 1617
        int svmType = params.svmType;
        Mat samples = data->getTrainSamples();
        Mat responses;
1618

1619
        if( svmType == C_SVC || svmType == NU_SVC )
1620
        {
1621
            responses = data->getTrainNormCatResponses();
1622 1623 1624
            if( responses.empty() )
                CV_Error(CV_StsBadArg, "in the case of classification problem the responses must be categorical; "
                                       "either specify varType when creating TrainData, or pass integer responses");
1625
            class_labels = data->getClassLabels();
1626 1627
        }
        else
1628
            responses = data->getTrainResponses();
1629

1630 1631 1632 1633 1634
        if( !do_train( samples, responses ))
        {
            clear();
            return false;
        }
1635

1636 1637
        return true;
    }
1638

1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731
    class TrainAutoBody : public ParallelLoopBody
    {
    public:
        TrainAutoBody(const vector<SvmParams>& _parameters,
                      const cv::Mat& _samples,
                      const cv::Mat& _responses,
                      const cv::Mat& _labels,
                      const vector<int>& _sidx,
                      bool _is_classification,
                      int _k_fold,
                      std::vector<double>& _result) :
        parameters(_parameters), samples(_samples), responses(_responses), labels(_labels),
        sidx(_sidx), is_classification(_is_classification), k_fold(_k_fold), result(_result)
        {}

        void operator()( const cv::Range& range ) const
        {
            int sample_count = samples.rows;
            int var_count_ = samples.cols;
            size_t sample_size = var_count_*samples.elemSize();

            int test_sample_count = (sample_count + k_fold/2)/k_fold;
            int train_sample_count = sample_count - test_sample_count;

            // Use a local instance
            cv::Ptr<SVMImpl> svm = makePtr<SVMImpl>();
            svm->class_labels = labels;

            int rtype = responses.type();

            Mat temp_train_samples(train_sample_count, var_count_, CV_32F);
            Mat temp_test_samples(test_sample_count, var_count_, CV_32F);
            Mat temp_train_responses(train_sample_count, 1, rtype);
            Mat temp_test_responses;

            for( int p = range.start; p < range.end; p++ )
            {
                svm->setParams(parameters[p]);

                double error = 0;
                for( int k = 0; k < k_fold; k++ )
                {
                    int start = (k*sample_count + k_fold/2)/k_fold;
                    for( int i = 0; i < train_sample_count; i++ )
                    {
                        int j = sidx[(i+start)%sample_count];
                        memcpy(temp_train_samples.ptr(i), samples.ptr(j), sample_size);
                        if( is_classification )
                            temp_train_responses.at<int>(i) = responses.at<int>(j);
                        else if( !responses.empty() )
                            temp_train_responses.at<float>(i) = responses.at<float>(j);
                    }

                    // Train SVM on <train_size> samples
                    if( !svm->do_train( temp_train_samples, temp_train_responses ))
                        continue;

                    for( int i = 0; i < test_sample_count; i++ )
                    {
                        int j = sidx[(i+start+train_sample_count) % sample_count];
                        memcpy(temp_test_samples.ptr(i), samples.ptr(j), sample_size);
                    }

                    svm->predict(temp_test_samples, temp_test_responses, 0);
                    for( int i = 0; i < test_sample_count; i++ )
                    {
                        float val = temp_test_responses.at<float>(i);
                        int j = sidx[(i+start+train_sample_count) % sample_count];
                        if( is_classification )
                            error += (float)(val != responses.at<int>(j));
                        else
                        {
                            val -= responses.at<float>(j);
                            error += val*val;
                        }
                    }
                }

                result[p] = error;
            }
        }

    private:
        const vector<SvmParams>& parameters;
        const cv::Mat& samples;
        const cv::Mat& responses;
        const cv::Mat& labels;
        const vector<int>& sidx;
        bool is_classification;
        int k_fold;
        std::vector<double>& result;
    };

1732 1733 1734 1735 1736
    bool trainAuto( const Ptr<TrainData>& data, int k_fold,
                    ParamGrid C_grid, ParamGrid gamma_grid, ParamGrid p_grid,
                    ParamGrid nu_grid, ParamGrid coef_grid, ParamGrid degree_grid,
                    bool balanced )
    {
1737 1738
        checkParams();

1739
        int svmType = params.svmType;
1740
        RNG rng((uint64)-1);
1741

1742 1743 1744
        if( svmType == ONE_CLASS )
            // current implementation of "auto" svm does not support the 1-class case.
            return train( data, 0 );
1745

1746
        clear();
1747

1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784
        CV_Assert( k_fold >= 2 );

        // All the parameters except, possibly, <coef0> are positive.
        // <coef0> is nonnegative
        #define CHECK_GRID(grid, param) \
        if( grid.logStep <= 1 ) \
        { \
            grid.minVal = grid.maxVal = params.param; \
            grid.logStep = 10; \
        } \
        else \
            checkParamGrid(grid)

        CHECK_GRID(C_grid, C);
        CHECK_GRID(gamma_grid, gamma);
        CHECK_GRID(p_grid, p);
        CHECK_GRID(nu_grid, nu);
        CHECK_GRID(coef_grid, coef0);
        CHECK_GRID(degree_grid, degree);

        // these parameters are not used:
        if( params.kernelType != POLY )
            degree_grid.minVal = degree_grid.maxVal = params.degree;
        if( params.kernelType == LINEAR )
            gamma_grid.minVal = gamma_grid.maxVal = params.gamma;
        if( params.kernelType != POLY && params.kernelType != SIGMOID )
            coef_grid.minVal = coef_grid.maxVal = params.coef0;
        if( svmType == NU_SVC || svmType == ONE_CLASS )
            C_grid.minVal = C_grid.maxVal = params.C;
        if( svmType == C_SVC || svmType == EPS_SVR )
            nu_grid.minVal = nu_grid.maxVal = params.nu;
        if( svmType != EPS_SVR )
            p_grid.minVal = p_grid.maxVal = params.p;

        Mat samples = data->getTrainSamples();
        Mat responses;
        bool is_classification = false;
1785
        Mat class_labels0;
1786 1787 1788 1789 1790 1791
        int class_count = (int)class_labels.total();

        if( svmType == C_SVC || svmType == NU_SVC )
        {
            responses = data->getTrainNormCatResponses();
            class_labels = data->getClassLabels();
1792
            class_count = (int)class_labels.total();
1793
            is_classification = true;
1794

1795 1796
            vector<int> temp_class_labels;
            setRangeVector(temp_class_labels, class_count);
1797

1798
            // temporarily replace class labels with 0, 1, ..., NCLASSES-1
1799 1800
            class_labels0 = class_labels;
            class_labels = Mat(temp_class_labels).clone();
1801 1802 1803
        }
        else
            responses = data->getTrainResponses();
1804

1805
        CV_Assert(samples.type() == CV_32F);
1806

1807 1808
        int sample_count = samples.rows;
        var_count = samples.cols;
1809

1810 1811
        vector<int> sidx;
        setRangeVector(sidx, sample_count);
1812

1813
        // randomly permute training samples
1814
        for( int i = 0; i < sample_count; i++ )
1815 1816 1817 1818 1819
        {
            int i1 = rng.uniform(0, sample_count);
            int i2 = rng.uniform(0, sample_count);
            std::swap(sidx[i1], sidx[i2]);
        }
1820

1821 1822 1823 1824 1825 1826
        if( is_classification && class_count == 2 && balanced )
        {
            // reshuffle the training set in such a way that
            // instances of each class are divided more or less evenly
            // between the k_fold parts.
            vector<int> sidx0, sidx1;
1827

1828
            for( int i = 0; i < sample_count; i++ )
1829 1830 1831 1832 1833 1834
            {
                if( responses.at<int>(sidx[i]) == 0 )
                    sidx0.push_back(sidx[i]);
                else
                    sidx1.push_back(sidx[i]);
            }
1835

1836 1837 1838
            int n0 = (int)sidx0.size(), n1 = (int)sidx1.size();
            int a0 = 0, a1 = 0;
            sidx.clear();
1839
            for( int k = 0; k < k_fold; k++ )
1840 1841 1842
            {
                int b0 = ((k+1)*n0 + k_fold/2)/k_fold, b1 = ((k+1)*n1 + k_fold/2)/k_fold;
                int a = (int)sidx.size(), b = a + (b0 - a0) + (b1 - a1);
1843
                for( int i = a0; i < b0; i++ )
1844
                    sidx.push_back(sidx0[i]);
1845
                for( int i = a1; i < b1; i++ )
1846
                    sidx.push_back(sidx1[i]);
1847
                for( int i = 0; i < (b - a); i++ )
1848 1849 1850 1851 1852 1853 1854 1855
                {
                    int i1 = rng.uniform(a, b);
                    int i2 = rng.uniform(a, b);
                    std::swap(sidx[i1], sidx[i2]);
                }
                a0 = b0; a1 = b1;
            }
        }
1856

1857
        // If grid.minVal == grid.maxVal, this will allow one and only one pass through the loop with params.var = grid.minVal.
1858
        #define FOR_IN_GRID(var, grid) \
1859
            for( params.var = grid.minVal; params.var == grid.minVal || params.var < grid.maxVal; params.var = (grid.minVal == grid.maxVal) ? grid.maxVal + 1 : params.var * grid.logStep )
1860

1861 1862
        // Create the list of parameters to test
        std::vector<SvmParams> parameters;
1863 1864 1865 1866 1867 1868 1869
        FOR_IN_GRID(C, C_grid)
        FOR_IN_GRID(gamma, gamma_grid)
        FOR_IN_GRID(p, p_grid)
        FOR_IN_GRID(nu, nu_grid)
        FOR_IN_GRID(coef0, coef_grid)
        FOR_IN_GRID(degree, degree_grid)
        {
1870 1871
            parameters.push_back(params);
        }
1872

1873 1874 1875 1876
        std::vector<double> result(parameters.size());
        TrainAutoBody invoker(parameters, samples, responses, class_labels, sidx,
                              is_classification, k_fold, result);
        parallel_for_(cv::Range(0,(int)parameters.size()), invoker);
1877

1878 1879 1880 1881 1882 1883
        // Extract the best parameters
        SvmParams best_params = params;
        double min_error = FLT_MAX;
        for( int i = 0; i < (int)result.size(); i++ )
        {
            if( result[i] < min_error )
1884
            {
1885 1886
                min_error   = result[i];
                best_params = parameters[i];
1887 1888
            }
        }
1889

1890
        class_labels = class_labels0;
1891
        setParams(best_params);
1892 1893
        return do_train( samples, responses );
    }
1894

1895 1896 1897 1898 1899 1900 1901 1902 1903
    struct PredictBody : ParallelLoopBody
    {
        PredictBody( const SVMImpl* _svm, const Mat& _samples, Mat& _results, bool _returnDFVal )
        {
            svm = _svm;
            results = &_results;
            samples = &_samples;
            returnDFVal = _returnDFVal;
        }
Maria Dimashova's avatar
Maria Dimashova committed
1904

1905 1906 1907 1908
        void operator()( const Range& range ) const
        {
            int svmType = svm->params.svmType;
            int sv_total = svm->sv.rows;
1909
            int class_count = !svm->class_labels.empty() ? (int)svm->class_labels.total() : svmType == ONE_CLASS ? 1 : 0;
1910

1911 1912
            AutoBuffer<float> _buffer(sv_total + (class_count+1)*2);
            float* buffer = _buffer;
1913

1914
            int i, j, dfi, k, si;
1915

1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933
            if( svmType == EPS_SVR || svmType == NU_SVR || svmType == ONE_CLASS )
            {
                for( si = range.start; si < range.end; si++ )
                {
                    const float* row_sample = samples->ptr<float>(si);
                    svm->kernel->calc( sv_total, svm->var_count, svm->sv.ptr<float>(), row_sample, buffer );

                    const SVMImpl::DecisionFunc* df = &svm->decision_func[0];
                    double sum = -df->rho;
                    for( i = 0; i < sv_total; i++ )
                        sum += buffer[i]*svm->df_alpha[i];
                    float result = svm->params.svmType == ONE_CLASS && !returnDFVal ? (float)(sum > 0) : (float)sum;
                    results->at<float>(si) = result;
                }
            }
            else if( svmType == C_SVC || svmType == NU_SVC )
            {
                int* vote = (int*)(buffer + sv_total);
1934

1935 1936 1937 1938 1939
                for( si = range.start; si < range.end; si++ )
                {
                    svm->kernel->calc( sv_total, svm->var_count, svm->sv.ptr<float>(),
                                       samples->ptr<float>(si), buffer );
                    double sum = 0.;
1940

1941
                    memset( vote, 0, class_count*sizeof(vote[0]));
1942

1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957
                    for( i = dfi = 0; i < class_count; i++ )
                    {
                        for( j = i+1; j < class_count; j++, dfi++ )
                        {
                            const DecisionFunc& df = svm->decision_func[dfi];
                            sum = -df.rho;
                            int sv_count = svm->getSVCount(dfi);
                            const double* alpha = &svm->df_alpha[df.ofs];
                            const int* sv_index = &svm->df_index[df.ofs];
                            for( k = 0; k < sv_count; k++ )
                                sum += alpha[k]*buffer[sv_index[k]];

                            vote[sum > 0 ? i : j]++;
                        }
                    }
1958

1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972
                    for( i = 1, k = 0; i < class_count; i++ )
                    {
                        if( vote[i] > vote[k] )
                            k = i;
                    }
                    float result = returnDFVal && class_count == 2 ?
                        (float)sum : (float)(svm->class_labels.at<int>(k));
                    results->at<float>(si) = result;
                }
            }
            else
                CV_Error( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
                         "the SVM structure is probably corrupted" );
        }
1973

1974 1975 1976 1977 1978
        const SVMImpl* svm;
        const Mat* samples;
        Mat* results;
        bool returnDFVal;
    };
1979

1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997
    bool trainAuto_(InputArray samples, int layout,
            InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
            Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
            Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
    {
        Ptr<TrainData> data = TrainData::create(samples, layout, responses);
        return this->trainAuto(
                data, kfold,
                *Cgrid.get(),
                *gammaGrid.get(),
                *pGrid.get(),
                *nuGrid.get(),
                *coeffGrid.get(),
                *degreeGrid.get(),
                balanced);
    }


1998 1999 2000 2001 2002 2003
    float predict( InputArray _samples, OutputArray _results, int flags ) const
    {
        float result = 0;
        Mat samples = _samples.getMat(), results;
        int nsamples = samples.rows;
        bool returnDFVal = (flags & RAW_OUTPUT) != 0;
2004

2005
        CV_Assert( samples.cols == var_count && samples.type() == CV_32F );
2006

2007 2008 2009 2010 2011 2012
        if( _results.needed() )
        {
            _results.create( nsamples, 1, samples.type() );
            results = _results.getMat();
        }
        else
2013
        {
2014 2015
            CV_Assert( nsamples == 1 );
            results = Mat(1, 1, CV_32F, &result);
2016
        }
2017 2018 2019 2020

        PredictBody invoker(this, samples, results, returnDFVal);
        if( nsamples < 10 )
            invoker(Range(0, nsamples));
2021
        else
2022 2023 2024
            parallel_for_(Range(0, nsamples), invoker);
        return result;
    }
2025

2026 2027 2028 2029 2030 2031 2032 2033
    double getDecisionFunction(int i, OutputArray _alpha, OutputArray _svidx ) const
    {
        CV_Assert( 0 <= i && i < (int)decision_func.size());
        const DecisionFunc& df = decision_func[i];
        int count = getSVCount(i);
        Mat(1, count, CV_64F, (double*)&df_alpha[df.ofs]).copyTo(_alpha);
        Mat(1, count, CV_32S, (int*)&df_index[df.ofs]).copyTo(_svidx);
        return df.rho;
2034 2035
    }

2036 2037 2038 2039
    void write_params( FileStorage& fs ) const
    {
        int svmType = params.svmType;
        int kernelType = params.kernelType;
2040

2041 2042 2043 2044 2045 2046 2047 2048 2049 2050
        String svm_type_str =
            svmType == C_SVC ? "C_SVC" :
            svmType == NU_SVC ? "NU_SVC" :
            svmType == ONE_CLASS ? "ONE_CLASS" :
            svmType == EPS_SVR ? "EPS_SVR" :
            svmType == NU_SVR ? "NU_SVR" : format("Uknown_%d", svmType);
        String kernel_type_str =
            kernelType == LINEAR ? "LINEAR" :
            kernelType == POLY ? "POLY" :
            kernelType == RBF ? "RBF" :
2051 2052 2053
            kernelType == SIGMOID ? "SIGMOID" :
            kernelType == CHI2 ? "CHI2" :
            kernelType == INTER ? "INTER" : format("Unknown_%d", kernelType);
2054

2055
        fs << "svmType" << svm_type_str;
2056

2057 2058
        // save kernel
        fs << "kernel" << "{" << "type" << kernel_type_str;
2059

2060 2061
        if( kernelType == POLY )
            fs << "degree" << params.degree;
2062

2063 2064
        if( kernelType != LINEAR )
            fs << "gamma" << params.gamma;
2065

2066 2067
        if( kernelType == POLY || kernelType == SIGMOID )
            fs << "coef0" << params.coef0;
2068

2069
        fs << "}";
2070

2071 2072
        if( svmType == C_SVC || svmType == EPS_SVR || svmType == NU_SVR )
            fs << "C" << params.C;
2073

2074 2075
        if( svmType == NU_SVC || svmType == ONE_CLASS || svmType == NU_SVR )
            fs << "nu" << params.nu;
2076

2077 2078
        if( svmType == EPS_SVR )
            fs << "p" << params.p;
2079

2080 2081 2082 2083 2084 2085 2086
        fs << "term_criteria" << "{:";
        if( params.termCrit.type & TermCriteria::EPS )
            fs << "epsilon" << params.termCrit.epsilon;
        if( params.termCrit.type & TermCriteria::COUNT )
            fs << "iterations" << params.termCrit.maxCount;
        fs << "}";
    }
2087

2088
    bool isTrained() const
2089
    {
2090
        return !sv.empty();
2091 2092
    }

2093 2094 2095 2096
    bool isClassifier() const
    {
        return params.svmType == C_SVC || params.svmType == NU_SVC || params.svmType == ONE_CLASS;
    }
2097

2098 2099 2100 2101
    int getVarCount() const
    {
        return var_count;
    }
2102

2103
    String getDefaultName() const
2104
    {
2105
        return "opencv_ml_svm";
2106 2107
    }

2108 2109 2110 2111 2112 2113
    void write( FileStorage& fs ) const
    {
        int class_count = !class_labels.empty() ? (int)class_labels.total() :
                          params.svmType == ONE_CLASS ? 1 : 0;
        if( !isTrained() )
            CV_Error( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2114

2115
        writeFormat(fs);
2116
        write_params( fs );
2117

2118
        fs << "var_count" << var_count;
2119

2120 2121 2122
        if( class_count > 0 )
        {
            fs << "class_count" << class_count;
2123

2124 2125
            if( !class_labels.empty() )
                fs << "class_labels" << class_labels;
2126

2127 2128
            if( !params.classWeights.empty() )
                fs << "class_weights" << params.classWeights;
2129 2130
        }

2131 2132 2133 2134 2135
        // write the joint collection of support vectors
        int i, sv_total = sv.rows;
        fs << "sv_total" << sv_total;
        fs << "support_vectors" << "[";
        for( i = 0; i < sv_total; i++ )
2136
        {
2137 2138 2139
            fs << "[:";
            fs.writeRaw("f", sv.ptr(i), sv.cols*sv.elemSize());
            fs << "]";
2140
        }
2141
        fs << "]";
2142

2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157
        if ( !uncompressed_sv.empty() )
        {
            // write the joint collection of uncompressed support vectors
            int uncompressed_sv_total = uncompressed_sv.rows;
            fs << "uncompressed_sv_total" << uncompressed_sv_total;
            fs << "uncompressed_support_vectors" << "[";
            for( i = 0; i < uncompressed_sv_total; i++ )
            {
                fs << "[:";
                fs.writeRaw("f", uncompressed_sv.ptr(i), uncompressed_sv.cols*uncompressed_sv.elemSize());
                fs << "]";
            }
            fs << "]";
        }

2158 2159
        // write decision functions
        int df_count = (int)decision_func.size();
2160

2161 2162
        fs << "decision_functions" << "[";
        for( i = 0; i < df_count; i++ )
2163
        {
2164 2165 2166 2167 2168 2169 2170
            const DecisionFunc& df = decision_func[i];
            int sv_count = getSVCount(i);
            fs << "{" << "sv_count" << sv_count
               << "rho" << df.rho
               << "alpha" << "[:";
            fs.writeRaw("d", (const uchar*)&df_alpha[df.ofs], sv_count*sizeof(df_alpha[0]));
            fs << "]";
2171
            if( class_count >= 2 )
2172 2173 2174 2175 2176 2177 2178 2179
            {
                fs << "index" << "[:";
                fs.writeRaw("i", (const uchar*)&df_index[df.ofs], sv_count*sizeof(df_index[0]));
                fs << "]";
            }
            else
                CV_Assert( sv_count == sv_total );
            fs << "}";
2180
        }
2181
        fs << "]";
2182 2183
    }

2184
    void read_params( const FileNode& fn )
2185
    {
2186
        SvmParams _params;
2187

2188 2189
        // check for old naming
        String svm_type_str = (String)(fn["svm_type"].empty() ? fn["svmType"] : fn["svm_type"]);
2190 2191 2192 2193 2194 2195 2196 2197
        int svmType =
            svm_type_str == "C_SVC" ? C_SVC :
            svm_type_str == "NU_SVC" ? NU_SVC :
            svm_type_str == "ONE_CLASS" ? ONE_CLASS :
            svm_type_str == "EPS_SVR" ? EPS_SVR :
            svm_type_str == "NU_SVR" ? NU_SVR : -1;

        if( svmType < 0 )
2198
            CV_Error( CV_StsParseError, "Missing or invalid SVM type" );
2199 2200 2201 2202 2203 2204 2205 2206 2207 2208

        FileNode kernel_node = fn["kernel"];
        if( kernel_node.empty() )
            CV_Error( CV_StsParseError, "SVM kernel tag is not found" );

        String kernel_type_str = (String)kernel_node["type"];
        int kernelType =
            kernel_type_str == "LINEAR" ? LINEAR :
            kernel_type_str == "POLY" ? POLY :
            kernel_type_str == "RBF" ? RBF :
2209 2210 2211
            kernel_type_str == "SIGMOID" ? SIGMOID :
            kernel_type_str == "CHI2" ? CHI2 :
            kernel_type_str == "INTER" ? INTER : CUSTOM;
2212

2213 2214
        if( kernelType == CUSTOM )
            CV_Error( CV_StsParseError, "Invalid SVM kernel type (or custom kernel)" );
2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228

        _params.svmType = svmType;
        _params.kernelType = kernelType;
        _params.degree = (double)kernel_node["degree"];
        _params.gamma = (double)kernel_node["gamma"];
        _params.coef0 = (double)kernel_node["coef0"];

        _params.C = (double)fn["C"];
        _params.nu = (double)fn["nu"];
        _params.p = (double)fn["p"];
        _params.classWeights = Mat();

        FileNode tcnode = fn["term_criteria"];
        if( !tcnode.empty() )
2229
        {
2230 2231 2232 2233
            _params.termCrit.epsilon = (double)tcnode["epsilon"];
            _params.termCrit.maxCount = (int)tcnode["iterations"];
            _params.termCrit.type = (_params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
                                   (_params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
2234
        }
2235 2236 2237
        else
            _params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 1000, FLT_EPSILON );

2238
        setParams( _params );
2239 2240
    }

2241 2242 2243
    void read( const FileNode& fn )
    {
        clear();
2244

2245 2246
        // read SVM parameters
        read_params( fn );
2247

2248 2249 2250 2251
        // and top-level data
        int i, sv_total = (int)fn["sv_total"];
        var_count = (int)fn["var_count"];
        int class_count = (int)fn["class_count"];
2252

2253 2254
        if( sv_total <= 0 || var_count <= 0 )
            CV_Error( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2255

2256 2257 2258 2259 2260 2261
        FileNode m = fn["class_labels"];
        if( !m.empty() )
            m >> class_labels;
        m = fn["class_weights"];
        if( !m.empty() )
            m >> params.classWeights;
2262

2263 2264
        if( class_count > 1 && (class_labels.empty() || (int)class_labels.total() != class_count))
            CV_Error( CV_StsParseError, "Array of class labels is missing or invalid" );
2265

2266 2267
        // read support vectors
        FileNode sv_node = fn["support_vectors"];
2268

2269
        CV_Assert((int)sv_node.size() == sv_total);
2270

2271
        sv.create(sv_total, var_count, CV_32F);
2272 2273
        FileNodeIterator sv_it = sv_node.begin();
        for( i = 0; i < sv_total; i++, ++sv_it )
2274
        {
2275
            (*sv_it).readRaw("f", sv.ptr(i), var_count*sv.elemSize());
2276 2277
        }

2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294
        int uncompressed_sv_total = (int)fn["uncompressed_sv_total"];

        if( uncompressed_sv_total > 0 )
        {
            // read uncompressed support vectors
            FileNode uncompressed_sv_node = fn["uncompressed_support_vectors"];

            CV_Assert((int)uncompressed_sv_node.size() == uncompressed_sv_total);
            uncompressed_sv.create(uncompressed_sv_total, var_count, CV_32F);

            FileNodeIterator uncompressed_sv_it = uncompressed_sv_node.begin();
            for( i = 0; i < uncompressed_sv_total; i++, ++uncompressed_sv_it )
            {
                (*uncompressed_sv_it).readRaw("f", uncompressed_sv.ptr(i), var_count*uncompressed_sv.elemSize());
            }
        }

2295 2296 2297 2298 2299
        // read decision functions
        int df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
        FileNode df_node = fn["decision_functions"];

        CV_Assert((int)df_node.size() == df_count);
2300

2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312
        FileNodeIterator df_it = df_node.begin();
        for( i = 0; i < df_count; i++, ++df_it )
        {
            FileNode dfi = *df_it;
            DecisionFunc df;
            int sv_count = (int)dfi["sv_count"];
            int ofs = (int)df_index.size();
            df.rho = (double)dfi["rho"];
            df.ofs = ofs;
            df_index.resize(ofs + sv_count);
            df_alpha.resize(ofs + sv_count);
            dfi["alpha"].readRaw("d", (uchar*)&df_alpha[ofs], sv_count*sizeof(df_alpha[0]));
2313
            if( class_count >= 2 )
2314 2315 2316
                dfi["index"].readRaw("i", (uchar*)&df_index[ofs], sv_count*sizeof(df_index[0]));
            decision_func.push_back(df);
        }
2317
        if( class_count < 2 )
2318 2319 2320 2321 2322
            setRangeVector(df_index, sv_total);
        if( (int)fn["optimize_linear"] != 0 )
            optimize_linear_svm();
    }

2323
    SvmParams params;
2324 2325
    Mat class_labels;
    int var_count;
2326
    Mat sv, uncompressed_sv;
2327 2328 2329 2330 2331 2332
    vector<DecisionFunc> decision_func;
    vector<double> df_alpha;
    vector<int> df_index;

    Ptr<Kernel> kernel;
};
2333 2334


2335
Ptr<SVM> SVM::create()
2336
{
2337
    return makePtr<SVMImpl>();
2338 2339
}

2340
Ptr<SVM> SVM::load(const String& filepath)
2341 2342
{
    FileStorage fs;
2343
    fs.open(filepath, FileStorage::READ);
2344 2345 2346

    Ptr<SVM> svm = makePtr<SVMImpl>();

2347
    ((SVMImpl*)svm.get())->read(fs.getFirstTopLevelNode());
2348 2349 2350
    return svm;
}

2351 2352 2353 2354 2355 2356 2357 2358
Mat SVM::getUncompressedSupportVectors() const
{
    const SVMImpl* this_ = dynamic_cast<const SVMImpl*>(this);
    if(!this_)
        CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
    return this_->getUncompressedSupportVectors_();
}

2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371
bool SVM::trainAuto(InputArray samples, int layout,
            InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
            Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
            Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
{
  SVMImpl* this_ = dynamic_cast<SVMImpl*>(this);
  if (!this_) {
    CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
  }
  return this_->trainAuto_(samples, layout, responses,
    kfold, Cgrid, gammaGrid, pGrid, nuGrid, coeffGrid, degreeGrid, balanced);
}

2372 2373
}
}
2374 2375

/* End of file. */