data.cpp 33.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                        Intel License Agreement
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of Intel Corporation may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"
#include <ctype.h>
43 44
#include <algorithm>
#include <iterator>
45

46
namespace cv { namespace ml {
47

48 49
static const float MISSED_VAL = TrainData::missingValue();
static const int VAR_MISSED = VAR_ORDERED;
50

51
TrainData::~TrainData() {}
52

53 54 55 56 57 58 59
Mat TrainData::getTestSamples() const
{
    Mat idx = getTestSampleIdx();
    Mat samples = getSamples();
    return idx.empty() ? Mat() : getSubVector(samples, idx);
}

60
Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
61
{
62 63 64 65 66 67
    if( idx.empty() )
        return vec;
    int i, j, n = idx.checkVector(1, CV_32S);
    int type = vec.type();
    CV_Assert( type == CV_32S || type == CV_32F || type == CV_64F );
    int dims = 1, m;
68

69 70 71 72 73 74 75 76 77 78
    if( vec.cols == 1 || vec.rows == 1 )
    {
        dims = 1;
        m = vec.cols + vec.rows - 1;
    }
    else
    {
        dims = vec.cols;
        m = vec.rows;
    }
79

80
    Mat subvec;
81

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    if( vec.cols == m )
        subvec.create(dims, n, type);
    else
        subvec.create(n, dims, type);
    if( type == CV_32S )
        for( i = 0; i < n; i++ )
        {
            int k = idx.at<int>(i);
            CV_Assert( 0 <= k && k < m );
            if( dims == 1 )
                subvec.at<int>(i) = vec.at<int>(k);
            else
                for( j = 0; j < dims; j++ )
                    subvec.at<int>(i, j) = vec.at<int>(k, j);
        }
    else if( type == CV_32F )
        for( i = 0; i < n; i++ )
        {
            int k = idx.at<int>(i);
            CV_Assert( 0 <= k && k < m );
            if( dims == 1 )
                subvec.at<float>(i) = vec.at<float>(k);
            else
                for( j = 0; j < dims; j++ )
                    subvec.at<float>(i, j) = vec.at<float>(k, j);
        }
    else
        for( i = 0; i < n; i++ )
        {
            int k = idx.at<int>(i);
            CV_Assert( 0 <= k && k < m );
            if( dims == 1 )
                subvec.at<double>(i) = vec.at<double>(k);
            else
                for( j = 0; j < dims; j++ )
                    subvec.at<double>(i, j) = vec.at<double>(k, j);
        }
    return subvec;
120 121
}

122
class TrainDataImpl : public TrainData
123
{
124 125
public:
    typedef std::map<String, int> MapType;
126

127 128 129 130 131
    TrainDataImpl()
    {
        file = 0;
        clear();
    }
132

133
    virtual ~TrainDataImpl() { closeFile(); }
134

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    int getLayout() const { return layout; }
    int getNSamples() const
    {
        return !sampleIdx.empty() ? (int)sampleIdx.total() :
               layout == ROW_SAMPLE ? samples.rows : samples.cols;
    }
    int getNTrainSamples() const
    {
        return !trainSampleIdx.empty() ? (int)trainSampleIdx.total() : getNSamples();
    }
    int getNTestSamples() const
    {
        return !testSampleIdx.empty() ? (int)testSampleIdx.total() : 0;
    }
    int getNVars() const
    {
        return !varIdx.empty() ? (int)varIdx.total() : getNAllVars();
    }
    int getNAllVars() const
    {
        return layout == ROW_SAMPLE ? samples.cols : samples.rows;
    }
157

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
    Mat getSamples() const { return samples; }
    Mat getResponses() const { return responses; }
    Mat getMissing() const { return missing; }
    Mat getVarIdx() const { return varIdx; }
    Mat getVarType() const { return varType; }
    int getResponseType() const
    {
        return classLabels.empty() ? VAR_ORDERED : VAR_CATEGORICAL;
    }
    Mat getTrainSampleIdx() const { return !trainSampleIdx.empty() ? trainSampleIdx : sampleIdx; }
    Mat getTestSampleIdx() const { return testSampleIdx; }
    Mat getSampleWeights() const
    {
        return sampleWeights;
    }
    Mat getTrainSampleWeights() const
    {
        return getSubVector(sampleWeights, getTrainSampleIdx());
    }
    Mat getTestSampleWeights() const
    {
        Mat idx = getTestSampleIdx();
        return idx.empty() ? Mat() : getSubVector(sampleWeights, idx);
    }
    Mat getTrainResponses() const
    {
        return getSubVector(responses, getTrainSampleIdx());
    }
    Mat getTrainNormCatResponses() const
    {
        return getSubVector(normCatResponses, getTrainSampleIdx());
    }
    Mat getTestResponses() const
    {
        Mat idx = getTestSampleIdx();
        return idx.empty() ? Mat() : getSubVector(responses, idx);
    }
    Mat getTestNormCatResponses() const
    {
        Mat idx = getTestSampleIdx();
        return idx.empty() ? Mat() : getSubVector(normCatResponses, idx);
    }
    Mat getNormCatResponses() const { return normCatResponses; }
    Mat getClassLabels() const { return classLabels; }
    Mat getClassCounters() const { return classCounters; }
    int getCatCount(int vi) const
    {
        int n = (int)catOfs.total();
        CV_Assert( 0 <= vi && vi < n );
        Vec2i ofs = catOfs.at<Vec2i>(vi);
        return ofs[1] - ofs[0];
    }
210

211 212
    Mat getCatOfs() const { return catOfs; }
    Mat getCatMap() const { return catMap; }
213

214
    Mat getDefaultSubstValues() const { return missingSubst; }
215

216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
    void closeFile() { if(file) fclose(file); file=0; }
    void clear()
    {
        closeFile();
        samples.release();
        missing.release();
        varType.release();
        responses.release();
        sampleIdx.release();
        trainSampleIdx.release();
        testSampleIdx.release();
        normCatResponses.release();
        classLabels.release();
        classCounters.release();
        catMap.release();
        catOfs.release();
        nameMap = MapType();
        layout = ROW_SAMPLE;
    }
235

236
    typedef std::map<int, int> CatMapHash;
237

238 239 240
    void setData(InputArray _samples, int _layout, InputArray _responses,
                 InputArray _varIdx, InputArray _sampleIdx, InputArray _sampleWeights,
                 InputArray _varType, InputArray _missing)
241
    {
242
        clear();
243

244 245 246 247 248 249 250 251 252
        CV_Assert(_layout == ROW_SAMPLE || _layout == COL_SAMPLE );
        samples = _samples.getMat();
        layout = _layout;
        responses = _responses.getMat();
        varIdx = _varIdx.getMat();
        sampleIdx = _sampleIdx.getMat();
        sampleWeights = _sampleWeights.getMat();
        varType = _varType.getMat();
        missing = _missing.getMat();
253

254 255 256
        int nsamples = layout == ROW_SAMPLE ? samples.rows : samples.cols;
        int ninputvars = layout == ROW_SAMPLE ? samples.cols : samples.rows;
        int i, noutputvars = 0;
257

258
        CV_Assert( samples.type() == CV_32F || samples.type() == CV_32S );
259

260 261 262
        if( !sampleIdx.empty() )
        {
            CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
263
                       checkRange(sampleIdx, true, 0, 0, nsamples)) ||
264 265 266 267
                       sampleIdx.checkVector(1, CV_8U, true) == nsamples );
            if( sampleIdx.type() == CV_8U )
                sampleIdx = convertMaskToIdx(sampleIdx);
        }
268

269 270 271 272 273 274 275 276
        if( !sampleWeights.empty() )
        {
            CV_Assert( sampleWeights.checkVector(1, CV_32F, true) == nsamples );
        }
        else
        {
            sampleWeights = Mat::ones(nsamples, 1, CV_32F);
        }
277

278 279 280 281 282 283 284 285 286 287
        if( !varIdx.empty() )
        {
            CV_Assert( (varIdx.checkVector(1, CV_32S, true) > 0 &&
                       checkRange(varIdx, true, 0, 0, ninputvars)) ||
                       varIdx.checkVector(1, CV_8U, true) == ninputvars );
            if( varIdx.type() == CV_8U )
                varIdx = convertMaskToIdx(varIdx);
            varIdx = varIdx.clone();
            std::sort(varIdx.ptr<int>(), varIdx.ptr<int>() + varIdx.total());
        }
288

289
        if( !responses.empty() )
290
        {
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
            CV_Assert( responses.type() == CV_32F || responses.type() == CV_32S );
            if( (responses.cols == 1 || responses.rows == 1) && (int)responses.total() == nsamples )
                noutputvars = 1;
            else
            {
                CV_Assert( (layout == ROW_SAMPLE && responses.rows == nsamples) ||
                           (layout == COL_SAMPLE && responses.cols == nsamples) );
                noutputvars = layout == ROW_SAMPLE ? responses.cols : responses.rows;
            }
            if( !responses.isContinuous() || (layout == COL_SAMPLE && noutputvars > 1) )
            {
                Mat temp;
                transpose(responses, temp);
                responses = temp;
            }
306
        }
307

308
        int nvars = ninputvars + noutputvars;
309

310
        if( !varType.empty() )
311
        {
312 313
            CV_Assert( varType.checkVector(1, CV_8U, true) == nvars &&
                       checkRange(varType, true, 0, VAR_ORDERED, VAR_CATEGORICAL+1) );
314 315
        }
        else
316 317 318 319
        {
            varType.create(1, nvars, CV_8U);
            varType = Scalar::all(VAR_ORDERED);
            if( noutputvars == 1 )
320
                varType.at<uchar>(ninputvars) = (uchar)(responses.type() < CV_32F ? VAR_CATEGORICAL : VAR_ORDERED);
321
        }
322

323 324 325 326 327
        if( noutputvars > 1 )
        {
            for( i = 0; i < noutputvars; i++ )
                CV_Assert( varType.at<uchar>(ninputvars + i) == VAR_ORDERED );
        }
Maria Dimashova's avatar
Maria Dimashova committed
328

329 330
        catOfs = Mat::zeros(1, nvars, CV_32SC2);
        missingSubst = Mat::zeros(1, nvars, CV_32F);
331

332 333 334
        vector<int> labels, counters, sortbuf, tempCatMap;
        vector<Vec2i> tempCatOfs;
        CatMapHash ofshash;
335

336 337 338 339 340 341 342
        AutoBuffer<uchar> buf(nsamples);
        Mat non_missing(layout == ROW_SAMPLE ? Size(1, nsamples) : Size(nsamples, 1), CV_8U, (uchar*)buf);
        bool haveMissing = !missing.empty();
        if( haveMissing )
        {
            CV_Assert( missing.size() == samples.size() && missing.type() == CV_8U );
        }
343

344 345 346 347 348
        // we iterate through all the variables. For each categorical variable we build a map
        // in order to convert input values of the variable into normalized values (0..catcount_vi-1)
        // often many categorical variables are similar, so we compress the map - try to re-use
        // maps for different variables if they are identical
        for( i = 0; i < ninputvars; i++ )
349
        {
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
            Mat values_i = layout == ROW_SAMPLE ? samples.col(i) : samples.row(i);

            if( varType.at<uchar>(i) == VAR_CATEGORICAL )
            {
                preprocessCategorical(values_i, 0, labels, 0, sortbuf);
                missingSubst.at<float>(i) = -1.f;
                int j, m = (int)labels.size();
                CV_Assert( m > 0 );
                int a = labels.front(), b = labels.back();
                const int* currmap = &labels[0];
                int hashval = ((unsigned)a*127 + (unsigned)b)*127 + m;
                CatMapHash::iterator it = ofshash.find(hashval);
                if( it != ofshash.end() )
                {
                    int vi = it->second;
                    Vec2i ofs0 = tempCatOfs[vi];
                    int m0 = ofs0[1] - ofs0[0];
                    const int* map0 = &tempCatMap[ofs0[0]];
                    if( m0 == m && map0[0] == a && map0[m0-1] == b )
                    {
                        for( j = 0; j < m; j++ )
                            if( map0[j] != currmap[j] )
                                break;
                        if( j == m )
                        {
                            // re-use the map
                            tempCatOfs.push_back(ofs0);
                            continue;
                        }
                    }
                }
                else
                    ofshash[hashval] = i;
                Vec2i ofs;
                ofs[0] = (int)tempCatMap.size();
                ofs[1] = ofs[0] + m;
                tempCatOfs.push_back(ofs);
                std::copy(labels.begin(), labels.end(), std::back_inserter(tempCatMap));
            }
389
            else
390
            {
391 392 393 394 395
                tempCatOfs.push_back(Vec2i(0, 0));
                /*Mat missing_i = layout == ROW_SAMPLE ? missing.col(i) : missing.row(i);
                compare(missing_i, Scalar::all(0), non_missing, CMP_EQ);
                missingSubst.at<float>(i) = (float)(mean(values_i, non_missing)[0]);*/
                missingSubst.at<float>(i) = 0.f;
396 397 398
            }
        }

399 400 401 402 403
        if( !tempCatOfs.empty() )
        {
            Mat(tempCatOfs).copyTo(catOfs);
            Mat(tempCatMap).copyTo(catMap);
        }
404

405
        if( varType.at<uchar>(ninputvars) == VAR_CATEGORICAL )
406
        {
407 408 409
            preprocessCategorical(responses, &normCatResponses, labels, &counters, sortbuf);
            Mat(labels).copyTo(classLabels);
            Mat(counters).copyTo(classCounters);
410 411 412
        }
    }

413 414 415 416 417 418 419 420 421
    Mat convertMaskToIdx(const Mat& mask)
    {
        int i, j, nz = countNonZero(mask), n = mask.cols + mask.rows - 1;
        Mat idx(1, nz, CV_32S);
        for( i = j = 0; i < n; i++ )
            if( mask.at<uchar>(i) )
                idx.at<int>(j++) = i;
        return idx;
    }
422

423 424 425 426 427 428 429 430 431 432 433 434 435 436
    struct CmpByIdx
    {
        CmpByIdx(const int* _data, int _step) : data(_data), step(_step) {}
        bool operator ()(int i, int j) const { return data[i*step] < data[j*step]; }
        const int* data;
        int step;
    };

    void preprocessCategorical(const Mat& data, Mat* normdata, vector<int>& labels,
                               vector<int>* counters, vector<int>& sortbuf)
    {
        CV_Assert((data.cols == 1 || data.rows == 1) && (data.type() == CV_32S || data.type() == CV_32F));
        int* odata = 0;
        int ostep = 0;
437

438 439 440 441 442 443
        if(normdata)
        {
            normdata->create(data.size(), CV_32S);
            odata = normdata->ptr<int>();
            ostep = normdata->isContinuous() ? 1 : (int)normdata->step1();
        }
444

445 446 447 448 449
        int i, n = data.cols + data.rows - 1;
        sortbuf.resize(n*2);
        int* idx = &sortbuf[0];
        int* idata = (int*)data.ptr<int>();
        int istep = data.isContinuous() ? 1 : (int)data.step1();
450

451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
        if( data.type() == CV_32F )
        {
            idata = idx + n;
            const float* fdata = data.ptr<float>();
            for( i = 0; i < n; i++ )
            {
                if( fdata[i*istep] == MISSED_VAL )
                    idata[i] = -1;
                else
                {
                    idata[i] = cvRound(fdata[i*istep]);
                    CV_Assert( (float)idata[i] == fdata[i*istep] );
                }
            }
            istep = 1;
        }
467

468 469
        for( i = 0; i < n; i++ )
            idx[i] = i;
470

471
        std::sort(idx, idx + n, CmpByIdx(idata, istep));
472

473 474 475
        int clscount = 1;
        for( i = 1; i < n; i++ )
            clscount += idata[idx[i]*istep] != idata[idx[i-1]*istep];
476

477 478 479
        int clslabel = -1;
        int prev = ~idata[idx[0]*istep];
        int previdx = 0;
480

481 482 483 484 485
        labels.resize(clscount);
        if(counters)
            counters->resize(clscount);

        for( i = 0; i < n; i++ )
486
        {
487 488
            int l = idata[idx[i]*istep];
            if( l != prev )
489
            {
490 491 492 493 494 495 496
                clslabel++;
                labels[clslabel] = l;
                int k = i - previdx;
                if( clslabel > 0 && counters )
                    counters->at(clslabel-1) = k;
                prev = l;
                previdx = i;
497
            }
498 499
            if(odata)
                odata[idx[i]*ostep] = clslabel;
500
        }
501 502
        if(counters)
            counters->at(clslabel) = i - previdx;
503 504
    }

505 506 507 508 509 510 511 512
    bool loadCSV(const String& filename, int headerLines,
                 int responseStartIdx, int responseEndIdx,
                 const String& varTypeSpec, char delimiter, char missch)
    {
        const int M = 1000000;
        const char delimiters[3] = { ' ', delimiter, '\0' };
        int nvars = 0;
        bool varTypesSet = false;
513

514
        clear();
515

516
        file = fopen( filename.c_str(), "rt" );
517

518 519
        if( !file )
            return false;
520

521 522 523 524 525 526
        std::vector<char> _buf(M);
        std::vector<float> allresponses;
        std::vector<float> rowvals;
        std::vector<uchar> vtypes, rowtypes;
        bool haveMissed = false;
        char* buf = &_buf[0];
527

528 529
        int i, ridx0 = responseStartIdx, ridx1 = responseEndIdx;
        int ninputvars = 0, noutputvars = 0;
530

531 532 533
        Mat tempSamples, tempMissing, tempResponses;
        MapType tempNameMap;
        int catCounter = 1;
534

535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
        // skip header lines
        int lineno = 0;
        for(;;lineno++)
        {
            if( !fgets(buf, M, file) )
                break;
            if(lineno < headerLines )
                continue;
            // trim trailing spaces
            int idx = (int)strlen(buf)-1;
            while( idx >= 0 && isspace(buf[idx]) )
                buf[idx--] = '\0';
            // skip spaces in the beginning
            char* ptr = buf;
            while( *ptr != '\0' && isspace(*ptr) )
                ptr++;
            // skip commented off lines
            if(*ptr == '#')
                continue;
            rowvals.clear();
            rowtypes.clear();

            char* token = strtok(buf, delimiters);
            if (!token)
                break;
560

561 562 563 564 565 566 567
            for(;;)
            {
                float val=0.f; int tp = 0;
                decodeElem( token, val, tp, missch, tempNameMap, catCounter );
                if( tp == VAR_MISSED )
                    haveMissed = true;
                rowvals.push_back(val);
568
                rowtypes.push_back((uchar)tp);
569 570 571 572
                token = strtok(NULL, delimiters);
                if (!token)
                    break;
            }
573

574 575 576 577 578 579 580 581 582 583 584 585
            if( nvars == 0 )
            {
                if( rowvals.empty() )
                    CV_Error(CV_StsBadArg, "invalid CSV format; no data found");
                nvars = (int)rowvals.size();
                if( !varTypeSpec.empty() && varTypeSpec.size() > 0 )
                {
                    setVarTypes(varTypeSpec, nvars, vtypes);
                    varTypesSet = true;
                }
                else
                    vtypes = rowtypes;
586

587 588 589 590 591 592 593 594
                ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
                ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
                CV_Assert(ridx1 > ridx0);
                noutputvars = ridx0 >= 0 ? ridx1 - ridx0 : 0;
                ninputvars = nvars - noutputvars;
            }
            else
                CV_Assert( nvars == (int)rowvals.size() );
595

596 597 598 599 600 601
            // check var types
            for( i = 0; i < nvars; i++ )
            {
                CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
                           (varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
            }
602

603 604 605 606 607 608 609 610 611 612 613
            if( ridx0 >= 0 )
            {
                for( i = ridx1; i < nvars; i++ )
                    std::swap(rowvals[i], rowvals[i-noutputvars]);
                for( i = ninputvars; i < nvars; i++ )
                    allresponses.push_back(rowvals[i]);
                rowvals.pop_back();
            }
            Mat rmat(1, ninputvars, CV_32F, &rowvals[0]);
            tempSamples.push_back(rmat);
        }
614

615
        closeFile();
616

617 618 619
        int nsamples = tempSamples.rows;
        if( nsamples == 0 )
            return false;
620

621 622
        if( haveMissed )
            compare(tempSamples, MISSED_VAL, tempMissing, CMP_EQ);
623

624 625 626 627 628 629 630 631 632 633 634 635
        if( ridx0 >= 0 )
        {
            for( i = ridx1; i < nvars; i++ )
                std::swap(vtypes[i], vtypes[i-noutputvars]);
            if( noutputvars > 1 )
            {
                for( i = ninputvars; i < nvars; i++ )
                    if( vtypes[i] == VAR_CATEGORICAL )
                        CV_Error(CV_StsBadArg,
                                 "If responses are vector values, not scalars, they must be marked as ordered responses");
            }
        }
636

637 638 639 640 641 642 643 644
        if( !varTypesSet && noutputvars == 1 && vtypes[ninputvars] == VAR_ORDERED )
        {
            for( i = 0; i < nsamples; i++ )
                if( allresponses[i] != cvRound(allresponses[i]) )
                    break;
            if( i == nsamples )
                vtypes[ninputvars] = VAR_CATEGORICAL;
        }
645

Lorena García's avatar
Lorena García committed
646
        //If there are responses in the csv file, save them. If not, responses matrix will contain just zeros
Lorena García's avatar
Lorena García committed
647
        if (noutputvars != 0){
Lorena García's avatar
Lorena García committed
648 649 650
            Mat(nsamples, noutputvars, CV_32F, &allresponses[0]).copyTo(tempResponses);
            setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
                    noArray(), Mat(vtypes).clone(), tempMissing);
Lorena García's avatar
Lorena García committed
651 652 653 654 655 656 657
        }
        else{
            Mat zero_mat(nsamples, 1, CV_32F, Scalar(0));
            zero_mat.copyTo(tempResponses);
            setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
                    noArray(), noArray(), tempMissing);
        }
658 659 660 661 662
        bool ok = !samples.empty();
        if(ok)
            std::swap(tempNameMap, nameMap);
        return ok;
    }
663

664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688
    void decodeElem( const char* token, float& elem, int& type,
                     char missch, MapType& namemap, int& counter ) const
    {
        char* stopstring = NULL;
        elem = (float)strtod( token, &stopstring );
        if( *stopstring == missch && strlen(stopstring) == 1 ) // missed value
        {
            elem = MISSED_VAL;
            type = VAR_MISSED;
        }
        else if( *stopstring != '\0' )
        {
            MapType::iterator it = namemap.find(token);
            if( it == namemap.end() )
            {
                elem = (float)counter;
                namemap[token] = counter++;
            }
            else
                elem = (float)it->second;
            type = VAR_CATEGORICAL;
        }
        else
            type = VAR_ORDERED;
    }
689

690 691 692 693 694 695
    void setVarTypes( const String& s, int nvars, std::vector<uchar>& vtypes ) const
    {
        const char* errmsg = "type spec is not correct; it should have format \"cat\", \"ord\" or "
          "\"ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\", where n's and m's are 0-based variable indices";
        const char* str = s.c_str();
        int specCounter = 0;
696

697
        vtypes.resize(nvars);
698

699 700 701 702 703 704 705
        for( int k = 0; k < 2; k++ )
        {
            const char* ptr = strstr(str, k == 0 ? "ord" : "cat");
            int tp = k == 0 ? VAR_ORDERED : VAR_CATEGORICAL;
            if( ptr ) // parse ord/cat str
            {
                char* stopstring = NULL;
706

707 708 709 710 711 712 713
                if( ptr[3] == '\0' )
                {
                    for( int i = 0; i < nvars; i++ )
                        vtypes[i] = (uchar)tp;
                    specCounter = nvars;
                    break;
                }
714

715 716
                if ( ptr[3] != '[')
                    CV_Error( CV_StsBadArg, errmsg );
717

718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
                ptr += 4; // pass "ord["
                do
                {
                    int b1 = (int)strtod( ptr, &stopstring );
                    if( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
                        CV_Error( CV_StsBadArg, errmsg );
                    ptr = stopstring + 1;
                    if( (stopstring[0] == ',') || (stopstring[0] == ']'))
                    {
                        CV_Assert( 0 <= b1 && b1 < nvars );
                        vtypes[b1] = (uchar)tp;
                        specCounter++;
                    }
                    else
                    {
                        if( stopstring[0] == '-')
                        {
                            int b2 = (int)strtod( ptr, &stopstring);
                            if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
                                CV_Error( CV_StsBadArg, errmsg );
                            ptr = stopstring + 1;
                            CV_Assert( 0 <= b1 && b1 <= b2 && b2 < nvars );
                            for (int i = b1; i <= b2; i++)
                                vtypes[i] = (uchar)tp;
                            specCounter += b2 - b1 + 1;
                        }
                        else
                            CV_Error( CV_StsBadArg, errmsg );
746

747 748 749
                    }
                }
                while(*stopstring != ']');
750

751 752 753 754
                if( stopstring[1] != '\0' && stopstring[1] != ',')
                    CV_Error( CV_StsBadArg, errmsg );
            }
        }
755

756 757
        if( specCounter != nvars )
            CV_Error( CV_StsBadArg, "type of some variables is not specified" );
758 759
    }

760
    void setTrainTestSplitRatio(double ratio, bool shuffle)
761
    {
762
        CV_Assert( 0. <= ratio && ratio <= 1. );
763
        setTrainTestSplit(cvRound(getNSamples()*ratio), shuffle);
764 765
    }

766
    void setTrainTestSplit(int count, bool shuffle)
767
    {
768
        int i, nsamples = getNSamples();
769
        CV_Assert( 0 <= count && count < nsamples );
770 771 772

        trainSampleIdx.release();
        testSampleIdx.release();
773

774 775 776 777 778
        if( count == 0 )
            trainSampleIdx = sampleIdx;
        else if( count == nsamples )
            testSampleIdx = sampleIdx;
        else
779
        {
780
            Mat mask(1, nsamples, CV_8U);
781
            uchar* mptr = mask.ptr();
782 783 784 785 786 787 788 789 790
            for( i = 0; i < nsamples; i++ )
                mptr[i] = (uchar)(i < count);
            trainSampleIdx.create(1, count, CV_32S);
            testSampleIdx.create(1, nsamples - count, CV_32S);
            int j0 = 0, j1 = 0;
            const int* sptr = !sampleIdx.empty() ? sampleIdx.ptr<int>() : 0;
            int* trainptr = trainSampleIdx.ptr<int>();
            int* testptr = testSampleIdx.ptr<int>();
            for( i = 0; i < nsamples; i++ )
791
            {
792 793 794
                int idx = sptr ? sptr[i] : i;
                if( mptr[i] )
                    trainptr[j0++] = idx;
795
                else
796
                    testptr[j1++] = idx;
797
            }
798 799
            if( shuffle )
                shuffleTrainTest();
800
        }
801
    }
802

803
    void shuffleTrainTest()
804
    {
805
        if( !trainSampleIdx.empty() && !testSampleIdx.empty() )
806
        {
807 808 809 810 811 812
            int i, nsamples = getNSamples(), ntrain = getNTrainSamples(), ntest = getNTestSamples();
            int* trainIdx = trainSampleIdx.ptr<int>();
            int* testIdx = testSampleIdx.ptr<int>();
            RNG& rng = theRNG();

            for( i = 0; i < nsamples; i++)
813
            {
814 815 816 817 818
                int a = rng.uniform(0, nsamples);
                int b = rng.uniform(0, nsamples);
                int* ptra = trainIdx;
                int* ptrb = trainIdx;
                if( a >= ntrain )
819
                {
820 821 822
                    ptra = testIdx;
                    a -= ntrain;
                    CV_Assert( a < ntest );
823
                }
824 825 826 827 828 829 830
                if( b >= ntrain )
                {
                    ptrb = testIdx;
                    b -= ntrain;
                    CV_Assert( b < ntest );
                }
                std::swap(ptra[a], ptrb[b]);
831 832 833 834
            }
        }
    }

835 836 837
    Mat getTrainSamples(int _layout,
                        bool compressSamples,
                        bool compressVars) const
838
    {
839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861
        if( samples.empty() )
            return samples;

        if( (!compressSamples || (trainSampleIdx.empty() && sampleIdx.empty())) &&
            (!compressVars || varIdx.empty()) &&
            layout == _layout )
            return samples;

        int drows = getNTrainSamples(), dcols = getNVars();
        Mat sidx = getTrainSampleIdx(), vidx = getVarIdx();
        const float* src0 = samples.ptr<float>();
        const int* sptr = !sidx.empty() ? sidx.ptr<int>() : 0;
        const int* vptr = !vidx.empty() ? vidx.ptr<int>() : 0;
        size_t sstep0 = samples.step/samples.elemSize();
        size_t sstep = layout == ROW_SAMPLE ? sstep0 : 1;
        size_t vstep = layout == ROW_SAMPLE ? 1 : sstep0;

        if( _layout == COL_SAMPLE )
        {
            std::swap(drows, dcols);
            std::swap(sptr, vptr);
            std::swap(sstep, vstep);
        }
862

863
        Mat dsamples(drows, dcols, CV_32F);
864

865 866 867 868
        for( int i = 0; i < drows; i++ )
        {
            const float* src = src0 + (sptr ? sptr[i] : i)*sstep;
            float* dst = dsamples.ptr<float>(i);
869

870 871 872
            for( int j = 0; j < dcols; j++ )
                dst[j] = src[(vptr ? vptr[j] : j)*vstep];
        }
873

874
        return dsamples;
875 876
    }

877
    void getValues( int vi, InputArray _sidx, float* values ) const
878
    {
879
        Mat sidx = _sidx.getMat();
880
        int i, n = sidx.checkVector(1, CV_32S), nsamples = getNSamples();
881
        CV_Assert( 0 <= vi && vi < getNAllVars() );
882
        CV_Assert( n >= 0 );
883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898
        const int* s = n > 0 ? sidx.ptr<int>() : 0;
        if( n == 0 )
            n = nsamples;

        size_t step = samples.step/samples.elemSize();
        size_t sstep = layout == ROW_SAMPLE ? step : 1;
        size_t vstep = layout == ROW_SAMPLE ? 1 : step;

        const float* src = samples.ptr<float>() + vi*vstep;
        float subst = missingSubst.at<float>(vi);
        for( i = 0; i < n; i++ )
        {
            int j = i;
            if( s )
            {
                j = s[i];
899
                CV_Assert( 0 <= j && j < nsamples );
900 901 902 903 904
            }
            values[i] = src[j*sstep];
            if( values[i] == MISSED_VAL )
                values[i] = subst;
        }
905 906
    }

907
    void getNormCatValues( int vi, InputArray _sidx, int* values ) const
908
    {
909 910 911 912 913
        float* fvalues = (float*)values;
        getValues(vi, _sidx, fvalues);
        int i, n = (int)_sidx.total();
        Vec2i ofs = catOfs.at<Vec2i>(vi);
        int m = ofs[1] - ofs[0];
914

915 916
        CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
        const int* cmap = &catMap.at<int>(ofs[0]);
917
        bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
918

919
        if( fastMap )
920
        {
921 922 923 924 925 926 927
            for( i = 0; i < n; i++ )
            {
                int val = cvRound(fvalues[i]);
                int idx = val - cmap[0];
                CV_Assert(cmap[idx] == val);
                values[i] = idx;
            }
928
        }
929 930 931 932 933 934
        else
        {
            for( i = 0; i < n; i++ )
            {
                int val = cvRound(fvalues[i]);
                int a = 0, b = m, c = -1;
935

936 937 938 939 940 941 942 943 944 945
                while( a < b )
                {
                    c = (a + b) >> 1;
                    if( val < cmap[c] )
                        b = c;
                    else if( val > cmap[c] )
                        a = c+1;
                    else
                        break;
                }
946

947 948 949 950
                CV_DbgAssert( c >= 0 && val == cmap[c] );
                values[i] = c;
            }
        }
951 952
    }

953 954 955 956
    void getSample(InputArray _vidx, int sidx, float* buf) const
    {
        CV_Assert(buf != 0 && 0 <= sidx && sidx < getNSamples());
        Mat vidx = _vidx.getMat();
957 958
        int i, n = vidx.checkVector(1, CV_32S), nvars = getNAllVars();
        CV_Assert( n >= 0 );
959 960 961 962 963 964 965 966 967 968
        const int* vptr = n > 0 ? vidx.ptr<int>() : 0;
        if( n == 0 )
            n = nvars;

        size_t step = samples.step/samples.elemSize();
        size_t sstep = layout == ROW_SAMPLE ? step : 1;
        size_t vstep = layout == ROW_SAMPLE ? 1 : step;

        const float* src = samples.ptr<float>() + sidx*sstep;
        for( i = 0; i < n; i++ )
969
        {
970 971 972 973
            int j = i;
            if( vptr )
            {
                j = vptr[i];
974
                CV_Assert( 0 <= j && j < nvars );
975 976
            }
            buf[i] = src[j*vstep];
977
        }
978
    }
979

980 981 982 983 984 985 986 987
    FILE* file;
    int layout;
    Mat samples, missing, varType, varIdx, responses, missingSubst;
    Mat sampleIdx, trainSampleIdx, testSampleIdx;
    Mat sampleWeights, catMap, catOfs;
    Mat normCatResponses, classLabels, classCounters;
    MapType nameMap;
};
988

989 990 991 992 993 994
Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
                                      int headerLines,
                                      int responseStartIdx,
                                      int responseEndIdx,
                                      const String& varTypeSpec,
                                      char delimiter, char missch)
995
{
996 997 998 999
    Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
    if(!td->loadCSV(filename, headerLines, responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch))
        td.release();
    return td;
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
1000 1001
}

1002 1003 1004
Ptr<TrainData> TrainData::create(InputArray samples, int layout, InputArray responses,
                                 InputArray varIdx, InputArray sampleIdx, InputArray sampleWeights,
                                 InputArray varType)
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
1005
{
1006 1007 1008
    Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
    td->setData(samples, layout, responses, varIdx, sampleIdx, sampleWeights, varType, noArray());
    return td;
1009 1010
}

1011 1012
}}

1013
/* End of file. */