tree.cpp 55 KB
Newer Older
1 2 3 4 5 6 7 8 9
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
10 11
//                           License Agreement
//                For Open Source Computer Vision Library
12 13
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
14
// Copyright (C) 2014, Itseez Inc, all rights reserved.
15 16 17 18 19 20 21 22 23 24 25 26
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
27
//   * The name of the copyright holders may not be used to endorse or promote products
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"
#include <ctype.h>

46 47
namespace cv {
namespace ml {
48

49
using std::vector;
50

51
TreeParams::TreeParams()
52
{
53 54 55 56 57 58 59 60 61
    maxDepth = INT_MAX;
    minSampleCount = 10;
    regressionAccuracy = 0.01f;
    useSurrogates = false;
    maxCategories = 10;
    CVFolds = 10;
    use1SERule = true;
    truncatePrunedTree = true;
    priors = Mat();
62 63
}

64 65 66 67 68
TreeParams::TreeParams(int _maxDepth, int _minSampleCount,
                       double _regressionAccuracy, bool _useSurrogates,
                       int _maxCategories, int _CVFolds,
                       bool _use1SERule, bool _truncatePrunedTree,
                       const Mat& _priors)
69
{
70 71 72 73 74 75 76 77 78 79
    maxDepth = _maxDepth;
    minSampleCount = _minSampleCount;
    regressionAccuracy = (float)_regressionAccuracy;
    useSurrogates = _useSurrogates;
    maxCategories = _maxCategories;
    CVFolds = _CVFolds;
    use1SERule = _use1SERule;
    truncatePrunedTree = _truncatePrunedTree;
    priors = _priors;
}
80

81
DTrees::Node::Node()
82
{
83 84 85 86
    classIdx = 0;
    value = 0;
    parent = left = right = split = defaultDir = -1;
}
87

88
DTrees::Split::Split()
89
{
90 91 92 93 94 95 96
    varIdx = 0;
    inversed = false;
    quality = 0.f;
    next = -1;
    c = 0.f;
    subsetOfs = 0;
}
97 98


99 100 101 102 103 104
DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
{
    data = _data;
    vector<int> subsampleIdx;
    Mat sidx0 = _data->getTrainSampleIdx();
    if( !sidx0.empty() )
105
    {
106 107
        sidx0.copyTo(sidx);
        std::sort(sidx.begin(), sidx.end());
108 109 110
    }
    else
    {
111 112
        int n = _data->getNSamples();
        setRangeVector(sidx, n);
113 114
    }

115 116
    maxSubsetSize = 0;
}
Andrey Kamaev's avatar
Andrey Kamaev committed
117

118
DTreesImpl::DTreesImpl() : _isClassifier(false) {}
119 120 121 122 123 124 125 126 127 128 129 130 131
DTreesImpl::~DTreesImpl() {}
void DTreesImpl::clear()
{
    varIdx.clear();
    compVarIdx.clear();
    varType.clear();
    catOfs.clear();
    catMap.clear();
    roots.clear();
    nodes.clear();
    splits.clear();
    subsets.clear();
    classLabels.clear();
Andrey Kamaev's avatar
Andrey Kamaev committed
132

133 134 135
    w.release();
    _isClassifier = false;
}
136

137 138 139 140
void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
{
    clear();
    w = makePtr<WorkData>(data);
Andrey Kamaev's avatar
Andrey Kamaev committed
141

142 143
    Mat vtype = data->getVarType();
    vtype.copyTo(varType);
144

145 146 147
    data->getCatOfs().copyTo(catOfs);
    data->getCatMap().copyTo(catMap);
    data->getDefaultSubstValues().copyTo(missingSubst);
148

149
    int nallvars = data->getNAllVars();
150

151 152 153 154 155
    Mat vidx0 = data->getVarIdx();
    if( !vidx0.empty() )
        vidx0.copyTo(varIdx);
    else
        setRangeVector(varIdx, nallvars);
156

157
    initCompVarIdx();
158

159
    w->maxSubsetSize = 0;
Andrey Kamaev's avatar
Andrey Kamaev committed
160

161 162 163
    int i, nvars = (int)varIdx.size();
    for( i = 0; i < nvars; i++ )
        w->maxSubsetSize = std::max(w->maxSubsetSize, getCatCount(varIdx[i]));
Andrey Kamaev's avatar
Andrey Kamaev committed
164

165
    w->maxSubsetSize = std::max((w->maxSubsetSize + 31)/32, 1);
166

167
    data->getSampleWeights().copyTo(w->sample_weights);
168

169
    _isClassifier = data->getResponseType() == VAR_CATEGORICAL;
170

171
    if( _isClassifier )
172
    {
173 174 175
        data->getNormCatResponses().copyTo(w->cat_responses);
        data->getClassLabels().copyTo(classLabels);
        int nclasses = (int)classLabels.size();
176

177 178 179 180 181 182 183 184 185 186
        Mat class_weights = params.priors;
        if( !class_weights.empty() )
        {
            if( class_weights.type() != CV_64F || !class_weights.isContinuous() )
            {
                Mat temp;
                class_weights.convertTo(temp, CV_64F);
                class_weights = temp;
            }
            CV_Assert( class_weights.checkVector(1, CV_64F) == nclasses );
187

188 189 190
            int nsamples = (int)w->cat_responses.size();
            const double* cw = class_weights.ptr<double>();
            CV_Assert( (int)w->sample_weights.size() == nsamples );
191

192 193 194 195 196 197 198
            for( i = 0; i < nsamples; i++ )
            {
                int ci = w->cat_responses[i];
                CV_Assert( 0 <= ci && ci < nclasses );
                w->sample_weights[i] *= cw[ci];
            }
        }
199 200
    }
    else
201 202 203 204 205 206 207 208 209 210
        data->getResponses().copyTo(w->ord_responses);
}


void DTreesImpl::initCompVarIdx()
{
    int nallvars = (int)varType.size();
    compVarIdx.assign(nallvars, -1);
    int i, nvars = (int)varIdx.size(), prevIdx = -1;
    for( i = 0; i < nvars; i++ )
211
    {
212 213 214 215
        int vi = varIdx[i];
        CV_Assert( 0 <= vi && vi < nallvars && vi > prevIdx );
        prevIdx = vi;
        compVarIdx[vi] = i;
Andrey Kamaev's avatar
Andrey Kamaev committed
216
    }
217 218 219 220 221 222
}

void DTreesImpl::endTraining()
{
    w.release();
}
223

224 225 226 227 228 229 230 231
bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
{
    startTraining(trainData, flags);
    bool ok = addTree( w->sidx ) >= 0;
    w.release();
    endTraining();
    return ok;
}
Andrey Kamaev's avatar
Andrey Kamaev committed
232

233 234 235 236
const vector<int>& DTreesImpl::getActiveVars()
{
    return varIdx;
}
237

238 239
int DTreesImpl::addTree(const vector<int>& sidx )
{
240
    size_t n = (params.getMaxDepth() > 0 ? (1 << params.getMaxDepth()) : 1024) + w->wnodes.size();
241

242 243 244 245 246 247
    w->wnodes.reserve(n);
    w->wsplits.reserve(n);
    w->wsubsets.reserve(n*w->maxSubsetSize);
    w->wnodes.clear();
    w->wsplits.clear();
    w->wsubsets.clear();
248

249
    int cv_n = params.getCVFolds();
250

251
    if( cv_n > 0 )
252
    {
253 254 255
        w->cv_Tn.resize(n*cv_n);
        w->cv_node_error.resize(n*cv_n);
        w->cv_node_risk.resize(n*cv_n);
256 257
    }

258 259 260
    // build the tree recursively
    int w_root = addNodeAndTrySplit(-1, sidx);
    int maxdepth = INT_MAX;//pruneCV(root);
261

262 263
    int w_nidx = w_root, pidx = -1, depth = 0;
    int root = (int)nodes.size();
264

265
    for(;;)
266
    {
267 268 269 270 271 272
        const WNode& wnode = w->wnodes[w_nidx];
        Node node;
        node.parent = pidx;
        node.classIdx = wnode.class_idx;
        node.value = wnode.value;
        node.defaultDir = wnode.defaultDir;
273

274 275
        int wsplit_idx = wnode.split;
        if( wsplit_idx >= 0 )
276
        {
277 278 279 280 281 282 283 284
            const WSplit& wsplit = w->wsplits[wsplit_idx];
            Split split;
            split.c = wsplit.c;
            split.quality = wsplit.quality;
            split.inversed = wsplit.inversed;
            split.varIdx = wsplit.varIdx;
            split.subsetOfs = -1;
            if( wsplit.subsetOfs >= 0 )
285
            {
286 287 288
                int ssize = getSubsetSize(split.varIdx);
                split.subsetOfs = (int)subsets.size();
                subsets.resize(split.subsetOfs + ssize);
289 290 291 292 293 294 295 296
                // This check verifies that subsets index is in the correct range
                // as in case ssize == 0 no real resize performed.
                // Thus memory kept safe.
                // Also this skips useless memcpy call when size parameter is zero
                if(ssize > 0)
                {
                    memcpy(&subsets[split.subsetOfs], &w->wsubsets[wsplit.subsetOfs], ssize*sizeof(int));
                }
Andrey Kamaev's avatar
Andrey Kamaev committed
297
            }
298 299
            node.split = (int)splits.size();
            splits.push_back(split);
300
        }
301 302 303
        int nidx = (int)nodes.size();
        nodes.push_back(node);
        if( pidx >= 0 )
304
        {
305 306
            int w_pidx = w->wnodes[w_nidx].parent;
            if( w->wnodes[w_pidx].left == w_nidx )
307
            {
308
                nodes[pidx].left = nidx;
309 310 311
            }
            else
            {
312 313
                CV_Assert(w->wnodes[w_pidx].right == w_nidx);
                nodes[pidx].right = nidx;
314
            }
315
        }
316

317 318 319 320 321
        if( wnode.left >= 0 && depth+1 < maxdepth )
        {
            w_nidx = wnode.left;
            pidx = nidx;
            depth++;
322 323 324
        }
        else
        {
325 326
            int w_pidx = wnode.parent;
            while( w_pidx >= 0 && w->wnodes[w_pidx].right == w_nidx )
327
            {
328 329 330 331 332
                w_nidx = w_pidx;
                w_pidx = w->wnodes[w_pidx].parent;
                nidx = pidx;
                pidx = nodes[pidx].parent;
                depth--;
333 334
            }

335 336
            if( w_pidx < 0 )
                break;
337

338
            w_nidx = w->wnodes[w_pidx].right;
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
339
            CV_Assert( w_nidx >= 0 );
340 341
        }
    }
342 343
    roots.push_back(root);
    return root;
344 345
}

346
void DTreesImpl::setDParams(const TreeParams& _params)
347
{
348
    params = _params;
349
}
350

351 352 353 354 355
int DTreesImpl::addNodeAndTrySplit( int parent, const vector<int>& sidx )
{
    w->wnodes.push_back(WNode());
    int nidx = (int)(w->wnodes.size() - 1);
    WNode& node = w->wnodes.back();
356

357 358
    node.parent = parent;
    node.depth = parent >= 0 ? w->wnodes[parent].depth + 1 : 0;
359
    int nfolds = params.getCVFolds();
360

361
    if( nfolds > 0 )
362
    {
363 364 365
        w->cv_Tn.resize((nidx+1)*nfolds);
        w->cv_node_error.resize((nidx+1)*nfolds);
        w->cv_node_risk.resize((nidx+1)*nfolds);
366 367
    }

368 369 370
    int i, n = node.sample_count = (int)sidx.size();
    bool can_split = true;
    vector<int> sleft, sright;
371

372
    calcValue( nidx, sidx );
373

374
    if( n <= params.getMinSampleCount() || node.depth >= params.getMaxDepth() )
375 376 377 378 379 380 381 382 383 384 385 386 387 388
        can_split = false;
    else if( _isClassifier )
    {
        const int* responses = &w->cat_responses[0];
        const int* s = &sidx[0];
        int first = responses[s[0]];
        for( i = 1; i < n; i++ )
            if( responses[s[i]] != first )
                break;
        if( i == n )
            can_split = false;
    }
    else
    {
389
        if( sqrt(node.node_risk) < params.getRegressionAccuracy() )
390 391
            can_split = false;
    }
392

393 394
    if( can_split )
        node.split = findBestSplit( sidx );
395

396
    //printf("depth=%d, nidx=%d, parent=%d, n=%d, %s, value=%.1f, risk=%.1f\n", node.depth, nidx, node.parent, n, (node.split < 0 ? "leaf" : varType[w->wsplits[node.split].varIdx] == VAR_CATEGORICAL ? "cat" : "ord"), node.value, node.node_risk);
397

398 399 400 401 402
    if( node.split >= 0 )
    {
        node.defaultDir = calcDir( node.split, sidx, sleft, sright );
        if( params.useSurrogates )
            CV_Error( CV_StsNotImplemented, "surrogate splits are not implemented yet");
403

404 405 406 407
        int left = addNodeAndTrySplit( nidx, sleft );
        int right = addNodeAndTrySplit( nidx, sright );
        w->wnodes[nidx].left = left;
        w->wnodes[nidx].right = right;
408
        CV_Assert( w->wnodes[nidx].left > 0 && w->wnodes[nidx].right > 0 );
409 410
    }

411
    return nidx;
412 413
}

414
int DTreesImpl::findBestSplit( const vector<int>& _sidx )
415
{
416 417 418 419
    const vector<int>& activeVars = getActiveVars();
    int splitidx = -1;
    int vi_, nv = (int)activeVars.size();
    AutoBuffer<int> buf(w->maxSubsetSize*2);
420
    int *subset = buf.data(), *best_subset = subset + w->maxSubsetSize;
421 422
    WSplit split, best_split;
    best_split.quality = 0.;
423

424
    for( vi_ = 0; vi_ < nv; vi_++ )
425
    {
426 427
        int vi = activeVars[vi_];
        if( varType[vi] == VAR_CATEGORICAL )
428
        {
429 430 431 432
            if( _isClassifier )
                split = findSplitCatClass(vi, _sidx, 0, subset);
            else
                split = findSplitCatReg(vi, _sidx, 0, subset);
433
        }
434
        else
435
        {
436 437 438 439
            if( _isClassifier )
                split = findSplitOrdClass(vi, _sidx, 0);
            else
                split = findSplitOrdReg(vi, _sidx, 0);
440
        }
441
        if( split.quality > best_split.quality )
442
        {
443 444
            best_split = split;
            std::swap(subset, best_subset);
445 446 447
        }
    }

448 449 450 451 452 453 454 455 456 457 458 459
    if( best_split.quality > 0 )
    {
        int best_vi = best_split.varIdx;
        CV_Assert( compVarIdx[best_split.varIdx] >= 0 && best_vi >= 0 );
        int i, prevsz = (int)w->wsubsets.size(), ssize = getSubsetSize(best_vi);
        w->wsubsets.resize(prevsz + ssize);
        for( i = 0; i < ssize; i++ )
            w->wsubsets[prevsz + i] = best_subset[i];
        best_split.subsetOfs = prevsz;
        w->wsplits.push_back(best_split);
        splitidx = (int)(w->wsplits.size()-1);
    }
460

461
    return splitidx;
462 463
}

464
void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx )
465
{
466
    WNode* node = &w->wnodes[nidx];
467
    int i, j, k, n = (int)_sidx.size(), cv_n = params.getCVFolds();
468
    int m = (int)classLabels.size();
469

470
    cv::AutoBuffer<double> buf(std::max(m, 3)*(cv_n+1));
471

472 473 474 475 476 477 478
    if( cv_n > 0 )
    {
        size_t sz = w->cv_Tn.size();
        w->cv_Tn.resize(sz + cv_n);
        w->cv_node_risk.resize(sz + cv_n);
        w->cv_node_error.resize(sz + cv_n);
    }
479

480
    if( _isClassifier )
481 482 483 484 485 486 487 488 489 490
    {
        // in case of classification tree:
        //  * node value is the label of the class that has the largest weight in the node.
        //  * node risk is the weighted number of misclassified samples,
        //  * j-th cross-validation fold value and risk are calculated as above,
        //    but using the samples with cv_labels(*)!=j.
        //  * j-th cross-validation fold error is calculated as the weighted number of
        //    misclassified samples with cv_labels(*)==j.

        // compute the number of instances of each class
491
        double* cls_count = buf.data();
492 493
        double* cv_cls_count = cls_count + m;

494 495 496 497 498 499 500 501 502
        double max_val = -1, total_weight = 0;
        int max_k = -1;

        for( k = 0; k < m; k++ )
            cls_count[k] = 0;

        if( cv_n == 0 )
        {
            for( i = 0; i < n; i++ )
503 504 505 506
            {
                int si = _sidx[i];
                cls_count[w->cat_responses[si]] += w->sample_weights[si];
            }
507 508 509 510 511 512 513 514 515
        }
        else
        {
            for( j = 0; j < cv_n; j++ )
                for( k = 0; k < m; k++ )
                    cv_cls_count[j*m + k] = 0;

            for( i = 0; i < n; i++ )
            {
516 517 518
                int si = _sidx[i];
                j = w->cv_labels[si]; k = w->cat_responses[si];
                cv_cls_count[j*m + k] += w->sample_weights[si];
519 520 521 522 523 524 525 526 527
            }

            for( j = 0; j < cv_n; j++ )
                for( k = 0; k < m; k++ )
                    cls_count[k] += cv_cls_count[j*m + k];
        }

        for( k = 0; k < m; k++ )
        {
528
            double val = cls_count[k];
529 530 531 532 533 534 535 536 537
            total_weight += val;
            if( max_val < val )
            {
                max_val = val;
                max_k = k;
            }
        }

        node->class_idx = max_k;
538
        node->value = classLabels[max_k];
539 540 541 542 543 544 545 546 547
        node->node_risk = total_weight - max_val;

        for( j = 0; j < cv_n; j++ )
        {
            double sum_k = 0, sum = 0, max_val_k = 0;
            max_val = -1; max_k = -1;

            for( k = 0; k < m; k++ )
            {
548 549
                double val_k = cv_cls_count[j*m + k];
                double val = cls_count[k] - val_k;
550 551 552 553 554 555 556 557 558 559
                sum_k += val_k;
                sum += val;
                if( max_val < val )
                {
                    max_val = val;
                    max_val_k = val_k;
                    max_k = k;
                }
            }

560 561 562
            w->cv_Tn[nidx*cv_n + j] = INT_MAX;
            w->cv_node_risk[nidx*cv_n + j] = sum - max_val;
            w->cv_node_error[nidx*cv_n + j] = sum_k - max_val_k;
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578
        }
    }
    else
    {
        // in case of regression tree:
        //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
        //    n is the number of samples in the node.
        //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
        //  * j-th cross-validation fold value and risk are calculated as above,
        //    but using the samples with cv_labels(*)!=j.
        //  * j-th cross-validation fold error is calculated
        //    using samples with cv_labels(*)==j as the test subset:
        //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
        //    where node_value_j is the node value calculated
        //    as described in the previous bullet, and summation is done
        //    over the samples with cv_labels(*)==j.
579
        double sum = 0, sum2 = 0, sumw = 0;
580 581 582 583 584

        if( cv_n == 0 )
        {
            for( i = 0; i < n; i++ )
            {
585 586 587 588 589 590
                int si = _sidx[i];
                double wval = w->sample_weights[si];
                double t = w->ord_responses[si];
                sum += t*wval;
                sum2 += t*t*wval;
                sumw += wval;
591 592 593 594
            }
        }
        else
        {
595
            double *cv_sum = buf.data(), *cv_sum2 = cv_sum + cv_n;
596
            double* cv_count = (double*)(cv_sum2 + cv_n);
597 598 599 600 601 602 603 604 605

            for( j = 0; j < cv_n; j++ )
            {
                cv_sum[j] = cv_sum2[j] = 0.;
                cv_count[j] = 0;
            }

            for( i = 0; i < n; i++ )
            {
606 607 608 609 610 611 612
                int si = _sidx[i];
                j = w->cv_labels[si];
                double wval = w->sample_weights[si];
                double t = w->ord_responses[si];
                cv_sum[j] += t*wval;
                cv_sum2[j] += t*t*wval;
                cv_count[j] += wval;
613
            }
614

615 616 617 618
            for( j = 0; j < cv_n; j++ )
            {
                sum += cv_sum[j];
                sum2 += cv_sum2[j];
619 620 621 622 623 624 625 626 627 628 629 630
                sumw += cv_count[j];
            }

            for( j = 0; j < cv_n; j++ )
            {
                double s = sum - cv_sum[j], si = sum - s;
                double s2 = sum2 - cv_sum2[j], s2i = sum2 - s2;
                double c = cv_count[j], ci = sumw - c;
                double r = si/std::max(ci, DBL_EPSILON);
                w->cv_node_risk[nidx*cv_n + j] = s2i - r*r*ci;
                w->cv_node_error[nidx*cv_n + j] = s2 - 2*r*s + c*r*r;
                w->cv_Tn[nidx*cv_n + j] = INT_MAX;
631 632
            }
        }
633
        CV_Assert(fabs(sumw) > 0);
634
        node->node_risk = sum2 - (sum/sumw)*sum;
635
        node->node_risk /= sumw;
636 637 638
        node->value = sum/sumw;
    }
}
639

640 641 642 643
DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality )
{
    int n = (int)_sidx.size();
    int m = (int)classLabels.size();
644

645 646 647 648
    cv::AutoBuffer<uchar> buf(n*(sizeof(float) + sizeof(int)) + m*2*sizeof(double));
    const int* sidx = &_sidx[0];
    const int* responses = &w->cat_responses[0];
    const double* weights = &w->sample_weights[0];
649
    double* lcw = (double*)buf.data();
650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
    double* rcw = lcw + m;
    float* values = (float*)(rcw + m);
    int* sorted_idx = (int*)(values + n);
    int i, best_i = -1;
    double best_val = initQuality;

    for( i = 0; i < m; i++ )
        lcw[i] = rcw[i] = 0.;

    w->data->getValues( vi, _sidx, values );

    for( i = 0; i < n; i++ )
    {
        sorted_idx[i] = i;
        int si = sidx[i];
        rcw[responses[si]] += weights[si];
    }

    std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));

    double L = 0, R = 0, lsum2 = 0, rsum2 = 0;
    for( i = 0; i < m; i++ )
    {
        double wval = rcw[i];
        R += wval;
        rsum2 += wval*wval;
    }

    for( i = 0; i < n - 1; i++ )
    {
        int curr = sorted_idx[i];
        int next = sorted_idx[i+1];
        int si = sidx[curr];
        double wval = weights[si], w2 = wval*wval;
        L += wval; R -= wval;
        int idx = responses[si];
        double lv = lcw[idx], rv = rcw[idx];
        lsum2 += 2*lv*wval + w2;
        rsum2 -= 2*rv*wval - w2;
        lcw[idx] = lv + wval; rcw[idx] = rv - wval;

691 692
        float value_between = (values[next] + values[curr]) * 0.5f;
        if( value_between > values[curr] && value_between < values[next] )
693
        {
694 695 696 697 698 699
            double val = (lsum2*R + rsum2*L)/(L*R);
            if( best_val < val )
            {
                best_val = val;
                best_i = i;
            }
700 701 702
        }
    }

703 704 705 706 707
    WSplit split;
    if( best_i >= 0 )
    {
        split.varIdx = vi;
        split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
708
        split.inversed = false;
709 710 711 712
        split.quality = (float)best_val;
    }
    return split;
}
713

714 715
// simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
void DTreesImpl::clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels )
716
{
717 718 719
    int iters = 0, max_iters = 100;
    int i, j, idx;
    cv::AutoBuffer<double> buf(n + k);
720
    double *v_weights = buf.data(), *c_weights = buf.data() + n;
721
    bool modified = true;
722
    RNG r((uint64)-1);
723 724 725 726

    // assign labels randomly
    for( i = 0; i < n; i++ )
    {
727
        double sum = 0;
728 729
        const double* v = vectors + i*m;
        labels[i] = i < k ? i : r.uniform(0, k);
730

731 732 733 734 735 736 737 738 739 740 741 742 743 744
        // compute weight of each vector
        for( j = 0; j < m; j++ )
            sum += v[j];
        v_weights[i] = sum ? 1./sum : 0.;
    }

    for( i = 0; i < n; i++ )
    {
        int i1 = r.uniform(0, n);
        int i2 = r.uniform(0, n);
        std::swap( labels[i1], labels[i2] );
    }

    for( iters = 0; iters <= max_iters; iters++ )
745
    {
746 747 748 749 750 751 752 753
        // calculate csums
        for( i = 0; i < k; i++ )
        {
            for( j = 0; j < m; j++ )
                csums[i*m + j] = 0;
        }

        for( i = 0; i < n; i++ )
754
        {
755 756 757 758 759
            const double* v = vectors + i*m;
            double* s = csums + labels[i]*m;
            for( j = 0; j < m; j++ )
                s[j] += v[j];
        }
760

761 762 763
        // exit the loop here, when we have up-to-date csums
        if( iters == max_iters || !modified )
            break;
764

765
        modified = false;
Andrey Kamaev's avatar
Andrey Kamaev committed
766

767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785
        // calculate weight of each cluster
        for( i = 0; i < k; i++ )
        {
            const double* s = csums + i*m;
            double sum = 0;
            for( j = 0; j < m; j++ )
                sum += s[j];
            c_weights[i] = sum ? 1./sum : 0;
        }

        // now for each vector determine the closest cluster
        for( i = 0; i < n; i++ )
        {
            const double* v = vectors + i*m;
            double alpha = v_weights[i];
            double min_dist2 = DBL_MAX;
            int min_idx = -1;

            for( idx = 0; idx < k; idx++ )
786
            {
787 788 789
                const double* s = csums + idx*m;
                double dist2 = 0., beta = c_weights[idx];
                for( j = 0; j < m; j++ )
790
                {
791 792 793 794 795 796 797
                    double t = v[j]*alpha - s[j]*beta;
                    dist2 += t*t;
                }
                if( min_dist2 > dist2 )
                {
                    min_dist2 = dist2;
                    min_idx = idx;
798 799
                }
            }
800

801 802 803
            if( min_idx != labels[i] )
                modified = true;
            labels[i] = min_idx;
804 805
        }
    }
806
}
807

808 809 810 811 812 813 814 815
DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _sidx,
                                                  double initQuality, int* subset )
{
    int _mi = getCatCount(vi), mi = _mi;
    int n = (int)_sidx.size();
    int m = (int)classLabels.size();

    int base_size = m*(3 + mi) + mi + 1;
816 817
    if( m > 2 && mi > params.getMaxCategories() )
        base_size += m*std::min(params.getMaxCategories(), n) + mi;
818 819 820 821
    else
        base_size += mi;
    AutoBuffer<double> buf(base_size + n);

822
    double* lc = buf.data();
823 824 825 826
    double* rc = lc + m;
    double* _cjk = rc + m*2, *cjk = _cjk;
    double* c_weights = cjk + m*mi;

827
    int* labels = (int*)(buf.data() + base_size);
828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845
    w->data->getNormCatValues(vi, _sidx, labels);
    const int* responses = &w->cat_responses[0];
    const double* weights = &w->sample_weights[0];

    int* cluster_labels = 0;
    double** dbl_ptr = 0;
    int i, j, k, si, idx;
    double L = 0, R = 0;
    double best_val = initQuality;
    int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;

    // init array of counters:
    // c_{jk} - number of samples that have vi-th input variable = j and response = k.
    for( j = -1; j < mi; j++ )
        for( k = 0; k < m; k++ )
            cjk[j*m + k] = 0;

    for( i = 0; i < n; i++ )
846
    {
847 848 849 850
        si = _sidx[i];
        j = labels[i];
        k = responses[si];
        cjk[j*m + k] += weights[si];
851 852
    }

853
    if( m > 2 )
854
    {
855
        if( mi > params.getMaxCategories() )
856
        {
857
            mi = std::min(params.getMaxCategories(), n);
858 859 860
            cjk = c_weights + _mi;
            cluster_labels = (int*)(cjk + m*mi);
            clusterCategories( _cjk, _mi, m, cjk, mi, cluster_labels );
861
        }
862 863
        subset_i = 1;
        subset_n = 1 << mi;
864
    }
865
    else
866
    {
867 868 869 870 871 872 873
        assert( m == 2 );
        dbl_ptr = (double**)(c_weights + _mi);
        for( j = 0; j < mi; j++ )
            dbl_ptr[j] = cjk + j*2 + 1;
        std::sort(dbl_ptr, dbl_ptr + mi, cmp_lt_ptr<double>());
        subset_i = 0;
        subset_n = mi;
874 875
    }

876 877 878 879 880 881 882 883 884
    for( k = 0; k < m; k++ )
    {
        double sum = 0;
        for( j = 0; j < mi; j++ )
            sum += cjk[j*m + k];
        CV_Assert(sum > 0);
        rc[k] = sum;
        lc[k] = 0;
    }
885

886 887 888 889 890 891 892 893
    for( j = 0; j < mi; j++ )
    {
        double sum = 0;
        for( k = 0; k < m; k++ )
            sum += cjk[j*m + k];
        c_weights[j] = sum;
        R += c_weights[j];
    }
894

895
    for( ; subset_i < subset_n; subset_i++ )
896
    {
897
        double lsum2 = 0, rsum2 = 0;
898

899 900 901 902 903 904
        if( m == 2 )
            idx = (int)(dbl_ptr[subset_i] - cjk)/2;
        else
        {
            int graycode = (subset_i>>1)^subset_i;
            int diff = graycode ^ prevcode;
905

906 907 908 909 910 911 912 913
            // determine index of the changed bit.
            Cv32suf u;
            idx = diff >= (1 << 16) ? 16 : 0;
            u.f = (float)(((diff >> 16) | diff) & 65535);
            idx += (u.i >> 23) - 127;
            subtract = graycode < prevcode;
            prevcode = graycode;
        }
914

915 916 917 918
        double* crow = cjk + idx*m;
        double weight = c_weights[idx];
        if( weight < FLT_EPSILON )
            continue;
919

920
        if( !subtract )
921
        {
922
            for( k = 0; k < m; k++ )
923
            {
924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943
                double t = crow[k];
                double lval = lc[k] + t;
                double rval = rc[k] - t;
                lsum2 += lval*lval;
                rsum2 += rval*rval;
                lc[k] = lval; rc[k] = rval;
            }
            L += weight;
            R -= weight;
        }
        else
        {
            for( k = 0; k < m; k++ )
            {
                double t = crow[k];
                double lval = lc[k] - t;
                double rval = rc[k] + t;
                lsum2 += lval*lval;
                rsum2 += rval*rval;
                lc[k] = lval; rc[k] = rval;
944
            }
945 946 947
            L -= weight;
            R += weight;
        }
948

949 950 951 952 953 954 955 956 957 958
        if( L > FLT_EPSILON && R > FLT_EPSILON )
        {
            double val = (lsum2*R + rsum2*L)/(L*R);
            if( best_val < val )
            {
                best_val = val;
                best_subset = subset_i;
            }
        }
    }
959

960 961 962 963 964 965 966 967 968
    WSplit split;
    if( best_subset >= 0 )
    {
        split.varIdx = vi;
        split.quality = (float)best_val;
        memset( subset, 0, getSubsetSize(vi) * sizeof(int) );
        if( m == 2 )
        {
            for( i = 0; i <= best_subset; i++ )
969
            {
970 971
                idx = (int)(dbl_ptr[i] - cjk) >> 1;
                subset[idx >> 5] |= 1 << (idx & 31);
972 973 974 975
            }
        }
        else
        {
976
            for( i = 0; i < _mi; i++ )
977
            {
978 979 980
                idx = cluster_labels ? cluster_labels[i] : i;
                if( best_subset & (1 << idx) )
                    subset[i >> 5] |= 1 << (i & 31);
981
            }
982 983 984 985 986 987 988 989 990 991 992 993
        }
    }
    return split;
}

DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality )
{
    const double* weights = &w->sample_weights[0];
    int n = (int)_sidx.size();

    AutoBuffer<uchar> buf(n*(sizeof(int) + sizeof(float)));

994
    float* values = (float*)buf.data();
995 996 997
    int* sorted_idx = (int*)(values + n);
    w->data->getValues(vi, _sidx, values);
    const double* responses = &w->ord_responses[0];
998

999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022
    int i, si, best_i = -1;
    double L = 0, R = 0;
    double best_val = initQuality, lsum = 0, rsum = 0;

    for( i = 0; i < n; i++ )
    {
        sorted_idx[i] = i;
        si = _sidx[i];
        R += weights[si];
        rsum += weights[si]*responses[si];
    }

    std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));

    // find the optimal split
    for( i = 0; i < n - 1; i++ )
    {
        int curr = sorted_idx[i];
        int next = sorted_idx[i+1];
        si = _sidx[curr];
        double wval = weights[si];
        double t = responses[si]*wval;
        L += wval; R -= wval;
        lsum += t; rsum -= t;
1023

1024 1025
        float value_between = (values[next] + values[curr]) * 0.5f;
        if( value_between > values[curr] && value_between < values[next] )
1026 1027 1028
        {
            double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
            if( best_val < val )
1029
            {
1030 1031
                best_val = val;
                best_i = i;
1032 1033 1034 1035
            }
        }
    }

1036 1037
    WSplit split;
    if( best_i >= 0 )
1038
    {
1039 1040
        split.varIdx = vi;
        split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
1041
        split.inversed = false;
1042 1043 1044 1045
        split.quality = (float)best_val;
    }
    return split;
}
Andrey Kamaev's avatar
Andrey Kamaev committed
1046

1047 1048 1049 1050 1051 1052 1053
DTreesImpl::WSplit DTreesImpl::findSplitCatReg( int vi, const vector<int>& _sidx,
                                                double initQuality, int* subset )
{
    const double* weights = &w->sample_weights[0];
    const double* responses = &w->ord_responses[0];
    int n = (int)_sidx.size();
    int mi = getCatCount(vi);
1054

1055
    AutoBuffer<double> buf(3*mi + 3 + n);
1056
    double* sum = buf.data() + 1;
1057 1058 1059
    double* counts = sum + mi + 1;
    double** sum_ptr = (double**)(counts + mi);
    int* cat_labels = (int*)(sum_ptr + mi);
1060

1061
    w->data->getNormCatValues(vi, _sidx, cat_labels);
1062

1063 1064
    double L = 0, R = 0, best_val = initQuality, lsum = 0, rsum = 0;
    int i, si, best_subset = -1, subset_i;
Andrey Kamaev's avatar
Andrey Kamaev committed
1065

1066 1067
    for( i = -1; i < mi; i++ )
        sum[i] = counts[i] = 0;
1068

1069 1070 1071 1072 1073 1074 1075 1076 1077
    // calculate sum response and weight of each category of the input var
    for( i = 0; i < n; i++ )
    {
        int idx = cat_labels[i];
        si = _sidx[i];
        double wval = weights[si];
        sum[idx] += responses[si]*wval;
        counts[idx] += wval;
    }
Andrey Kamaev's avatar
Andrey Kamaev committed
1078

1079 1080 1081 1082 1083 1084 1085 1086
    // calculate average response in each category
    for( i = 0; i < mi; i++ )
    {
        R += counts[i];
        rsum += sum[i];
        sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;
        sum_ptr[i] = sum + i;
    }
Andrey Kamaev's avatar
Andrey Kamaev committed
1087

1088
    std::sort(sum_ptr, sum_ptr + mi, cmp_lt_ptr<double>());
1089

1090 1091 1092 1093
    // revert back to unnormalized sums
    // (there should be a very little loss in accuracy)
    for( i = 0; i < mi; i++ )
        sum[i] *= counts[i];
1094

1095 1096 1097 1098
    for( subset_i = 0; subset_i < mi-1; subset_i++ )
    {
        int idx = (int)(sum_ptr[subset_i] - sum);
        double ni = counts[idx];
1099

1100 1101 1102 1103 1104
        if( ni > FLT_EPSILON )
        {
            double s = sum[idx];
            lsum += s; L += ni;
            rsum -= s; R -= ni;
1105

1106
            if( L > FLT_EPSILON && R > FLT_EPSILON )
1107
            {
1108 1109 1110 1111 1112 1113
                double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
                if( best_val < val )
                {
                    best_val = val;
                    best_subset = subset_i;
                }
1114
            }
Andrey Kamaev's avatar
Andrey Kamaev committed
1115
        }
1116
    }
1117

1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131
    WSplit split;
    if( best_subset >= 0 )
    {
        split.varIdx = vi;
        split.quality = (float)best_val;
        memset( subset, 0, getSubsetSize(vi) * sizeof(int));
        for( i = 0; i <= best_subset; i++ )
        {
            int idx = (int)(sum_ptr[i] - sum);
            subset[idx >> 5] |= 1 << (idx & 31);
        }
    }
    return split;
}
1132

1133 1134 1135 1136 1137 1138 1139 1140 1141
int DTreesImpl::calcDir( int splitidx, const vector<int>& _sidx,
                         vector<int>& _sleft, vector<int>& _sright )
{
    WSplit split = w->wsplits[splitidx];
    int i, si, n = (int)_sidx.size(), vi = split.varIdx;
    _sleft.reserve(n);
    _sright.reserve(n);
    _sleft.clear();
    _sright.clear();
1142

1143 1144 1145 1146
    AutoBuffer<float> buf(n);
    int mi = getCatCount(vi);
    double wleft = 0, wright = 0;
    const double* weights = &w->sample_weights[0];
1147

1148
    if( mi <= 0 ) // split on an ordered variable
1149
    {
1150
        float c = split.c;
1151
        float* values = buf.data();
1152 1153 1154
        w->data->getValues(vi, _sidx, values);

        for( i = 0; i < n; i++ )
1155
        {
1156 1157
            si = _sidx[i];
            if( values[i] <= c )
1158
            {
1159 1160
                _sleft.push_back(si);
                wleft += weights[si];
1161 1162 1163
            }
            else
            {
1164 1165
                _sright.push_back(si);
                wright += weights[si];
1166 1167 1168 1169 1170
            }
        }
    }
    else
    {
1171
        const int* subset = &w->wsubsets[split.subsetOfs];
1172
        int* cat_labels = (int*)buf.data();
1173 1174 1175
        w->data->getNormCatValues(vi, _sidx, cat_labels);

        for( i = 0; i < n; i++ )
1176
        {
1177 1178 1179
            si = _sidx[i];
            unsigned u = cat_labels[i];
            if( CV_DTREE_CAT_DIR(u, subset) < 0 )
1180
            {
1181 1182
                _sleft.push_back(si);
                wleft += weights[si];
1183 1184 1185
            }
            else
            {
1186 1187
                _sright.push_back(si);
                wright += weights[si];
1188 1189 1190
            }
        }
    }
1191 1192
    CV_Assert( (int)_sleft.size() < n && (int)_sright.size() < n );
    return wleft > wright ? -1 : 1;
1193 1194
}

1195
int DTreesImpl::pruneCV( int root )
1196
{
1197
    vector<double> ab;
1198 1199 1200 1201 1202

    // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
    // 2. choose the best tree index (if need, apply 1SE rule).
    // 3. store the best index and cut the branches.

1203
    int ti, tree_count = 0, j, cv_n = params.getCVFolds(), n = w->wnodes[root].sample_count;
1204
    // currently, 1SE for regression is not implemented
1205
    bool use_1se = params.use1SERule != 0 && _isClassifier;
1206 1207 1208 1209 1210 1211
    double min_err = 0, min_err_se = 0;
    int min_idx = -1;

    // build the main tree sequence, calculate alpha's
    for(;;tree_count++)
    {
1212 1213
        double min_alpha = updateTreeRNC(root, tree_count, -1);
        if( cutTree(root, tree_count, -1, min_alpha) )
1214 1215
            break;

1216
        ab.push_back(min_alpha);
1217 1218 1219 1220
    }

    if( tree_count > 0 )
    {
1221 1222
        ab[0] = 0.;

1223
        for( ti = 1; ti < tree_count-1; ti++ )
1224 1225
            ab[ti] = std::sqrt(ab[ti]*ab[ti+1]);
        ab[tree_count-1] = DBL_MAX*0.5;
1226

1227
        Mat err_jk(cv_n, tree_count, CV_64F);
1228 1229 1230 1231

        for( j = 0; j < cv_n; j++ )
        {
            int tj = 0, tk = 0;
1232
            for( ; tj < tree_count; tj++ )
1233
            {
1234 1235
                double min_alpha = updateTreeRNC(root, tj, j);
                if( cutTree(root, tj, j, min_alpha) )
1236 1237 1238 1239
                    min_alpha = DBL_MAX;

                for( ; tk < tree_count; tk++ )
                {
1240
                    if( ab[tk] > min_alpha )
1241
                        break;
1242
                    err_jk.at<double>(j, tk) = w->wnodes[root].tree_error;
1243 1244 1245 1246 1247 1248 1249 1250
                }
            }
        }

        for( ti = 0; ti < tree_count; ti++ )
        {
            double sum_err = 0;
            for( j = 0; j < cv_n; j++ )
1251
                sum_err += err_jk.at<double>(j, ti);
1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263
            if( ti == 0 || sum_err < min_err )
            {
                min_err = sum_err;
                min_idx = ti;
                if( use_1se )
                    min_err_se = sqrt( sum_err*(n - sum_err) );
            }
            else if( sum_err < min_err + min_err_se )
                min_idx = ti;
        }
    }

1264
    return min_idx;
1265 1266
}

1267
double DTreesImpl::updateTreeRNC( int root, double T, int fold )
1268
{
1269
    int nidx = root, pidx = -1, cv_n = params.getCVFolds();
1270 1271 1272 1273
    double min_alpha = DBL_MAX;

    for(;;)
    {
1274 1275
        WNode *node = 0, *parent = 0;

1276 1277
        for(;;)
        {
1278 1279 1280
            node = &w->wnodes[nidx];
            double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
            if( t <= T || node->left < 0 )
1281 1282 1283 1284 1285 1286
            {
                node->complexity = 1;
                node->tree_risk = node->node_risk;
                node->tree_error = 0.;
                if( fold >= 0 )
                {
1287 1288
                    node->tree_risk = w->cv_node_risk[nidx*cv_n + fold];
                    node->tree_error = w->cv_node_error[nidx*cv_n + fold];
1289 1290 1291
                }
                break;
            }
1292
            nidx = node->left;
1293 1294
        }

1295 1296
        for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
             nidx = pidx, pidx = w->wnodes[pidx].parent )
1297
        {
1298 1299
            node = &w->wnodes[nidx];
            parent = &w->wnodes[pidx];
1300 1301 1302 1303
            parent->complexity += node->complexity;
            parent->tree_risk += node->tree_risk;
            parent->tree_error += node->tree_error;

1304 1305 1306
            parent->alpha = ((fold >= 0 ? w->cv_node_risk[pidx*cv_n + fold] : parent->node_risk)
                             - parent->tree_risk)/(parent->complexity - 1);
            min_alpha = std::min( min_alpha, parent->alpha );
1307 1308
        }

1309
        if( pidx < 0 )
1310 1311
            break;

1312 1313
        node = &w->wnodes[nidx];
        parent = &w->wnodes[pidx];
1314 1315 1316
        parent->complexity = node->complexity;
        parent->tree_risk = node->tree_risk;
        parent->tree_error = node->tree_error;
1317
        nidx = parent->right;
1318 1319 1320 1321 1322
    }

    return min_alpha;
}

1323
bool DTreesImpl::cutTree( int root, double T, int fold, double min_alpha )
1324
{
1325
    int cv_n = params.getCVFolds(), nidx = root, pidx = -1;
1326 1327 1328
    WNode* node = &w->wnodes[root];
    if( node->left < 0 )
        return true;
1329 1330 1331 1332 1333

    for(;;)
    {
        for(;;)
        {
1334 1335 1336
            node = &w->wnodes[nidx];
            double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
            if( t <= T || node->left < 0 )
1337 1338 1339 1340
                break;
            if( node->alpha <= min_alpha + FLT_EPSILON )
            {
                if( fold >= 0 )
1341
                    w->cv_Tn[nidx*cv_n + fold] = T;
1342 1343
                else
                    node->Tn = T;
1344 1345
                if( nidx == root )
                    return true;
1346 1347
                break;
            }
1348
            nidx = node->left;
1349
        }
1350

1351 1352
        for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
             nidx = pidx, pidx = w->wnodes[pidx].parent )
1353
            ;
1354

1355
        if( pidx < 0 )
1356
            break;
1357

1358
        nidx = w->wnodes[pidx].right;
1359
    }
1360

1361
    return false;
1362 1363
}

1364
float DTreesImpl::predictTrees( const Range& range, const Mat& sample, int flags ) const
1365
{
1366
    CV_Assert( sample.type() == CV_32F );
1367

1368 1369 1370 1371 1372 1373 1374
    int predictType = flags & PREDICT_MASK;
    int nvars = (int)varIdx.size();
    if( nvars == 0 )
        nvars = (int)varType.size();
    int i, ncats = (int)catOfs.size(), nclasses = (int)classLabels.size();
    int catbufsize = ncats > 0 ? nvars : 0;
    AutoBuffer<int> buf(nclasses + catbufsize + 1);
1375
    int* votes = buf.data();
1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386
    int* catbuf = votes + nclasses;
    const int* cvidx = (flags & (COMPRESSED_INPUT|PREPROCESSED_INPUT)) == 0 && !varIdx.empty() ? &compVarIdx[0] : 0;
    const uchar* vtype = &varType[0];
    const Vec2i* cofs = !catOfs.empty() ? &catOfs[0] : 0;
    const int* cmap = !catMap.empty() ? &catMap[0] : 0;
    const float* psample = sample.ptr<float>();
    const float* missingSubstPtr = !missingSubst.empty() ? &missingSubst[0] : 0;
    size_t sstep = sample.isContinuous() ? 1 : sample.step/sizeof(float);
    double sum = 0.;
    int lastClassIdx = -1;
    const float MISSED_VAL = TrainData::missingValue();
1387

1388 1389
    for( i = 0; i < catbufsize; i++ )
        catbuf[i] = -1;
1390

1391
    if( predictType == PREDICT_AUTO )
1392
    {
1393 1394
        predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
            PREDICT_SUM : PREDICT_MAX_VOTE;
1395 1396
    }

1397
    if( predictType == PREDICT_MAX_VOTE )
1398
    {
1399 1400
        for( i = 0; i < nclasses; i++ )
            votes[i] = 0;
1401 1402
    }

1403
    for( int ridx = range.start; ridx < range.end; ridx++ )
1404
    {
1405
        int nidx = roots[ridx], prev = nidx, c = 0;
1406

1407
        for(;;)
1408
        {
1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429
            prev = nidx;
            const Node& node = nodes[nidx];
            if( node.split < 0 )
                break;
            const Split& split = splits[node.split];
            int vi = split.varIdx;
            int ci = cvidx ? cvidx[vi] : vi;
            float val = psample[ci*sstep];
            if( val == MISSED_VAL )
            {
                if( !missingSubstPtr )
                {
                    nidx = node.defaultDir < 0 ? node.left : node.right;
                    continue;
                }
                val = missingSubstPtr[vi];
            }

            if( vtype[vi] == VAR_ORDERED )
                nidx = val <= split.c ? node.left : node.right;
            else
1430
            {
1431
                if( flags & PREPROCESSED_INPUT )
1432 1433 1434 1435 1436 1437
                    c = cvRound(val);
                else
                {
                    c = catbuf[ci];
                    if( c < 0 )
                    {
1438 1439
                        int a = c = cofs[vi][0];
                        int b = cofs[vi][1];
Andrey Kamaev's avatar
Andrey Kamaev committed
1440

1441 1442
                        int ival = cvRound(val);
                        if( ival != val )
1443
                            CV_Error( CV_StsBadArg,
1444
                                     "one of input categorical variable is not an integer" );
Andrey Kamaev's avatar
Andrey Kamaev committed
1445

1446
                        CV_Assert(cmap != NULL);
1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457
                        while( a < b )
                        {
                            c = (a + b) >> 1;
                            if( ival < cmap[c] )
                                b = c;
                            else if( ival > cmap[c] )
                                a = c+1;
                            else
                                break;
                        }

1458
                        CV_Assert( c >= 0 && ival == cmap[c] );
1459

1460 1461
                        c -= cofs[vi][0];
                        catbuf[ci] = c;
1462
                    }
1463 1464 1465
                    const int* subset = &subsets[split.subsetOfs];
                    unsigned u = c;
                    nidx = CV_DTREE_CAT_DIR(u, subset) < 0 ? node.left : node.right;
1466 1467
                }
            }
1468
        }
1469

1470 1471 1472 1473 1474 1475
        if( predictType == PREDICT_SUM )
            sum += nodes[prev].value;
        else
        {
            lastClassIdx = nodes[prev].classIdx;
            votes[lastClassIdx]++;
1476
        }
1477
    }
1478

1479 1480 1481 1482
    if( predictType == PREDICT_MAX_VOTE )
    {
        int best_idx = lastClassIdx;
        if( range.end - range.start > 1 )
1483
        {
1484 1485 1486 1487
            best_idx = 0;
            for( i = 1; i < nclasses; i++ )
                if( votes[best_idx] < votes[i] )
                    best_idx = i;
1488
        }
1489
        sum = (flags & RAW_OUTPUT) ? (float)best_idx : classLabels[best_idx];
1490 1491
    }

1492
    return (float)sum;
1493 1494 1495
}


1496
float DTreesImpl::predict( InputArray _samples, OutputArray _results, int flags ) const
1497
{
1498 1499 1500 1501 1502 1503 1504 1505
    CV_Assert( !roots.empty() );
    Mat samples = _samples.getMat(), results;
    int i, nsamples = samples.rows;
    int rtype = CV_32F;
    bool needresults = _results.needed();
    float retval = 0.f;
    bool iscls = isClassifier();
    float scale = !iscls ? 1.f/(int)roots.size() : 1.f;
1506

1507 1508
    if( iscls && (flags & PREDICT_MASK) == PREDICT_MAX_VOTE )
        rtype = CV_32S;
1509

1510
    if( needresults )
1511
    {
1512 1513 1514 1515 1516
        _results.create(nsamples, 1, rtype);
        results = _results.getMat();
    }
    else
        nsamples = std::min(nsamples, 1);
1517

1518 1519 1520 1521
    for( i = 0; i < nsamples; i++ )
    {
        float val = predictTrees( Range(0, (int)roots.size()), samples.row(i), flags )*scale;
        if( needresults )
1522
        {
1523 1524 1525 1526 1527 1528 1529 1530 1531 1532
            if( rtype == CV_32F )
                results.at<float>(i) = val;
            else
                results.at<int>(i) = cvRound(val);
        }
        if( i == 0 )
            retval = val;
    }
    return retval;
}
1533

1534 1535
void DTreesImpl::writeTrainingParams(FileStorage& fs) const
{
1536 1537 1538
    fs << "use_surrogates" << (params.useSurrogates ? 1 : 0);
    fs << "max_categories" << params.getMaxCategories();
    fs << "regression_accuracy" << params.getRegressionAccuracy();
1539

1540 1541 1542
    fs << "max_depth" << params.getMaxDepth();
    fs << "min_sample_count" << params.getMinSampleCount();
    fs << "cross_validation_folds" << params.getCVFolds();
1543

1544 1545
    if( params.getCVFolds() > 1 )
        fs << "use_1se_rule" << (params.use1SERule ? 1 : 0);
1546

1547 1548
    if( !params.priors.empty() )
        fs << "priors" << params.priors;
1549
}
1550

1551 1552 1553 1554 1555
void DTreesImpl::writeParams(FileStorage& fs) const
{
    fs << "is_classifier" << isClassifier();
    fs << "var_all" << (int)varType.size();
    fs << "var_count" << getVarCount();
1556

1557 1558 1559 1560 1561 1562 1563 1564 1565
    int ord_var_count = 0, cat_var_count = 0;
    int i, n = (int)varType.size();
    for( i = 0; i < n; i++ )
        if( varType[i] == VAR_ORDERED )
            ord_var_count++;
        else
            cat_var_count++;
    fs << "ord_var_count" << ord_var_count;
    fs << "cat_var_count" << cat_var_count;
1566

1567 1568 1569 1570
    fs << "training_params" << "{";
    writeTrainingParams(fs);

    fs << "}";
1571

1572
    if( !varIdx.empty() )
1573 1574
    {
        fs << "global_var_idx" << 1;
1575
        fs << "var_idx" << varIdx;
1576
    }
1577

1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590
    fs << "var_type" << varType;

    if( !catOfs.empty() )
        fs << "cat_ofs" << catOfs;
    if( !catMap.empty() )
        fs << "cat_map" << catMap;
    if( !classLabels.empty() )
        fs << "class_labels" << classLabels;
    if( !missingSubst.empty() )
        fs << "missing_subst" << missingSubst;
}

void DTreesImpl::writeSplit( FileStorage& fs, int splitidx ) const
1591
{
1592 1593 1594
    const Split& split = splits[splitidx];

    fs << "{:";
1595

1596 1597 1598
    int vi = split.varIdx;
    fs << "var" << vi;
    fs << "quality" << split.quality;
1599

1600
    if( varType[vi] == VAR_CATEGORICAL ) // split on a categorical var
1601
    {
1602 1603
        int i, n = getCatCount(vi), to_right = 0;
        const int* subset = &subsets[split.subsetOfs];
1604
        for( i = 0; i < n; i++ )
1605
            to_right += CV_DTREE_CAT_DIR(i, subset) > 0;
1606 1607 1608

        // ad-hoc rule when to use inverse categorical split notation
        // to achieve more compact and clear representation
1609
        int default_dir = to_right <= 1 || to_right <= std::min(3, n/2) || to_right <= n/3 ? -1 : 1;
1610

1611
        fs << (default_dir*(split.inversed ? -1 : 1) > 0 ? "in" : "not_in") << "[:";
1612 1613 1614

        for( i = 0; i < n; i++ )
        {
1615
            int dir = CV_DTREE_CAT_DIR(i, subset);
1616
            if( dir*default_dir < 0 )
1617
                fs << i;
1618
        }
1619 1620

        fs << "]";
1621 1622
    }
    else
1623
        fs << (!split.inversed ? "le" : "gt") << split.c;
1624

1625
    fs << "}";
1626 1627
}

1628
void DTreesImpl::writeNode( FileStorage& fs, int nidx, int depth ) const
1629
{
1630 1631 1632 1633
    const Node& node = nodes[nidx];
    fs << "{";
    fs << "depth" << depth;
    fs << "value" << node.value;
1634

1635 1636
    if( _isClassifier )
        fs << "norm_class_idx" << node.classIdx;
1637

1638
    if( node.split >= 0 )
1639
    {
1640
        fs << "splits" << "[";
1641

1642 1643
        for( int splitidx = node.split; splitidx >= 0; splitidx = splits[splitidx].next )
            writeSplit( fs, splitidx );
1644

1645
        fs << "]";
1646 1647
    }

1648
    fs << "}";
1649 1650
}

1651
void DTreesImpl::writeTree( FileStorage& fs, int root ) const
1652
{
1653
    fs << "nodes" << "[";
1654

1655 1656
    int nidx = root, pidx = 0, depth = 0;
    const Node *node = 0;
1657 1658 1659 1660 1661 1662

    // traverse the tree and save all the nodes in depth-first order
    for(;;)
    {
        for(;;)
        {
1663 1664 1665
            writeNode( fs, nidx, depth );
            node = &nodes[nidx];
            if( node->left < 0 )
1666
                break;
1667 1668
            nidx = node->left;
            depth++;
1669 1670
        }

1671 1672 1673
        for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
             nidx = pidx, pidx = nodes[pidx].parent )
            depth--;
1674

1675
        if( pidx < 0 )
1676 1677
            break;

1678
        nidx = nodes[pidx].right;
1679 1680
    }

1681
    fs << "]";
1682 1683
}

1684
void DTreesImpl::write( FileStorage& fs ) const
1685
{
1686
    writeFormat(fs);
1687 1688
    writeParams(fs);
    writeTree(fs, roots[0]);
1689 1690
}

1691
void DTreesImpl::readParams( const FileNode& fn )
1692
{
1693 1694 1695 1696 1697
    _isClassifier = (int)fn["is_classifier"] != 0;
    /*int var_all = (int)fn["var_all"];
    int var_count = (int)fn["var_count"];
    int cat_var_count = (int)fn["cat_var_count"];
    int ord_var_count = (int)fn["ord_var_count"];*/
1698

1699
    FileNode tparams_node = fn["training_params"];
1700

1701
    TreeParams params0 = TreeParams();
1702

1703 1704 1705
    if( !tparams_node.empty() ) // training parameters are not necessary
    {
        params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
1706 1707 1708 1709 1710
        params0.setMaxCategories((int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]));
        params0.setRegressionAccuracy((float)tparams_node["regression_accuracy"]);
        params0.setMaxDepth((int)tparams_node["max_depth"]);
        params0.setMinSampleCount((int)tparams_node["min_sample_count"]);
        params0.setCVFolds((int)tparams_node["cross_validation_folds"]);
1711

1712
        if( params0.getCVFolds() > 1 )
1713 1714 1715
        {
            params.use1SERule = (int)tparams_node["use_1se_rule"] != 0;
        }
1716

1717 1718
        tparams_node["priors"] >> params0.priors;
    }
1719

1720
    readVectorOrMat(fn["var_idx"], varIdx);
1721
    fn["var_type"] >> varType;
1722

1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796
    int format = 0;
    fn["format"] >> format;
    bool isLegacy = format < 3;

    int varAll = (int)fn["var_all"];
    if (isLegacy && (int)varType.size() <= varAll)
    {
        std::vector<uchar> extendedTypes(varAll + 1, 0);

        int i = 0, n;
        if (!varIdx.empty())
        {
            n = (int)varIdx.size();
            for (; i < n; ++i)
            {
                int var = varIdx[i];
                extendedTypes[var] = varType[i];
            }
        }
        else
        {
            n = (int)varType.size();
            for (; i < n; ++i)
            {
                extendedTypes[i] = varType[i];
            }
        }
        extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
        extendedTypes.swap(varType);
    }

    readVectorOrMat(fn["cat_map"], catMap);

    if (isLegacy)
    {
        // generating "catOfs" from "cat_count"
        catOfs.clear();
        classLabels.clear();
        std::vector<int> counts;
        readVectorOrMat(fn["cat_count"], counts);
        unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
        for (; i < size; ++i)
        {
            Vec2i newOffsets(0, 0);
            if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
            {
                newOffsets[0] = curShift;
                curShift += counts[j];
                newOffsets[1] = curShift;
                ++j;
            }
            catOfs.push_back(newOffsets);
        }
        // other elements in "catMap" are "classLabels"
        if (curShift < catMap.size())
        {
            classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
            catMap.erase(catMap.begin() + curShift, catMap.end());
        }
    }
    else
    {
        fn["cat_ofs"] >> catOfs;
        fn["missing_subst"] >> missingSubst;
        fn["class_labels"] >> classLabels;
    }

    // init var mapping for node reading (var indexes or varIdx indexes)
    bool globalVarIdx = false;
    fn["global_var_idx"] >> globalVarIdx;
    if (globalVarIdx || varIdx.empty())
        setRangeVector(varMapping, (int)varType.size());
    else
        varMapping = varIdx;
1797

1798 1799 1800
    initCompVarIdx();
    setDParams(params0);
}
1801

1802 1803 1804
int DTreesImpl::readSplit( const FileNode& fn )
{
    Split split;
1805

1806 1807
    int vi = (int)fn["var"];
    CV_Assert( 0 <= vi && vi <= (int)varType.size() );
1808
    vi = varMapping[vi]; // convert to varIdx if needed
1809
    split.varIdx = vi;
1810

1811
    if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
1812
    {
1813 1814 1815 1816 1817 1818 1819
        int i, val, ssize = getSubsetSize(vi);
        split.subsetOfs = (int)subsets.size();
        for( i = 0; i < ssize; i++ )
            subsets.push_back(0);
        int* subset = &subsets[split.subsetOfs];
        FileNode fns = fn["in"];
        if( fns.empty() )
1820
        {
1821 1822
            fns = fn["not_in"];
            split.inversed = true;
1823 1824
        }

1825
        if( fns.isInt() )
1826
        {
1827 1828
            val = (int)fns;
            subset[val >> 5] |= 1 << (val & 31);
1829 1830 1831
        }
        else
        {
1832 1833 1834
            FileNodeIterator it = fns.begin();
            int n = (int)fns.size();
            for( i = 0; i < n; i++, ++it )
1835
            {
1836 1837
                val = (int)*it;
                subset[val >> 5] |= 1 << (val & 31);
1838 1839 1840 1841 1842
            }
        }

        // for categorical splits we do not use inversed splits,
        // instead we inverse the variable set in the split
1843 1844 1845 1846 1847 1848
        if( split.inversed )
        {
            for( i = 0; i < ssize; i++ )
                subset[i] ^= -1;
            split.inversed = false;
        }
1849 1850 1851
    }
    else
    {
1852 1853
        FileNode cmpNode = fn["le"];
        if( cmpNode.empty() )
1854
        {
1855 1856
            cmpNode = fn["gt"];
            split.inversed = true;
1857
        }
1858
        split.c = (float)cmpNode;
1859
    }
1860

1861 1862
    split.quality = (float)fn["quality"];
    splits.push_back(split);
1863

1864
    return (int)(splits.size() - 1);
1865 1866
}

1867
int DTreesImpl::readNode( const FileNode& fn )
1868
{
1869 1870
    Node node;
    node.value = (double)fn["value"];
1871

1872 1873
    if( _isClassifier )
        node.classIdx = (int)fn["norm_class_idx"];
1874

1875 1876
    FileNode sfn = fn["splits"];
    if( !sfn.empty() )
1877
    {
1878 1879
        int i, n = (int)sfn.size(), prevsplit = -1;
        FileNodeIterator it = sfn.begin();
1880

1881
        for( i = 0; i < n; i++, ++it )
1882
        {
1883 1884 1885 1886 1887
            int splitidx = readSplit(*it);
            if( splitidx < 0 )
                break;
            if( prevsplit < 0 )
                node.split = splitidx;
1888
            else
1889 1890
                splits[prevsplit].next = splitidx;
            prevsplit = splitidx;
1891 1892
        }
    }
1893 1894
    nodes.push_back(node);
    return (int)(nodes.size() - 1);
1895 1896
}

1897
int DTreesImpl::readTree( const FileNode& fn )
1898
{
1899 1900
    int i, n = (int)fn.size(), root = -1, pidx = -1;
    FileNodeIterator it = fn.begin();
1901

1902
    for( i = 0; i < n; i++, ++it )
1903
    {
1904 1905 1906 1907 1908 1909 1910
        int nidx = readNode(*it);
        if( nidx < 0 )
            break;
        Node& node = nodes[nidx];
        node.parent = pidx;
        if( pidx < 0 )
            root = nidx;
1911
        else
1912 1913 1914 1915 1916 1917 1918 1919 1920
        {
            Node& parent = nodes[pidx];
            if( parent.left < 0 )
                parent.left = nidx;
            else
                parent.right = nidx;
        }
        if( node.split >= 0 )
            pidx = nidx;
1921 1922
        else
        {
1923 1924
            while( pidx >= 0 && nodes[pidx].right >= 0 )
                pidx = nodes[pidx].parent;
1925 1926
        }
    }
1927 1928
    roots.push_back(root);
    return root;
1929 1930
}

1931
void DTreesImpl::read( const FileNode& fn )
1932 1933
{
    clear();
1934
    readParams(fn);
1935

1936 1937 1938
    FileNode fnodes = fn["nodes"];
    CV_Assert( !fnodes.empty() );
    readTree(fnodes);
1939 1940
}

1941
Ptr<DTrees> DTrees::create()
1942
{
1943
    return makePtr<DTreesImpl>();
1944
}
1945

1946 1947 1948 1949 1950 1951
Ptr<DTrees> DTrees::load(const String& filepath, const String& nodeName)
{
    return Algorithm::load<DTrees>(filepath, nodeName);
}


1952
}
1953 1954
}

1955
/* End of file. */