tree_engine.cpp 4.57 KB
Newer Older
1
#include "opencv2/ml/ml.hpp"
2
#include "opencv2/core/core_c.h"
3
#include "opencv2/core/utility.hpp"
4
#include <stdio.h>
5
#include <map>
Gary Bradski's avatar
Gary Bradski committed
6

7
static void help()
Gary Bradski's avatar
Gary Bradski committed
8
{
9 10 11 12 13 14 15 16
    printf(
        "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees:\n"
        "CvDTree dtree;\n"
        "CvBoost boost;\n"
        "CvRTrees rtrees;\n"
        "CvERTrees ertrees;\n"
        "CvGBTrees gbtrees;\n"
        "Call:\n\t./tree_engine [-r <response_column>] [-c] <csv filename>\n"
17 18 19
        "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
        "-c specifies that the response is categorical (it's ordered by default) and\n"
        "<csv filename> is the name of training data file in comma-separated value format\n\n");
Gary Bradski's avatar
Gary Bradski committed
20
}
21 22


23
static int count_classes(CvMLData& data)
24
{
25
    cv::Mat r = cv::cvarrToMat(data.get_responses());
26 27 28 29 30 31 32 33
    std::map<int, int> rmap;
    int i, n = (int)r.total();
    for( i = 0; i < n; i++ )
    {
        float val = r.at<float>(i);
        int ival = cvRound(val);
        if( ival != val )
            return -1;
34
        rmap[ival] = 1;
35
    }
36
    return (int)rmap.size();
37 38
}

39
static void print_result(float train_err, float test_err, const CvMat* _var_imp)
40 41 42
{
    printf( "train error    %f\n", train_err );
    printf( "test error    %f\n\n", test_err );
43

44
    if (_var_imp)
45
    {
46
        cv::Mat var_imp = cv::cvarrToMat(_var_imp), sorted_idx;
47
        cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING);
48

49 50 51 52
        printf( "variable importance:\n" );
        int i, n = (int)var_imp.total();
        int type = var_imp.type();
        CV_Assert(type == CV_32F || type == CV_64F);
53

54
        for( i = 0; i < n; i++)
55
        {
56 57
            int k = sorted_idx.at<int>(i);
            printf( "%d\t%f\n", k, type == CV_32F ? var_imp.at<float>(k) : var_imp.at<double>(k));
58 59 60 61 62
        }
    }
    printf("\n");
}

63
int main(int argc, char** argv)
64
{
65
    if(argc < 2)
66 67
    {
        help();
68 69 70 71 72
        return 0;
    }
    const char* filename = 0;
    int response_idx = 0;
    bool categorical_response = false;
73

74 75 76 77 78 79 80 81 82 83 84 85 86 87
    for(int i = 1; i < argc; i++)
    {
        if(strcmp(argv[i], "-r") == 0)
            sscanf(argv[++i], "%d", &response_idx);
        else if(strcmp(argv[i], "-c") == 0)
            categorical_response = true;
        else if(argv[i][0] != '-' )
            filename = argv[i];
        else
        {
            printf("Error. Invalid option %s\n", argv[i]);
            help();
            return -1;
        }
88
    }
89

90
    printf("\nReading in %s...\n\n",filename);
91 92 93 94
    CvDTree dtree;
    CvBoost boost;
    CvRTrees rtrees;
    CvERTrees ertrees;
95
    CvGBTrees gbtrees;
96 97 98

    CvMLData data;

99

100
    CvTrainTestSplit spl( 0.5f );
101

102
    if ( data.read_csv( filename ) == 0)
103
    {
104 105 106
        data.set_response_idx( response_idx );
        if(categorical_response)
            data.change_var_type( response_idx, CV_VAR_CATEGORICAL );
107
        data.set_train_test_split( &spl );
108

109 110 111 112
        printf("======DTREE=====\n");
        dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
        print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );

113 114
        if( categorical_response && count_classes(data) == 2 )
        {
115 116
        printf("======BOOST=====\n");
        boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
117
        print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
118
        }
119 120 121 122 123 124

        printf("======RTREES=====\n");
        rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
        print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() );

        printf("======ERTREES=====\n");
125
        ertrees.train( &data, CvRTParams( 18, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
126
        print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );
127

128
        printf("======GBTREES=====\n");
129 130 131 132
        if (categorical_response)
            gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.1f, 0.8f, 5, false));
        else
            gbtrees.train( &data, CvGBTreesParams(CvGBTrees::SQUARED_LOSS, 100, 0.1f, 0.8f, 5, false));
133
        print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
134 135 136 137 138
    }
    else
        printf("File can not be read");

    return 0;
139
}