data.cpp 35 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                        Intel License Agreement
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of Intel Corporation may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"
#include <ctype.h>
43 44
#include <algorithm>
#include <iterator>
45

46
namespace cv { namespace ml {
47

48 49
static const float MISSED_VAL = TrainData::missingValue();
static const int VAR_MISSED = VAR_ORDERED;
50

51
TrainData::~TrainData() {}
52

53 54 55 56 57 58 59
Mat TrainData::getTestSamples() const
{
    Mat idx = getTestSampleIdx();
    Mat samples = getSamples();
    return idx.empty() ? Mat() : getSubVector(samples, idx);
}

60
Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
61
{
62 63 64 65 66 67
    if( idx.empty() )
        return vec;
    int i, j, n = idx.checkVector(1, CV_32S);
    int type = vec.type();
    CV_Assert( type == CV_32S || type == CV_32F || type == CV_64F );
    int dims = 1, m;
68

69 70 71 72 73 74 75 76 77 78
    if( vec.cols == 1 || vec.rows == 1 )
    {
        dims = 1;
        m = vec.cols + vec.rows - 1;
    }
    else
    {
        dims = vec.cols;
        m = vec.rows;
    }
79

80
    Mat subvec;
81

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    if( vec.cols == m )
        subvec.create(dims, n, type);
    else
        subvec.create(n, dims, type);
    if( type == CV_32S )
        for( i = 0; i < n; i++ )
        {
            int k = idx.at<int>(i);
            CV_Assert( 0 <= k && k < m );
            if( dims == 1 )
                subvec.at<int>(i) = vec.at<int>(k);
            else
                for( j = 0; j < dims; j++ )
                    subvec.at<int>(i, j) = vec.at<int>(k, j);
        }
    else if( type == CV_32F )
        for( i = 0; i < n; i++ )
        {
            int k = idx.at<int>(i);
            CV_Assert( 0 <= k && k < m );
            if( dims == 1 )
                subvec.at<float>(i) = vec.at<float>(k);
            else
                for( j = 0; j < dims; j++ )
                    subvec.at<float>(i, j) = vec.at<float>(k, j);
        }
    else
        for( i = 0; i < n; i++ )
        {
            int k = idx.at<int>(i);
            CV_Assert( 0 <= k && k < m );
            if( dims == 1 )
                subvec.at<double>(i) = vec.at<double>(k);
            else
                for( j = 0; j < dims; j++ )
                    subvec.at<double>(i, j) = vec.at<double>(k, j);
        }
    return subvec;
120 121
}

122
class TrainDataImpl : public TrainData
123
{
124 125
public:
    typedef std::map<String, int> MapType;
126

127 128 129 130 131
    TrainDataImpl()
    {
        file = 0;
        clear();
    }
132

133
    virtual ~TrainDataImpl() { closeFile(); }
134

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    int getLayout() const { return layout; }
    int getNSamples() const
    {
        return !sampleIdx.empty() ? (int)sampleIdx.total() :
               layout == ROW_SAMPLE ? samples.rows : samples.cols;
    }
    int getNTrainSamples() const
    {
        return !trainSampleIdx.empty() ? (int)trainSampleIdx.total() : getNSamples();
    }
    int getNTestSamples() const
    {
        return !testSampleIdx.empty() ? (int)testSampleIdx.total() : 0;
    }
    int getNVars() const
    {
        return !varIdx.empty() ? (int)varIdx.total() : getNAllVars();
    }
    int getNAllVars() const
    {
        return layout == ROW_SAMPLE ? samples.cols : samples.rows;
    }
157

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
    Mat getSamples() const { return samples; }
    Mat getResponses() const { return responses; }
    Mat getMissing() const { return missing; }
    Mat getVarIdx() const { return varIdx; }
    Mat getVarType() const { return varType; }
    int getResponseType() const
    {
        return classLabels.empty() ? VAR_ORDERED : VAR_CATEGORICAL;
    }
    Mat getTrainSampleIdx() const { return !trainSampleIdx.empty() ? trainSampleIdx : sampleIdx; }
    Mat getTestSampleIdx() const { return testSampleIdx; }
    Mat getSampleWeights() const
    {
        return sampleWeights;
    }
    Mat getTrainSampleWeights() const
    {
        return getSubVector(sampleWeights, getTrainSampleIdx());
    }
    Mat getTestSampleWeights() const
    {
        Mat idx = getTestSampleIdx();
        return idx.empty() ? Mat() : getSubVector(sampleWeights, idx);
    }
    Mat getTrainResponses() const
    {
        return getSubVector(responses, getTrainSampleIdx());
    }
    Mat getTrainNormCatResponses() const
    {
        return getSubVector(normCatResponses, getTrainSampleIdx());
    }
    Mat getTestResponses() const
    {
        Mat idx = getTestSampleIdx();
        return idx.empty() ? Mat() : getSubVector(responses, idx);
    }
    Mat getTestNormCatResponses() const
    {
        Mat idx = getTestSampleIdx();
        return idx.empty() ? Mat() : getSubVector(normCatResponses, idx);
    }
    Mat getNormCatResponses() const { return normCatResponses; }
    Mat getClassLabels() const { return classLabels; }
    Mat getClassCounters() const { return classCounters; }
    int getCatCount(int vi) const
    {
        int n = (int)catOfs.total();
        CV_Assert( 0 <= vi && vi < n );
        Vec2i ofs = catOfs.at<Vec2i>(vi);
        return ofs[1] - ofs[0];
    }
210

211 212
    Mat getCatOfs() const { return catOfs; }
    Mat getCatMap() const { return catMap; }
213

214
    Mat getDefaultSubstValues() const { return missingSubst; }
215

216 217 218 219 220 221 222
    void closeFile() { if(file) fclose(file); file=0; }
    void clear()
    {
        closeFile();
        samples.release();
        missing.release();
        varType.release();
223
        varSymbolFlags.release();
224 225 226 227 228 229 230 231 232 233 234 235
        responses.release();
        sampleIdx.release();
        trainSampleIdx.release();
        testSampleIdx.release();
        normCatResponses.release();
        classLabels.release();
        classCounters.release();
        catMap.release();
        catOfs.release();
        nameMap = MapType();
        layout = ROW_SAMPLE;
    }
236

237
    typedef std::map<int, int> CatMapHash;
238

239 240 241
    void setData(InputArray _samples, int _layout, InputArray _responses,
                 InputArray _varIdx, InputArray _sampleIdx, InputArray _sampleWeights,
                 InputArray _varType, InputArray _missing)
242
    {
243
        clear();
244

245 246 247 248 249 250 251 252 253
        CV_Assert(_layout == ROW_SAMPLE || _layout == COL_SAMPLE );
        samples = _samples.getMat();
        layout = _layout;
        responses = _responses.getMat();
        varIdx = _varIdx.getMat();
        sampleIdx = _sampleIdx.getMat();
        sampleWeights = _sampleWeights.getMat();
        varType = _varType.getMat();
        missing = _missing.getMat();
254

255 256 257
        int nsamples = layout == ROW_SAMPLE ? samples.rows : samples.cols;
        int ninputvars = layout == ROW_SAMPLE ? samples.cols : samples.rows;
        int i, noutputvars = 0;
258

259
        CV_Assert( samples.type() == CV_32F || samples.type() == CV_32S );
260

261 262 263
        if( !sampleIdx.empty() )
        {
            CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
264
                       checkRange(sampleIdx, true, 0, 0, nsamples)) ||
265 266 267 268
                       sampleIdx.checkVector(1, CV_8U, true) == nsamples );
            if( sampleIdx.type() == CV_8U )
                sampleIdx = convertMaskToIdx(sampleIdx);
        }
269

270 271 272 273 274 275 276 277
        if( !sampleWeights.empty() )
        {
            CV_Assert( sampleWeights.checkVector(1, CV_32F, true) == nsamples );
        }
        else
        {
            sampleWeights = Mat::ones(nsamples, 1, CV_32F);
        }
278

279 280 281 282 283 284 285 286 287 288
        if( !varIdx.empty() )
        {
            CV_Assert( (varIdx.checkVector(1, CV_32S, true) > 0 &&
                       checkRange(varIdx, true, 0, 0, ninputvars)) ||
                       varIdx.checkVector(1, CV_8U, true) == ninputvars );
            if( varIdx.type() == CV_8U )
                varIdx = convertMaskToIdx(varIdx);
            varIdx = varIdx.clone();
            std::sort(varIdx.ptr<int>(), varIdx.ptr<int>() + varIdx.total());
        }
289

290
        if( !responses.empty() )
291
        {
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
            CV_Assert( responses.type() == CV_32F || responses.type() == CV_32S );
            if( (responses.cols == 1 || responses.rows == 1) && (int)responses.total() == nsamples )
                noutputvars = 1;
            else
            {
                CV_Assert( (layout == ROW_SAMPLE && responses.rows == nsamples) ||
                           (layout == COL_SAMPLE && responses.cols == nsamples) );
                noutputvars = layout == ROW_SAMPLE ? responses.cols : responses.rows;
            }
            if( !responses.isContinuous() || (layout == COL_SAMPLE && noutputvars > 1) )
            {
                Mat temp;
                transpose(responses, temp);
                responses = temp;
            }
307
        }
308

309
        int nvars = ninputvars + noutputvars;
310

311
        if( !varType.empty() )
312
        {
313 314
            CV_Assert( varType.checkVector(1, CV_8U, true) == nvars &&
                       checkRange(varType, true, 0, VAR_ORDERED, VAR_CATEGORICAL+1) );
315 316
        }
        else
317 318 319 320
        {
            varType.create(1, nvars, CV_8U);
            varType = Scalar::all(VAR_ORDERED);
            if( noutputvars == 1 )
321
                varType.at<uchar>(ninputvars) = (uchar)(responses.type() < CV_32F ? VAR_CATEGORICAL : VAR_ORDERED);
322
        }
323

324 325 326 327 328
        if( noutputvars > 1 )
        {
            for( i = 0; i < noutputvars; i++ )
                CV_Assert( varType.at<uchar>(ninputvars + i) == VAR_ORDERED );
        }
Maria Dimashova's avatar
Maria Dimashova committed
329

330 331
        catOfs = Mat::zeros(1, nvars, CV_32SC2);
        missingSubst = Mat::zeros(1, nvars, CV_32F);
332

333 334 335
        vector<int> labels, counters, sortbuf, tempCatMap;
        vector<Vec2i> tempCatOfs;
        CatMapHash ofshash;
336

337 338 339 340 341 342 343
        AutoBuffer<uchar> buf(nsamples);
        Mat non_missing(layout == ROW_SAMPLE ? Size(1, nsamples) : Size(nsamples, 1), CV_8U, (uchar*)buf);
        bool haveMissing = !missing.empty();
        if( haveMissing )
        {
            CV_Assert( missing.size() == samples.size() && missing.type() == CV_8U );
        }
344

345 346 347 348 349
        // we iterate through all the variables. For each categorical variable we build a map
        // in order to convert input values of the variable into normalized values (0..catcount_vi-1)
        // often many categorical variables are similar, so we compress the map - try to re-use
        // maps for different variables if they are identical
        for( i = 0; i < ninputvars; i++ )
350
        {
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
            Mat values_i = layout == ROW_SAMPLE ? samples.col(i) : samples.row(i);

            if( varType.at<uchar>(i) == VAR_CATEGORICAL )
            {
                preprocessCategorical(values_i, 0, labels, 0, sortbuf);
                missingSubst.at<float>(i) = -1.f;
                int j, m = (int)labels.size();
                CV_Assert( m > 0 );
                int a = labels.front(), b = labels.back();
                const int* currmap = &labels[0];
                int hashval = ((unsigned)a*127 + (unsigned)b)*127 + m;
                CatMapHash::iterator it = ofshash.find(hashval);
                if( it != ofshash.end() )
                {
                    int vi = it->second;
                    Vec2i ofs0 = tempCatOfs[vi];
                    int m0 = ofs0[1] - ofs0[0];
                    const int* map0 = &tempCatMap[ofs0[0]];
                    if( m0 == m && map0[0] == a && map0[m0-1] == b )
                    {
                        for( j = 0; j < m; j++ )
                            if( map0[j] != currmap[j] )
                                break;
                        if( j == m )
                        {
                            // re-use the map
                            tempCatOfs.push_back(ofs0);
                            continue;
                        }
                    }
                }
                else
                    ofshash[hashval] = i;
                Vec2i ofs;
                ofs[0] = (int)tempCatMap.size();
                ofs[1] = ofs[0] + m;
                tempCatOfs.push_back(ofs);
                std::copy(labels.begin(), labels.end(), std::back_inserter(tempCatMap));
            }
390
            else
391
            {
392 393 394 395 396
                tempCatOfs.push_back(Vec2i(0, 0));
                /*Mat missing_i = layout == ROW_SAMPLE ? missing.col(i) : missing.row(i);
                compare(missing_i, Scalar::all(0), non_missing, CMP_EQ);
                missingSubst.at<float>(i) = (float)(mean(values_i, non_missing)[0]);*/
                missingSubst.at<float>(i) = 0.f;
397 398 399
            }
        }

400 401 402 403 404
        if( !tempCatOfs.empty() )
        {
            Mat(tempCatOfs).copyTo(catOfs);
            Mat(tempCatMap).copyTo(catMap);
        }
405

406
        if( noutputvars > 0 && varType.at<uchar>(ninputvars) == VAR_CATEGORICAL )
407
        {
408 409 410
            preprocessCategorical(responses, &normCatResponses, labels, &counters, sortbuf);
            Mat(labels).copyTo(classLabels);
            Mat(counters).copyTo(classCounters);
411 412 413
        }
    }

414 415 416 417 418 419 420 421 422
    Mat convertMaskToIdx(const Mat& mask)
    {
        int i, j, nz = countNonZero(mask), n = mask.cols + mask.rows - 1;
        Mat idx(1, nz, CV_32S);
        for( i = j = 0; i < n; i++ )
            if( mask.at<uchar>(i) )
                idx.at<int>(j++) = i;
        return idx;
    }
423

424 425 426 427 428 429 430 431 432 433 434 435 436 437
    struct CmpByIdx
    {
        CmpByIdx(const int* _data, int _step) : data(_data), step(_step) {}
        bool operator ()(int i, int j) const { return data[i*step] < data[j*step]; }
        const int* data;
        int step;
    };

    void preprocessCategorical(const Mat& data, Mat* normdata, vector<int>& labels,
                               vector<int>* counters, vector<int>& sortbuf)
    {
        CV_Assert((data.cols == 1 || data.rows == 1) && (data.type() == CV_32S || data.type() == CV_32F));
        int* odata = 0;
        int ostep = 0;
438

439 440 441 442 443 444
        if(normdata)
        {
            normdata->create(data.size(), CV_32S);
            odata = normdata->ptr<int>();
            ostep = normdata->isContinuous() ? 1 : (int)normdata->step1();
        }
445

446 447 448 449 450
        int i, n = data.cols + data.rows - 1;
        sortbuf.resize(n*2);
        int* idx = &sortbuf[0];
        int* idata = (int*)data.ptr<int>();
        int istep = data.isContinuous() ? 1 : (int)data.step1();
451

452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
        if( data.type() == CV_32F )
        {
            idata = idx + n;
            const float* fdata = data.ptr<float>();
            for( i = 0; i < n; i++ )
            {
                if( fdata[i*istep] == MISSED_VAL )
                    idata[i] = -1;
                else
                {
                    idata[i] = cvRound(fdata[i*istep]);
                    CV_Assert( (float)idata[i] == fdata[i*istep] );
                }
            }
            istep = 1;
        }
468

469 470
        for( i = 0; i < n; i++ )
            idx[i] = i;
471

472
        std::sort(idx, idx + n, CmpByIdx(idata, istep));
473

474 475 476
        int clscount = 1;
        for( i = 1; i < n; i++ )
            clscount += idata[idx[i]*istep] != idata[idx[i-1]*istep];
477

478 479 480
        int clslabel = -1;
        int prev = ~idata[idx[0]*istep];
        int previdx = 0;
481

482 483 484 485 486
        labels.resize(clscount);
        if(counters)
            counters->resize(clscount);

        for( i = 0; i < n; i++ )
487
        {
488 489
            int l = idata[idx[i]*istep];
            if( l != prev )
490
            {
491 492 493 494 495 496 497
                clslabel++;
                labels[clslabel] = l;
                int k = i - previdx;
                if( clslabel > 0 && counters )
                    counters->at(clslabel-1) = k;
                prev = l;
                previdx = i;
498
            }
499 500
            if(odata)
                odata[idx[i]*ostep] = clslabel;
501
        }
502 503
        if(counters)
            counters->at(clslabel) = i - previdx;
504 505
    }

506 507 508 509 510 511 512 513
    bool loadCSV(const String& filename, int headerLines,
                 int responseStartIdx, int responseEndIdx,
                 const String& varTypeSpec, char delimiter, char missch)
    {
        const int M = 1000000;
        const char delimiters[3] = { ' ', delimiter, '\0' };
        int nvars = 0;
        bool varTypesSet = false;
514

515
        clear();
516

517
        file = fopen( filename.c_str(), "rt" );
518

519 520
        if( !file )
            return false;
521

522 523 524 525
        std::vector<char> _buf(M);
        std::vector<float> allresponses;
        std::vector<float> rowvals;
        std::vector<uchar> vtypes, rowtypes;
526
        std::vector<uchar> vsymbolflags;
527 528
        bool haveMissed = false;
        char* buf = &_buf[0];
529

530 531
        int i, ridx0 = responseStartIdx, ridx1 = responseEndIdx;
        int ninputvars = 0, noutputvars = 0;
532

533 534 535
        Mat tempSamples, tempMissing, tempResponses;
        MapType tempNameMap;
        int catCounter = 1;
536

537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561
        // skip header lines
        int lineno = 0;
        for(;;lineno++)
        {
            if( !fgets(buf, M, file) )
                break;
            if(lineno < headerLines )
                continue;
            // trim trailing spaces
            int idx = (int)strlen(buf)-1;
            while( idx >= 0 && isspace(buf[idx]) )
                buf[idx--] = '\0';
            // skip spaces in the beginning
            char* ptr = buf;
            while( *ptr != '\0' && isspace(*ptr) )
                ptr++;
            // skip commented off lines
            if(*ptr == '#')
                continue;
            rowvals.clear();
            rowtypes.clear();

            char* token = strtok(buf, delimiters);
            if (!token)
                break;
562

563 564 565 566 567 568 569
            for(;;)
            {
                float val=0.f; int tp = 0;
                decodeElem( token, val, tp, missch, tempNameMap, catCounter );
                if( tp == VAR_MISSED )
                    haveMissed = true;
                rowvals.push_back(val);
570
                rowtypes.push_back((uchar)tp);
571 572 573 574
                token = strtok(NULL, delimiters);
                if (!token)
                    break;
            }
575

576 577 578 579 580 581 582 583 584 585 586 587
            if( nvars == 0 )
            {
                if( rowvals.empty() )
                    CV_Error(CV_StsBadArg, "invalid CSV format; no data found");
                nvars = (int)rowvals.size();
                if( !varTypeSpec.empty() && varTypeSpec.size() > 0 )
                {
                    setVarTypes(varTypeSpec, nvars, vtypes);
                    varTypesSet = true;
                }
                else
                    vtypes = rowtypes;
588 589 590
                vsymbolflags.resize(nvars);
                for( i = 0; i < nvars; i++ )
                    vsymbolflags[i] = (uchar)(rowtypes[i] == VAR_CATEGORICAL);
591

592 593 594 595 596 597 598 599
                ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
                ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
                CV_Assert(ridx1 > ridx0);
                noutputvars = ridx0 >= 0 ? ridx1 - ridx0 : 0;
                ninputvars = nvars - noutputvars;
            }
            else
                CV_Assert( nvars == (int)rowvals.size() );
600

601 602 603 604 605
            // check var types
            for( i = 0; i < nvars; i++ )
            {
                CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
                           (varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
606 607 608 609 610
                uchar sflag = (uchar)(rowtypes[i] == VAR_CATEGORICAL);
                if( vsymbolflags[i] == VAR_MISSED )
                    vsymbolflags[i] = sflag;
                else
                    CV_Assert(vsymbolflags[i] == sflag || rowtypes[i] == VAR_MISSED);
611
            }
612

613 614 615 616 617 618 619 620 621 622 623
            if( ridx0 >= 0 )
            {
                for( i = ridx1; i < nvars; i++ )
                    std::swap(rowvals[i], rowvals[i-noutputvars]);
                for( i = ninputvars; i < nvars; i++ )
                    allresponses.push_back(rowvals[i]);
                rowvals.pop_back();
            }
            Mat rmat(1, ninputvars, CV_32F, &rowvals[0]);
            tempSamples.push_back(rmat);
        }
624

625
        closeFile();
626

627 628 629
        int nsamples = tempSamples.rows;
        if( nsamples == 0 )
            return false;
630

631 632
        if( haveMissed )
            compare(tempSamples, MISSED_VAL, tempMissing, CMP_EQ);
633

634 635 636 637 638 639 640 641 642 643 644 645
        if( ridx0 >= 0 )
        {
            for( i = ridx1; i < nvars; i++ )
                std::swap(vtypes[i], vtypes[i-noutputvars]);
            if( noutputvars > 1 )
            {
                for( i = ninputvars; i < nvars; i++ )
                    if( vtypes[i] == VAR_CATEGORICAL )
                        CV_Error(CV_StsBadArg,
                                 "If responses are vector values, not scalars, they must be marked as ordered responses");
            }
        }
646

647 648 649 650 651 652 653 654
        if( !varTypesSet && noutputvars == 1 && vtypes[ninputvars] == VAR_ORDERED )
        {
            for( i = 0; i < nsamples; i++ )
                if( allresponses[i] != cvRound(allresponses[i]) )
                    break;
            if( i == nsamples )
                vtypes[ninputvars] = VAR_CATEGORICAL;
        }
655

Lorena García's avatar
Lorena García committed
656
        //If there are responses in the csv file, save them. If not, responses matrix will contain just zeros
Lorena García's avatar
Lorena García committed
657
        if (noutputvars != 0){
Lorena García's avatar
Lorena García committed
658 659 660
            Mat(nsamples, noutputvars, CV_32F, &allresponses[0]).copyTo(tempResponses);
            setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
                    noArray(), Mat(vtypes).clone(), tempMissing);
Lorena García's avatar
Lorena García committed
661 662 663 664 665 666 667
        }
        else{
            Mat zero_mat(nsamples, 1, CV_32F, Scalar(0));
            zero_mat.copyTo(tempResponses);
            setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
                    noArray(), noArray(), tempMissing);
        }
668 669
        bool ok = !samples.empty();
        if(ok)
670
        {
671
            std::swap(tempNameMap, nameMap);
672 673
            Mat(vsymbolflags).copyTo(varSymbolFlags);
        }
674 675
        return ok;
    }
676

677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701
    void decodeElem( const char* token, float& elem, int& type,
                     char missch, MapType& namemap, int& counter ) const
    {
        char* stopstring = NULL;
        elem = (float)strtod( token, &stopstring );
        if( *stopstring == missch && strlen(stopstring) == 1 ) // missed value
        {
            elem = MISSED_VAL;
            type = VAR_MISSED;
        }
        else if( *stopstring != '\0' )
        {
            MapType::iterator it = namemap.find(token);
            if( it == namemap.end() )
            {
                elem = (float)counter;
                namemap[token] = counter++;
            }
            else
                elem = (float)it->second;
            type = VAR_CATEGORICAL;
        }
        else
            type = VAR_ORDERED;
    }
702

703 704 705 706 707 708
    void setVarTypes( const String& s, int nvars, std::vector<uchar>& vtypes ) const
    {
        const char* errmsg = "type spec is not correct; it should have format \"cat\", \"ord\" or "
          "\"ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\", where n's and m's are 0-based variable indices";
        const char* str = s.c_str();
        int specCounter = 0;
709

710
        vtypes.resize(nvars);
711

712 713 714 715 716 717 718
        for( int k = 0; k < 2; k++ )
        {
            const char* ptr = strstr(str, k == 0 ? "ord" : "cat");
            int tp = k == 0 ? VAR_ORDERED : VAR_CATEGORICAL;
            if( ptr ) // parse ord/cat str
            {
                char* stopstring = NULL;
719

720 721 722 723 724 725 726
                if( ptr[3] == '\0' )
                {
                    for( int i = 0; i < nvars; i++ )
                        vtypes[i] = (uchar)tp;
                    specCounter = nvars;
                    break;
                }
727

728 729
                if ( ptr[3] != '[')
                    CV_Error( CV_StsBadArg, errmsg );
730

731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758
                ptr += 4; // pass "ord["
                do
                {
                    int b1 = (int)strtod( ptr, &stopstring );
                    if( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
                        CV_Error( CV_StsBadArg, errmsg );
                    ptr = stopstring + 1;
                    if( (stopstring[0] == ',') || (stopstring[0] == ']'))
                    {
                        CV_Assert( 0 <= b1 && b1 < nvars );
                        vtypes[b1] = (uchar)tp;
                        specCounter++;
                    }
                    else
                    {
                        if( stopstring[0] == '-')
                        {
                            int b2 = (int)strtod( ptr, &stopstring);
                            if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
                                CV_Error( CV_StsBadArg, errmsg );
                            ptr = stopstring + 1;
                            CV_Assert( 0 <= b1 && b1 <= b2 && b2 < nvars );
                            for (int i = b1; i <= b2; i++)
                                vtypes[i] = (uchar)tp;
                            specCounter += b2 - b1 + 1;
                        }
                        else
                            CV_Error( CV_StsBadArg, errmsg );
759

760 761 762 763 764
                    }
                }
                while(*stopstring != ']');
            }
        }
765

766 767
        if( specCounter != nvars )
            CV_Error( CV_StsBadArg, "type of some variables is not specified" );
768 769
    }

770
    void setTrainTestSplitRatio(double ratio, bool shuffle)
771
    {
772
        CV_Assert( 0. <= ratio && ratio <= 1. );
773
        setTrainTestSplit(cvRound(getNSamples()*ratio), shuffle);
774 775
    }

776
    void setTrainTestSplit(int count, bool shuffle)
777
    {
778
        int i, nsamples = getNSamples();
779
        CV_Assert( 0 <= count && count < nsamples );
780 781 782

        trainSampleIdx.release();
        testSampleIdx.release();
783

784 785 786 787 788
        if( count == 0 )
            trainSampleIdx = sampleIdx;
        else if( count == nsamples )
            testSampleIdx = sampleIdx;
        else
789
        {
790
            Mat mask(1, nsamples, CV_8U);
791
            uchar* mptr = mask.ptr();
792 793 794 795 796 797 798 799 800
            for( i = 0; i < nsamples; i++ )
                mptr[i] = (uchar)(i < count);
            trainSampleIdx.create(1, count, CV_32S);
            testSampleIdx.create(1, nsamples - count, CV_32S);
            int j0 = 0, j1 = 0;
            const int* sptr = !sampleIdx.empty() ? sampleIdx.ptr<int>() : 0;
            int* trainptr = trainSampleIdx.ptr<int>();
            int* testptr = testSampleIdx.ptr<int>();
            for( i = 0; i < nsamples; i++ )
801
            {
802 803 804
                int idx = sptr ? sptr[i] : i;
                if( mptr[i] )
                    trainptr[j0++] = idx;
805
                else
806
                    testptr[j1++] = idx;
807
            }
808 809
            if( shuffle )
                shuffleTrainTest();
810
        }
811
    }
812

813
    void shuffleTrainTest()
814
    {
815
        if( !trainSampleIdx.empty() && !testSampleIdx.empty() )
816
        {
817 818 819 820 821 822
            int i, nsamples = getNSamples(), ntrain = getNTrainSamples(), ntest = getNTestSamples();
            int* trainIdx = trainSampleIdx.ptr<int>();
            int* testIdx = testSampleIdx.ptr<int>();
            RNG& rng = theRNG();

            for( i = 0; i < nsamples; i++)
823
            {
824 825 826 827 828
                int a = rng.uniform(0, nsamples);
                int b = rng.uniform(0, nsamples);
                int* ptra = trainIdx;
                int* ptrb = trainIdx;
                if( a >= ntrain )
829
                {
830 831 832
                    ptra = testIdx;
                    a -= ntrain;
                    CV_Assert( a < ntest );
833
                }
834 835 836 837 838 839 840
                if( b >= ntrain )
                {
                    ptrb = testIdx;
                    b -= ntrain;
                    CV_Assert( b < ntest );
                }
                std::swap(ptra[a], ptrb[b]);
841 842 843 844
            }
        }
    }

845 846 847
    Mat getTrainSamples(int _layout,
                        bool compressSamples,
                        bool compressVars) const
848
    {
849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871
        if( samples.empty() )
            return samples;

        if( (!compressSamples || (trainSampleIdx.empty() && sampleIdx.empty())) &&
            (!compressVars || varIdx.empty()) &&
            layout == _layout )
            return samples;

        int drows = getNTrainSamples(), dcols = getNVars();
        Mat sidx = getTrainSampleIdx(), vidx = getVarIdx();
        const float* src0 = samples.ptr<float>();
        const int* sptr = !sidx.empty() ? sidx.ptr<int>() : 0;
        const int* vptr = !vidx.empty() ? vidx.ptr<int>() : 0;
        size_t sstep0 = samples.step/samples.elemSize();
        size_t sstep = layout == ROW_SAMPLE ? sstep0 : 1;
        size_t vstep = layout == ROW_SAMPLE ? 1 : sstep0;

        if( _layout == COL_SAMPLE )
        {
            std::swap(drows, dcols);
            std::swap(sptr, vptr);
            std::swap(sstep, vstep);
        }
872

873
        Mat dsamples(drows, dcols, CV_32F);
874

875 876 877 878
        for( int i = 0; i < drows; i++ )
        {
            const float* src = src0 + (sptr ? sptr[i] : i)*sstep;
            float* dst = dsamples.ptr<float>(i);
879

880 881 882
            for( int j = 0; j < dcols; j++ )
                dst[j] = src[(vptr ? vptr[j] : j)*vstep];
        }
883

884
        return dsamples;
885 886
    }

887
    void getValues( int vi, InputArray _sidx, float* values ) const
888
    {
889
        Mat sidx = _sidx.getMat();
890
        int i, n = sidx.checkVector(1, CV_32S), nsamples = getNSamples();
891
        CV_Assert( 0 <= vi && vi < getNAllVars() );
892
        CV_Assert( n >= 0 );
893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908
        const int* s = n > 0 ? sidx.ptr<int>() : 0;
        if( n == 0 )
            n = nsamples;

        size_t step = samples.step/samples.elemSize();
        size_t sstep = layout == ROW_SAMPLE ? step : 1;
        size_t vstep = layout == ROW_SAMPLE ? 1 : step;

        const float* src = samples.ptr<float>() + vi*vstep;
        float subst = missingSubst.at<float>(vi);
        for( i = 0; i < n; i++ )
        {
            int j = i;
            if( s )
            {
                j = s[i];
909
                CV_Assert( 0 <= j && j < nsamples );
910 911 912 913 914
            }
            values[i] = src[j*sstep];
            if( values[i] == MISSED_VAL )
                values[i] = subst;
        }
915 916
    }

917
    void getNormCatValues( int vi, InputArray _sidx, int* values ) const
918
    {
919 920 921 922 923
        float* fvalues = (float*)values;
        getValues(vi, _sidx, fvalues);
        int i, n = (int)_sidx.total();
        Vec2i ofs = catOfs.at<Vec2i>(vi);
        int m = ofs[1] - ofs[0];
924

925 926
        CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
        const int* cmap = &catMap.at<int>(ofs[0]);
927
        bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
928

929
        if( fastMap )
930
        {
931 932 933 934 935 936 937
            for( i = 0; i < n; i++ )
            {
                int val = cvRound(fvalues[i]);
                int idx = val - cmap[0];
                CV_Assert(cmap[idx] == val);
                values[i] = idx;
            }
938
        }
939 940 941 942 943 944
        else
        {
            for( i = 0; i < n; i++ )
            {
                int val = cvRound(fvalues[i]);
                int a = 0, b = m, c = -1;
945

946 947 948 949 950 951 952 953 954 955
                while( a < b )
                {
                    c = (a + b) >> 1;
                    if( val < cmap[c] )
                        b = c;
                    else if( val > cmap[c] )
                        a = c+1;
                    else
                        break;
                }
956

957 958 959 960
                CV_DbgAssert( c >= 0 && val == cmap[c] );
                values[i] = c;
            }
        }
961 962
    }

963 964 965 966
    void getSample(InputArray _vidx, int sidx, float* buf) const
    {
        CV_Assert(buf != 0 && 0 <= sidx && sidx < getNSamples());
        Mat vidx = _vidx.getMat();
967 968
        int i, n = vidx.checkVector(1, CV_32S), nvars = getNAllVars();
        CV_Assert( n >= 0 );
969 970 971 972 973 974 975 976 977 978
        const int* vptr = n > 0 ? vidx.ptr<int>() : 0;
        if( n == 0 )
            n = nvars;

        size_t step = samples.step/samples.elemSize();
        size_t sstep = layout == ROW_SAMPLE ? step : 1;
        size_t vstep = layout == ROW_SAMPLE ? 1 : step;

        const float* src = samples.ptr<float>() + sidx*sstep;
        for( i = 0; i < n; i++ )
979
        {
980 981 982 983
            int j = i;
            if( vptr )
            {
                j = vptr[i];
984
                CV_Assert( 0 <= j && j < nvars );
985 986
            }
            buf[i] = src[j*vstep];
987
        }
988
    }
989

990 991
    FILE* file;
    int layout;
992
    Mat samples, missing, varType, varIdx, varSymbolFlags, responses, missingSubst;
993 994 995 996 997
    Mat sampleIdx, trainSampleIdx, testSampleIdx;
    Mat sampleWeights, catMap, catOfs;
    Mat normCatResponses, classLabels, classCounters;
    MapType nameMap;
};
998

999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023
void TrainData::getNames(std::vector<String>& names) const
{
    const TrainDataImpl* impl = dynamic_cast<const TrainDataImpl*>(this);
    CV_Assert(impl != 0);
    size_t n = impl->nameMap.size();
    TrainDataImpl::MapType::const_iterator it = impl->nameMap.begin(),
                                           it_end = impl->nameMap.end();
    names.resize(n+1);
    names[0] = "?";
    for( ; it != it_end; ++it )
    {
        String s = it->first;
        int label = it->second;
        CV_Assert( label > 0 && label <= (int)n );
        names[label] = s;
    }
}

Mat TrainData::getVarSymbolFlags() const
{
    const TrainDataImpl* impl = dynamic_cast<const TrainDataImpl*>(this);
    CV_Assert(impl != 0);
    return impl->varSymbolFlags;
}

1024 1025 1026 1027 1028 1029
Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
                                      int headerLines,
                                      int responseStartIdx,
                                      int responseEndIdx,
                                      const String& varTypeSpec,
                                      char delimiter, char missch)
1030
{
1031
    CV_TRACE_FUNCTION_SKIP_NESTED();
1032 1033 1034 1035
    Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
    if(!td->loadCSV(filename, headerLines, responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch))
        td.release();
    return td;
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
1036 1037
}

1038 1039 1040
Ptr<TrainData> TrainData::create(InputArray samples, int layout, InputArray responses,
                                 InputArray varIdx, InputArray sampleIdx, InputArray sampleWeights,
                                 InputArray varType)
Vadim Pisarevsky's avatar
Vadim Pisarevsky committed
1041
{
1042
    CV_TRACE_FUNCTION_SKIP_NESTED();
1043 1044 1045
    Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
    td->setData(samples, layout, responses, varIdx, sampleIdx, sampleWeights, varType, noArray());
    return td;
1046 1047
}

1048 1049
}}

1050
/* End of file. */