data.cpp 35.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                        Intel License Agreement
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of Intel Corporation may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"
#include <ctype.h>
43 44
#include <algorithm>
#include <iterator>
45

46 47
#include <opencv2/core/utils/logger.hpp>

48
namespace cv { namespace ml {
49

50 51
static const float MISSED_VAL = TrainData::missingValue();
static const int VAR_MISSED = VAR_ORDERED;
52

53
TrainData::~TrainData() {}
54

55
Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
56
{
57 58 59 60 61 62 63 64 65 66
    if (!(vec.cols == 1 || vec.rows == 1))
        CV_LOG_WARNING(NULL, "'getSubVector(const Mat& vec, const Mat& idx)' call with non-1D input is deprecated. It is not designed to work with 2D matrixes (especially with 'cv::ml::COL_SAMPLE' layout).");
    return getSubMatrix(vec, idx, vec.rows == 1 ? cv::ml::COL_SAMPLE : cv::ml::ROW_SAMPLE);
}

template<typename T>
Mat getSubMatrixImpl(const Mat& m, const Mat& idx, int layout)
{
    int nidx = idx.checkVector(1, CV_32S);
    int dims = m.cols, nsamples = m.rows;
67

68 69
    Mat subm;
    if (layout == COL_SAMPLE)
70
    {
71 72
        std::swap(dims, nsamples);
        subm.create(dims, nidx, m.type());
73 74 75
    }
    else
    {
76
        subm.create(nidx, dims, m.type());
77
    }
78

79 80 81 82
    for (int i = 0; i < nidx; i++)
    {
        int k = idx.at<int>(i); CV_CheckGE(k, 0, "Bad idx"); CV_CheckLT(k, nsamples, "Bad idx or layout");
        if (dims == 1)
83
        {
84
            subm.at<T>(i) = m.at<T>(k);  // at() has "transparent" access for 1D col-based / row-based vectors.
85
        }
86
        else if (layout == COL_SAMPLE)
87
        {
88 89
            for (int j = 0; j < dims; j++)
                subm.at<T>(j, i) = m.at<T>(j, k);
90
        }
91
        else
92
        {
93 94
            for (int j = 0; j < dims; j++)
                subm.at<T>(i, j) = m.at<T>(k, j);
95
        }
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    }
    return subm;
}

Mat TrainData::getSubMatrix(const Mat& m, const Mat& idx, int layout)
{
    if (idx.empty())
        return m;
    int type = m.type();
    CV_CheckType(type, type == CV_32S || type == CV_32F || type == CV_64F, "");
    if (type == CV_32S || type == CV_32F)  // 32-bit
        return getSubMatrixImpl<int>(m, idx, layout);
    if (type == CV_64F)  // 64-bit
        return getSubMatrixImpl<double>(m, idx, layout);
    CV_Error(Error::StsInternal, "");
111 112
}

113

114
class TrainDataImpl CV_FINAL : public TrainData
115
{
116 117
public:
    typedef std::map<String, int> MapType;
118

119 120 121 122 123
    TrainDataImpl()
    {
        file = 0;
        clear();
    }
124

125
    virtual ~TrainDataImpl() { closeFile(); }
126

127 128
    int getLayout() const CV_OVERRIDE { return layout; }
    int getNSamples() const CV_OVERRIDE
129 130 131 132
    {
        return !sampleIdx.empty() ? (int)sampleIdx.total() :
               layout == ROW_SAMPLE ? samples.rows : samples.cols;
    }
133
    int getNTrainSamples() const CV_OVERRIDE
134 135 136
    {
        return !trainSampleIdx.empty() ? (int)trainSampleIdx.total() : getNSamples();
    }
137
    int getNTestSamples() const CV_OVERRIDE
138 139 140
    {
        return !testSampleIdx.empty() ? (int)testSampleIdx.total() : 0;
    }
141
    int getNVars() const CV_OVERRIDE
142 143 144
    {
        return !varIdx.empty() ? (int)varIdx.total() : getNAllVars();
    }
145
    int getNAllVars() const CV_OVERRIDE
146 147 148
    {
        return layout == ROW_SAMPLE ? samples.cols : samples.rows;
    }
149

150 151 152
    Mat getTestSamples() const CV_OVERRIDE
    {
        Mat idx = getTestSampleIdx();
153
        return idx.empty() ? Mat() : getSubMatrix(samples, idx, getLayout());
154 155
    }

156 157 158 159 160 161
    Mat getSamples() const CV_OVERRIDE { return samples; }
    Mat getResponses() const CV_OVERRIDE { return responses; }
    Mat getMissing() const CV_OVERRIDE { return missing; }
    Mat getVarIdx() const CV_OVERRIDE { return varIdx; }
    Mat getVarType() const CV_OVERRIDE { return varType; }
    int getResponseType() const CV_OVERRIDE
162 163 164
    {
        return classLabels.empty() ? VAR_ORDERED : VAR_CATEGORICAL;
    }
165 166 167
    Mat getTrainSampleIdx() const CV_OVERRIDE { return !trainSampleIdx.empty() ? trainSampleIdx : sampleIdx; }
    Mat getTestSampleIdx() const CV_OVERRIDE { return testSampleIdx; }
    Mat getSampleWeights() const CV_OVERRIDE
168 169 170
    {
        return sampleWeights;
    }
171
    Mat getTrainSampleWeights() const CV_OVERRIDE
172
    {
173
        return getSubVector(sampleWeights, getTrainSampleIdx());  // 1D-vector
174
    }
175
    Mat getTestSampleWeights() const CV_OVERRIDE
176 177
    {
        Mat idx = getTestSampleIdx();
178
        return idx.empty() ? Mat() : getSubVector(sampleWeights, idx);  // 1D-vector
179
    }
180
    Mat getTrainResponses() const CV_OVERRIDE
181
    {
182
        return getSubMatrix(responses, getTrainSampleIdx(), cv::ml::ROW_SAMPLE);  // col-based responses are transposed in setData()
183
    }
184
    Mat getTrainNormCatResponses() const CV_OVERRIDE
185
    {
186
        return getSubMatrix(normCatResponses, getTrainSampleIdx(), cv::ml::ROW_SAMPLE);  // like 'responses'
187
    }
188
    Mat getTestResponses() const CV_OVERRIDE
189 190
    {
        Mat idx = getTestSampleIdx();
191
        return idx.empty() ? Mat() : getSubMatrix(responses, idx, cv::ml::ROW_SAMPLE);  // col-based responses are transposed in setData()
192
    }
193
    Mat getTestNormCatResponses() const CV_OVERRIDE
194 195
    {
        Mat idx = getTestSampleIdx();
196
        return idx.empty() ? Mat() : getSubMatrix(normCatResponses, idx, cv::ml::ROW_SAMPLE);  // like 'responses'
197
    }
198 199
    Mat getNormCatResponses() const CV_OVERRIDE { return normCatResponses; }
    Mat getClassLabels() const CV_OVERRIDE { return classLabels; }
200
    Mat getClassCounters() const { return classCounters; }
201
    int getCatCount(int vi) const CV_OVERRIDE
202 203 204 205 206 207
    {
        int n = (int)catOfs.total();
        CV_Assert( 0 <= vi && vi < n );
        Vec2i ofs = catOfs.at<Vec2i>(vi);
        return ofs[1] - ofs[0];
    }
208

209 210
    Mat getCatOfs() const CV_OVERRIDE { return catOfs; }
    Mat getCatMap() const CV_OVERRIDE { return catMap; }
211

212
    Mat getDefaultSubstValues() const CV_OVERRIDE { return missingSubst; }
213

214 215 216 217 218 219 220
    void closeFile() { if(file) fclose(file); file=0; }
    void clear()
    {
        closeFile();
        samples.release();
        missing.release();
        varType.release();
221
        varSymbolFlags.release();
222 223 224 225 226 227 228 229 230 231 232 233
        responses.release();
        sampleIdx.release();
        trainSampleIdx.release();
        testSampleIdx.release();
        normCatResponses.release();
        classLabels.release();
        classCounters.release();
        catMap.release();
        catOfs.release();
        nameMap = MapType();
        layout = ROW_SAMPLE;
    }
234

235
    typedef std::map<int, int> CatMapHash;
236

237 238 239
    void setData(InputArray _samples, int _layout, InputArray _responses,
                 InputArray _varIdx, InputArray _sampleIdx, InputArray _sampleWeights,
                 InputArray _varType, InputArray _missing)
240
    {
241
        clear();
242

243 244 245 246 247 248 249 250 251
        CV_Assert(_layout == ROW_SAMPLE || _layout == COL_SAMPLE );
        samples = _samples.getMat();
        layout = _layout;
        responses = _responses.getMat();
        varIdx = _varIdx.getMat();
        sampleIdx = _sampleIdx.getMat();
        sampleWeights = _sampleWeights.getMat();
        varType = _varType.getMat();
        missing = _missing.getMat();
252

253 254 255
        int nsamples = layout == ROW_SAMPLE ? samples.rows : samples.cols;
        int ninputvars = layout == ROW_SAMPLE ? samples.cols : samples.rows;
        int i, noutputvars = 0;
256

257
        CV_Assert( samples.type() == CV_32F || samples.type() == CV_32S );
258

259 260 261
        if( !sampleIdx.empty() )
        {
            CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
262
                       checkRange(sampleIdx, true, 0, 0, nsamples)) ||
263 264 265 266
                       sampleIdx.checkVector(1, CV_8U, true) == nsamples );
            if( sampleIdx.type() == CV_8U )
                sampleIdx = convertMaskToIdx(sampleIdx);
        }
267

268 269 270 271 272 273 274 275
        if( !sampleWeights.empty() )
        {
            CV_Assert( sampleWeights.checkVector(1, CV_32F, true) == nsamples );
        }
        else
        {
            sampleWeights = Mat::ones(nsamples, 1, CV_32F);
        }
276

277 278 279 280 281 282 283 284 285 286
        if( !varIdx.empty() )
        {
            CV_Assert( (varIdx.checkVector(1, CV_32S, true) > 0 &&
                       checkRange(varIdx, true, 0, 0, ninputvars)) ||
                       varIdx.checkVector(1, CV_8U, true) == ninputvars );
            if( varIdx.type() == CV_8U )
                varIdx = convertMaskToIdx(varIdx);
            varIdx = varIdx.clone();
            std::sort(varIdx.ptr<int>(), varIdx.ptr<int>() + varIdx.total());
        }
287

288
        if( !responses.empty() )
289
        {
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
            CV_Assert( responses.type() == CV_32F || responses.type() == CV_32S );
            if( (responses.cols == 1 || responses.rows == 1) && (int)responses.total() == nsamples )
                noutputvars = 1;
            else
            {
                CV_Assert( (layout == ROW_SAMPLE && responses.rows == nsamples) ||
                           (layout == COL_SAMPLE && responses.cols == nsamples) );
                noutputvars = layout == ROW_SAMPLE ? responses.cols : responses.rows;
            }
            if( !responses.isContinuous() || (layout == COL_SAMPLE && noutputvars > 1) )
            {
                Mat temp;
                transpose(responses, temp);
                responses = temp;
            }
305
        }
306

307
        int nvars = ninputvars + noutputvars;
308

309
        if( !varType.empty() )
310
        {
311 312
            CV_Assert( varType.checkVector(1, CV_8U, true) == nvars &&
                       checkRange(varType, true, 0, VAR_ORDERED, VAR_CATEGORICAL+1) );
313 314
        }
        else
315 316 317 318
        {
            varType.create(1, nvars, CV_8U);
            varType = Scalar::all(VAR_ORDERED);
            if( noutputvars == 1 )
319
                varType.at<uchar>(ninputvars) = (uchar)(responses.type() < CV_32F ? VAR_CATEGORICAL : VAR_ORDERED);
320
        }
321

322 323 324 325 326
        if( noutputvars > 1 )
        {
            for( i = 0; i < noutputvars; i++ )
                CV_Assert( varType.at<uchar>(ninputvars + i) == VAR_ORDERED );
        }
Maria Dimashova's avatar
Maria Dimashova committed
327

328 329
        catOfs = Mat::zeros(1, nvars, CV_32SC2);
        missingSubst = Mat::zeros(1, nvars, CV_32F);
330

331 332 333
        vector<int> labels, counters, sortbuf, tempCatMap;
        vector<Vec2i> tempCatOfs;
        CatMapHash ofshash;
334

335
        AutoBuffer<uchar> buf(nsamples);
336
        Mat non_missing(layout == ROW_SAMPLE ? Size(1, nsamples) : Size(nsamples, 1), CV_8U, buf.data());
337 338 339 340 341
        bool haveMissing = !missing.empty();
        if( haveMissing )
        {
            CV_Assert( missing.size() == samples.size() && missing.type() == CV_8U );
        }
342

343 344 345 346 347
        // we iterate through all the variables. For each categorical variable we build a map
        // in order to convert input values of the variable into normalized values (0..catcount_vi-1)
        // often many categorical variables are similar, so we compress the map - try to re-use
        // maps for different variables if they are identical
        for( i = 0; i < ninputvars; i++ )
348
        {
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
            Mat values_i = layout == ROW_SAMPLE ? samples.col(i) : samples.row(i);

            if( varType.at<uchar>(i) == VAR_CATEGORICAL )
            {
                preprocessCategorical(values_i, 0, labels, 0, sortbuf);
                missingSubst.at<float>(i) = -1.f;
                int j, m = (int)labels.size();
                CV_Assert( m > 0 );
                int a = labels.front(), b = labels.back();
                const int* currmap = &labels[0];
                int hashval = ((unsigned)a*127 + (unsigned)b)*127 + m;
                CatMapHash::iterator it = ofshash.find(hashval);
                if( it != ofshash.end() )
                {
                    int vi = it->second;
                    Vec2i ofs0 = tempCatOfs[vi];
                    int m0 = ofs0[1] - ofs0[0];
                    const int* map0 = &tempCatMap[ofs0[0]];
                    if( m0 == m && map0[0] == a && map0[m0-1] == b )
                    {
                        for( j = 0; j < m; j++ )
                            if( map0[j] != currmap[j] )
                                break;
                        if( j == m )
                        {
                            // re-use the map
                            tempCatOfs.push_back(ofs0);
                            continue;
                        }
                    }
                }
                else
                    ofshash[hashval] = i;
                Vec2i ofs;
                ofs[0] = (int)tempCatMap.size();
                ofs[1] = ofs[0] + m;
                tempCatOfs.push_back(ofs);
                std::copy(labels.begin(), labels.end(), std::back_inserter(tempCatMap));
            }
388
            else
389
            {
390 391 392 393 394
                tempCatOfs.push_back(Vec2i(0, 0));
                /*Mat missing_i = layout == ROW_SAMPLE ? missing.col(i) : missing.row(i);
                compare(missing_i, Scalar::all(0), non_missing, CMP_EQ);
                missingSubst.at<float>(i) = (float)(mean(values_i, non_missing)[0]);*/
                missingSubst.at<float>(i) = 0.f;
395 396 397
            }
        }

398 399 400 401 402
        if( !tempCatOfs.empty() )
        {
            Mat(tempCatOfs).copyTo(catOfs);
            Mat(tempCatMap).copyTo(catMap);
        }
403

404
        if( noutputvars > 0 && varType.at<uchar>(ninputvars) == VAR_CATEGORICAL )
405
        {
406 407 408
            preprocessCategorical(responses, &normCatResponses, labels, &counters, sortbuf);
            Mat(labels).copyTo(classLabels);
            Mat(counters).copyTo(classCounters);
409 410 411
        }
    }

412 413 414 415 416 417 418 419 420
    Mat convertMaskToIdx(const Mat& mask)
    {
        int i, j, nz = countNonZero(mask), n = mask.cols + mask.rows - 1;
        Mat idx(1, nz, CV_32S);
        for( i = j = 0; i < n; i++ )
            if( mask.at<uchar>(i) )
                idx.at<int>(j++) = i;
        return idx;
    }
421

422 423 424 425 426 427 428 429 430 431 432 433 434 435
    struct CmpByIdx
    {
        CmpByIdx(const int* _data, int _step) : data(_data), step(_step) {}
        bool operator ()(int i, int j) const { return data[i*step] < data[j*step]; }
        const int* data;
        int step;
    };

    void preprocessCategorical(const Mat& data, Mat* normdata, vector<int>& labels,
                               vector<int>* counters, vector<int>& sortbuf)
    {
        CV_Assert((data.cols == 1 || data.rows == 1) && (data.type() == CV_32S || data.type() == CV_32F));
        int* odata = 0;
        int ostep = 0;
436

437 438 439 440 441 442
        if(normdata)
        {
            normdata->create(data.size(), CV_32S);
            odata = normdata->ptr<int>();
            ostep = normdata->isContinuous() ? 1 : (int)normdata->step1();
        }
443

444 445 446 447 448
        int i, n = data.cols + data.rows - 1;
        sortbuf.resize(n*2);
        int* idx = &sortbuf[0];
        int* idata = (int*)data.ptr<int>();
        int istep = data.isContinuous() ? 1 : (int)data.step1();
449

450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
        if( data.type() == CV_32F )
        {
            idata = idx + n;
            const float* fdata = data.ptr<float>();
            for( i = 0; i < n; i++ )
            {
                if( fdata[i*istep] == MISSED_VAL )
                    idata[i] = -1;
                else
                {
                    idata[i] = cvRound(fdata[i*istep]);
                    CV_Assert( (float)idata[i] == fdata[i*istep] );
                }
            }
            istep = 1;
        }
466

467 468
        for( i = 0; i < n; i++ )
            idx[i] = i;
469

470
        std::sort(idx, idx + n, CmpByIdx(idata, istep));
471

472 473 474
        int clscount = 1;
        for( i = 1; i < n; i++ )
            clscount += idata[idx[i]*istep] != idata[idx[i-1]*istep];
475

476 477 478
        int clslabel = -1;
        int prev = ~idata[idx[0]*istep];
        int previdx = 0;
479

480 481 482 483 484
        labels.resize(clscount);
        if(counters)
            counters->resize(clscount);

        for( i = 0; i < n; i++ )
485
        {
486 487
            int l = idata[idx[i]*istep];
            if( l != prev )
488
            {
489 490 491 492 493 494 495
                clslabel++;
                labels[clslabel] = l;
                int k = i - previdx;
                if( clslabel > 0 && counters )
                    counters->at(clslabel-1) = k;
                prev = l;
                previdx = i;
496
            }
497 498
            if(odata)
                odata[idx[i]*ostep] = clslabel;
499
        }
500 501
        if(counters)
            counters->at(clslabel) = i - previdx;
502 503
    }

504 505 506 507 508 509 510 511
    bool loadCSV(const String& filename, int headerLines,
                 int responseStartIdx, int responseEndIdx,
                 const String& varTypeSpec, char delimiter, char missch)
    {
        const int M = 1000000;
        const char delimiters[3] = { ' ', delimiter, '\0' };
        int nvars = 0;
        bool varTypesSet = false;
512

513
        clear();
514

515
        file = fopen( filename.c_str(), "rt" );
516

517 518
        if( !file )
            return false;
519

520 521 522 523
        std::vector<char> _buf(M);
        std::vector<float> allresponses;
        std::vector<float> rowvals;
        std::vector<uchar> vtypes, rowtypes;
524
        std::vector<uchar> vsymbolflags;
525 526
        bool haveMissed = false;
        char* buf = &_buf[0];
527

528 529
        int i, ridx0 = responseStartIdx, ridx1 = responseEndIdx;
        int ninputvars = 0, noutputvars = 0;
530

531 532 533
        Mat tempSamples, tempMissing, tempResponses;
        MapType tempNameMap;
        int catCounter = 1;
534

535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
        // skip header lines
        int lineno = 0;
        for(;;lineno++)
        {
            if( !fgets(buf, M, file) )
                break;
            if(lineno < headerLines )
                continue;
            // trim trailing spaces
            int idx = (int)strlen(buf)-1;
            while( idx >= 0 && isspace(buf[idx]) )
                buf[idx--] = '\0';
            // skip spaces in the beginning
            char* ptr = buf;
            while( *ptr != '\0' && isspace(*ptr) )
                ptr++;
            // skip commented off lines
            if(*ptr == '#')
                continue;
            rowvals.clear();
            rowtypes.clear();

            char* token = strtok(buf, delimiters);
            if (!token)
                break;
560

561 562 563 564 565 566 567
            for(;;)
            {
                float val=0.f; int tp = 0;
                decodeElem( token, val, tp, missch, tempNameMap, catCounter );
                if( tp == VAR_MISSED )
                    haveMissed = true;
                rowvals.push_back(val);
568
                rowtypes.push_back((uchar)tp);
569 570 571 572
                token = strtok(NULL, delimiters);
                if (!token)
                    break;
            }
573

574 575 576 577 578 579 580 581 582 583 584 585
            if( nvars == 0 )
            {
                if( rowvals.empty() )
                    CV_Error(CV_StsBadArg, "invalid CSV format; no data found");
                nvars = (int)rowvals.size();
                if( !varTypeSpec.empty() && varTypeSpec.size() > 0 )
                {
                    setVarTypes(varTypeSpec, nvars, vtypes);
                    varTypesSet = true;
                }
                else
                    vtypes = rowtypes;
586 587 588
                vsymbolflags.resize(nvars);
                for( i = 0; i < nvars; i++ )
                    vsymbolflags[i] = (uchar)(rowtypes[i] == VAR_CATEGORICAL);
589

590 591 592 593 594 595 596 597
                ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
                ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
                CV_Assert(ridx1 > ridx0);
                noutputvars = ridx0 >= 0 ? ridx1 - ridx0 : 0;
                ninputvars = nvars - noutputvars;
            }
            else
                CV_Assert( nvars == (int)rowvals.size() );
598

599 600 601 602 603
            // check var types
            for( i = 0; i < nvars; i++ )
            {
                CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
                           (varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
604 605 606 607 608
                uchar sflag = (uchar)(rowtypes[i] == VAR_CATEGORICAL);
                if( vsymbolflags[i] == VAR_MISSED )
                    vsymbolflags[i] = sflag;
                else
                    CV_Assert(vsymbolflags[i] == sflag || rowtypes[i] == VAR_MISSED);
609
            }
610

611 612 613 614 615 616 617 618 619 620 621
            if( ridx0 >= 0 )
            {
                for( i = ridx1; i < nvars; i++ )
                    std::swap(rowvals[i], rowvals[i-noutputvars]);
                for( i = ninputvars; i < nvars; i++ )
                    allresponses.push_back(rowvals[i]);
                rowvals.pop_back();
            }
            Mat rmat(1, ninputvars, CV_32F, &rowvals[0]);
            tempSamples.push_back(rmat);
        }
622

623
        closeFile();
624

625 626 627
        int nsamples = tempSamples.rows;
        if( nsamples == 0 )
            return false;
628

629 630
        if( haveMissed )
            compare(tempSamples, MISSED_VAL, tempMissing, CMP_EQ);
631

632 633 634 635 636 637 638 639 640 641 642 643
        if( ridx0 >= 0 )
        {
            for( i = ridx1; i < nvars; i++ )
                std::swap(vtypes[i], vtypes[i-noutputvars]);
            if( noutputvars > 1 )
            {
                for( i = ninputvars; i < nvars; i++ )
                    if( vtypes[i] == VAR_CATEGORICAL )
                        CV_Error(CV_StsBadArg,
                                 "If responses are vector values, not scalars, they must be marked as ordered responses");
            }
        }
644

645 646 647 648 649 650 651 652
        if( !varTypesSet && noutputvars == 1 && vtypes[ninputvars] == VAR_ORDERED )
        {
            for( i = 0; i < nsamples; i++ )
                if( allresponses[i] != cvRound(allresponses[i]) )
                    break;
            if( i == nsamples )
                vtypes[ninputvars] = VAR_CATEGORICAL;
        }
653

Lorena García's avatar
Lorena García committed
654
        //If there are responses in the csv file, save them. If not, responses matrix will contain just zeros
Lorena García's avatar
Lorena García committed
655
        if (noutputvars != 0){
Lorena García's avatar
Lorena García committed
656 657 658
            Mat(nsamples, noutputvars, CV_32F, &allresponses[0]).copyTo(tempResponses);
            setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
                    noArray(), Mat(vtypes).clone(), tempMissing);
Lorena García's avatar
Lorena García committed
659 660 661 662 663 664 665
        }
        else{
            Mat zero_mat(nsamples, 1, CV_32F, Scalar(0));
            zero_mat.copyTo(tempResponses);
            setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
                    noArray(), noArray(), tempMissing);
        }
666 667
        bool ok = !samples.empty();
        if(ok)
668
        {
669
            std::swap(tempNameMap, nameMap);
670 671
            Mat(vsymbolflags).copyTo(varSymbolFlags);
        }
672 673
        return ok;
    }
674

675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699
    void decodeElem( const char* token, float& elem, int& type,
                     char missch, MapType& namemap, int& counter ) const
    {
        char* stopstring = NULL;
        elem = (float)strtod( token, &stopstring );
        if( *stopstring == missch && strlen(stopstring) == 1 ) // missed value
        {
            elem = MISSED_VAL;
            type = VAR_MISSED;
        }
        else if( *stopstring != '\0' )
        {
            MapType::iterator it = namemap.find(token);
            if( it == namemap.end() )
            {
                elem = (float)counter;
                namemap[token] = counter++;
            }
            else
                elem = (float)it->second;
            type = VAR_CATEGORICAL;
        }
        else
            type = VAR_ORDERED;
    }
700

701 702 703 704 705 706
    void setVarTypes( const String& s, int nvars, std::vector<uchar>& vtypes ) const
    {
        const char* errmsg = "type spec is not correct; it should have format \"cat\", \"ord\" or "
          "\"ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\", where n's and m's are 0-based variable indices";
        const char* str = s.c_str();
        int specCounter = 0;
707

708
        vtypes.resize(nvars);
709

710 711 712 713 714 715 716
        for( int k = 0; k < 2; k++ )
        {
            const char* ptr = strstr(str, k == 0 ? "ord" : "cat");
            int tp = k == 0 ? VAR_ORDERED : VAR_CATEGORICAL;
            if( ptr ) // parse ord/cat str
            {
                char* stopstring = NULL;
717

718 719 720 721 722 723 724
                if( ptr[3] == '\0' )
                {
                    for( int i = 0; i < nvars; i++ )
                        vtypes[i] = (uchar)tp;
                    specCounter = nvars;
                    break;
                }
725

726 727
                if ( ptr[3] != '[')
                    CV_Error( CV_StsBadArg, errmsg );
728

729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756
                ptr += 4; // pass "ord["
                do
                {
                    int b1 = (int)strtod( ptr, &stopstring );
                    if( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
                        CV_Error( CV_StsBadArg, errmsg );
                    ptr = stopstring + 1;
                    if( (stopstring[0] == ',') || (stopstring[0] == ']'))
                    {
                        CV_Assert( 0 <= b1 && b1 < nvars );
                        vtypes[b1] = (uchar)tp;
                        specCounter++;
                    }
                    else
                    {
                        if( stopstring[0] == '-')
                        {
                            int b2 = (int)strtod( ptr, &stopstring);
                            if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
                                CV_Error( CV_StsBadArg, errmsg );
                            ptr = stopstring + 1;
                            CV_Assert( 0 <= b1 && b1 <= b2 && b2 < nvars );
                            for (int i = b1; i <= b2; i++)
                                vtypes[i] = (uchar)tp;
                            specCounter += b2 - b1 + 1;
                        }
                        else
                            CV_Error( CV_StsBadArg, errmsg );
757

758 759 760 761 762
                    }
                }
                while(*stopstring != ']');
            }
        }
763

764 765
        if( specCounter != nvars )
            CV_Error( CV_StsBadArg, "type of some variables is not specified" );
766 767
    }

768
    void setTrainTestSplitRatio(double ratio, bool shuffle) CV_OVERRIDE
769
    {
770
        CV_Assert( 0. <= ratio && ratio <= 1. );
771
        setTrainTestSplit(cvRound(getNSamples()*ratio), shuffle);
772 773
    }

774
    void setTrainTestSplit(int count, bool shuffle) CV_OVERRIDE
775
    {
776
        int i, nsamples = getNSamples();
777
        CV_Assert( 0 <= count && count < nsamples );
778 779 780

        trainSampleIdx.release();
        testSampleIdx.release();
781

782 783 784 785 786
        if( count == 0 )
            trainSampleIdx = sampleIdx;
        else if( count == nsamples )
            testSampleIdx = sampleIdx;
        else
787
        {
788
            Mat mask(1, nsamples, CV_8U);
789
            uchar* mptr = mask.ptr();
790 791 792 793 794 795 796 797 798
            for( i = 0; i < nsamples; i++ )
                mptr[i] = (uchar)(i < count);
            trainSampleIdx.create(1, count, CV_32S);
            testSampleIdx.create(1, nsamples - count, CV_32S);
            int j0 = 0, j1 = 0;
            const int* sptr = !sampleIdx.empty() ? sampleIdx.ptr<int>() : 0;
            int* trainptr = trainSampleIdx.ptr<int>();
            int* testptr = testSampleIdx.ptr<int>();
            for( i = 0; i < nsamples; i++ )
799
            {
800 801 802
                int idx = sptr ? sptr[i] : i;
                if( mptr[i] )
                    trainptr[j0++] = idx;
803
                else
804
                    testptr[j1++] = idx;
805
            }
806 807
            if( shuffle )
                shuffleTrainTest();
808
        }
809
    }
810

811
    void shuffleTrainTest() CV_OVERRIDE
812
    {
813
        if( !trainSampleIdx.empty() && !testSampleIdx.empty() )
814
        {
815 816 817 818 819 820
            int i, nsamples = getNSamples(), ntrain = getNTrainSamples(), ntest = getNTestSamples();
            int* trainIdx = trainSampleIdx.ptr<int>();
            int* testIdx = testSampleIdx.ptr<int>();
            RNG& rng = theRNG();

            for( i = 0; i < nsamples; i++)
821
            {
822 823 824 825 826
                int a = rng.uniform(0, nsamples);
                int b = rng.uniform(0, nsamples);
                int* ptra = trainIdx;
                int* ptrb = trainIdx;
                if( a >= ntrain )
827
                {
828 829 830
                    ptra = testIdx;
                    a -= ntrain;
                    CV_Assert( a < ntest );
831
                }
832 833 834 835 836 837 838
                if( b >= ntrain )
                {
                    ptrb = testIdx;
                    b -= ntrain;
                    CV_Assert( b < ntest );
                }
                std::swap(ptra[a], ptrb[b]);
839 840 841 842
            }
        }
    }

843 844
    Mat getTrainSamples(int _layout,
                        bool compressSamples,
845
                        bool compressVars) const CV_OVERRIDE
846
    {
847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869
        if( samples.empty() )
            return samples;

        if( (!compressSamples || (trainSampleIdx.empty() && sampleIdx.empty())) &&
            (!compressVars || varIdx.empty()) &&
            layout == _layout )
            return samples;

        int drows = getNTrainSamples(), dcols = getNVars();
        Mat sidx = getTrainSampleIdx(), vidx = getVarIdx();
        const float* src0 = samples.ptr<float>();
        const int* sptr = !sidx.empty() ? sidx.ptr<int>() : 0;
        const int* vptr = !vidx.empty() ? vidx.ptr<int>() : 0;
        size_t sstep0 = samples.step/samples.elemSize();
        size_t sstep = layout == ROW_SAMPLE ? sstep0 : 1;
        size_t vstep = layout == ROW_SAMPLE ? 1 : sstep0;

        if( _layout == COL_SAMPLE )
        {
            std::swap(drows, dcols);
            std::swap(sptr, vptr);
            std::swap(sstep, vstep);
        }
870

871
        Mat dsamples(drows, dcols, CV_32F);
872

873 874 875 876
        for( int i = 0; i < drows; i++ )
        {
            const float* src = src0 + (sptr ? sptr[i] : i)*sstep;
            float* dst = dsamples.ptr<float>(i);
877

878 879 880
            for( int j = 0; j < dcols; j++ )
                dst[j] = src[(vptr ? vptr[j] : j)*vstep];
        }
881

882
        return dsamples;
883 884
    }

885
    void getValues( int vi, InputArray _sidx, float* values ) const CV_OVERRIDE
886
    {
887
        Mat sidx = _sidx.getMat();
888
        int i, n = sidx.checkVector(1, CV_32S), nsamples = getNSamples();
889
        CV_Assert( 0 <= vi && vi < getNAllVars() );
890
        CV_Assert( n >= 0 );
891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906
        const int* s = n > 0 ? sidx.ptr<int>() : 0;
        if( n == 0 )
            n = nsamples;

        size_t step = samples.step/samples.elemSize();
        size_t sstep = layout == ROW_SAMPLE ? step : 1;
        size_t vstep = layout == ROW_SAMPLE ? 1 : step;

        const float* src = samples.ptr<float>() + vi*vstep;
        float subst = missingSubst.at<float>(vi);
        for( i = 0; i < n; i++ )
        {
            int j = i;
            if( s )
            {
                j = s[i];
907
                CV_Assert( 0 <= j && j < nsamples );
908 909 910 911 912
            }
            values[i] = src[j*sstep];
            if( values[i] == MISSED_VAL )
                values[i] = subst;
        }
913 914
    }

915
    void getNormCatValues( int vi, InputArray _sidx, int* values ) const CV_OVERRIDE
916
    {
917 918 919 920 921
        float* fvalues = (float*)values;
        getValues(vi, _sidx, fvalues);
        int i, n = (int)_sidx.total();
        Vec2i ofs = catOfs.at<Vec2i>(vi);
        int m = ofs[1] - ofs[0];
922

923 924
        CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
        const int* cmap = &catMap.at<int>(ofs[0]);
925
        bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
926

927
        if( fastMap )
928
        {
929 930 931 932 933 934 935
            for( i = 0; i < n; i++ )
            {
                int val = cvRound(fvalues[i]);
                int idx = val - cmap[0];
                CV_Assert(cmap[idx] == val);
                values[i] = idx;
            }
936
        }
937 938 939 940 941 942
        else
        {
            for( i = 0; i < n; i++ )
            {
                int val = cvRound(fvalues[i]);
                int a = 0, b = m, c = -1;
943

944 945 946 947 948 949 950 951 952 953
                while( a < b )
                {
                    c = (a + b) >> 1;
                    if( val < cmap[c] )
                        b = c;
                    else if( val > cmap[c] )
                        a = c+1;
                    else
                        break;
                }
954

955 956 957 958
                CV_DbgAssert( c >= 0 && val == cmap[c] );
                values[i] = c;
            }
        }
959 960
    }

961
    void getSample(InputArray _vidx, int sidx, float* buf) const CV_OVERRIDE
962 963 964
    {
        CV_Assert(buf != 0 && 0 <= sidx && sidx < getNSamples());
        Mat vidx = _vidx.getMat();
965 966
        int i, n = vidx.checkVector(1, CV_32S), nvars = getNAllVars();
        CV_Assert( n >= 0 );
967 968 969 970 971 972 973 974 975 976
        const int* vptr = n > 0 ? vidx.ptr<int>() : 0;
        if( n == 0 )
            n = nvars;

        size_t step = samples.step/samples.elemSize();
        size_t sstep = layout == ROW_SAMPLE ? step : 1;
        size_t vstep = layout == ROW_SAMPLE ? 1 : step;

        const float* src = samples.ptr<float>() + sidx*sstep;
        for( i = 0; i < n; i++ )
977
        {
978 979 980 981
            int j = i;
            if( vptr )
            {
                j = vptr[i];
982
                CV_Assert( 0 <= j && j < nvars );
983 984
            }
            buf[i] = src[j*vstep];
985
        }
986
    }
987

988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008
    void getNames(std::vector<String>& names) const CV_OVERRIDE
    {
        size_t n = nameMap.size();
        TrainDataImpl::MapType::const_iterator it = nameMap.begin(),
                                               it_end = nameMap.end();
        names.resize(n+1);
        names[0] = "?";
        for( ; it != it_end; ++it )
        {
            String s = it->first;
            int label = it->second;
            CV_Assert( label > 0 && label <= (int)n );
            names[label] = s;
        }
    }

    Mat getVarSymbolFlags() const CV_OVERRIDE
    {
        return varSymbolFlags;
    }

1009 1010
    FILE* file;
    int layout;
1011
    Mat samples, missing, varType, varIdx, varSymbolFlags, responses, missingSubst;
1012 1013 1014 1015 1016
    Mat sampleIdx, trainSampleIdx, testSampleIdx;
    Mat sampleWeights, catMap, catOfs;
    Mat normCatResponses, classLabels, classCounters;
    MapType nameMap;
};
1017

1018

1019 1020 1021 1022 1023 1024
Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
                                      int headerLines,
                                      int responseStartIdx,
                                      int responseEndIdx,
                                      const String& varTypeSpec,
                                      char delimiter, char missch)
1025
{
1026
    CV_TRACE_FUNCTION_SKIP_NESTED();
1027 1028 1029 1030
    Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
    if(!td->loadCSV(filename, headerLines, responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch))
        td.release();
    return td;
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
1031 1032
}

1033 1034 1035
Ptr<TrainData> TrainData::create(InputArray samples, int layout, InputArray responses,
                                 InputArray varIdx, InputArray sampleIdx, InputArray sampleWeights,
                                 InputArray varType)
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
1036
{
1037
    CV_TRACE_FUNCTION_SKIP_NESTED();
1038 1039 1040
    Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
    td->setData(samples, layout, responses, varIdx, sampleIdx, sampleWeights, varType, noArray());
    return td;
1041 1042
}

1043 1044
}}

1045
/* End of file. */