tree_engine.cpp 2.72 KB
Newer Older
1
#include "opencv2/ml/ml.hpp"
2
#include "opencv2/core/core.hpp"
3
#include "opencv2/core/utility.hpp"
4
#include <stdio.h>
5
#include <string>
6
#include <map>
Gary Bradski's avatar
Gary Bradski committed
7

8 9 10
using namespace cv;
using namespace cv::ml;

11
static void help()
Gary Bradski's avatar
Gary Bradski committed
12
{
13
    printf(
14 15
        "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
        "Usage:\n\t./tree_engine [-r <response_column>] [-ts type_spec] <csv filename>\n"
16
        "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
17
        "-ts specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
18
        "<csv filename> is the name of training data file in comma-separated value format\n\n");
Gary Bradski's avatar
Gary Bradski committed
19
}
20

21
static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
22
{
23 24
    bool ok = model->train(data);
    if( !ok )
25
    {
26
        printf("Training failed\n");
27
    }
28
    else
29
    {
30 31
        printf( "train error: %f\n", model->calcError(data, false, noArray()) );
        printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
32 33 34
    }
}

35
int main(int argc, char** argv)
36
{
37
    if(argc < 2)
38 39
    {
        help();
40 41 42 43
        return 0;
    }
    const char* filename = 0;
    int response_idx = 0;
44
    std::string typespec;
45

46 47 48 49
    for(int i = 1; i < argc; i++)
    {
        if(strcmp(argv[i], "-r") == 0)
            sscanf(argv[++i], "%d", &response_idx);
50 51
        else if(strcmp(argv[i], "-ts") == 0)
            typespec = argv[++i];
52 53 54 55 56 57 58 59
        else if(argv[i][0] != '-' )
            filename = argv[i];
        else
        {
            printf("Error. Invalid option %s\n", argv[i]);
            help();
            return -1;
        }
60
    }
61

62
    printf("\nReading in %s...\n\n",filename);
63
    const double train_test_split_ratio = 0.5;
64

65
    Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
66

67
    if( data.empty() )
68
    {
69 70 71
        printf("ERROR: File %s can not be read\n", filename);
        return 0;
    }
72

73
    data->setTrainTestSplitRatio(train_test_split_ratio);
74

75 76 77
    printf("======DTREE=====\n");
    Ptr<DTrees> dtree = DTrees::create(DTrees::Params( 10, 2, 0, false, 16, 0, false, false, Mat() ));
    train_and_print_errs(dtree, data);
78

79 80 81 82 83
    if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
    {
        printf("======BOOST=====\n");
        Ptr<Boost> boost = Boost::create(Boost::Params(Boost::GENTLE, 100, 0.95, 2, false, Mat()));
        train_and_print_errs(boost, data);
84
    }
85 86 87 88

    printf("======RTREES=====\n");
    Ptr<RTrees> rtrees = RTrees::create(RTrees::Params(10, 2, 0, false, 16, Mat(), false, 0, TermCriteria(TermCriteria::MAX_ITER, 100, 0)));
    train_and_print_errs(rtrees, data);
89 90

    return 0;
91
}