test_gbttest.cpp 7.29 KB
Newer Older
1 2 3

#include "test_precomp.hpp"

4 5
#if 0

6 7 8 9 10 11 12
using namespace std;


class CV_GBTreesTest : public cvtest::BaseTest
{
public:
    CV_GBTreesTest();
Andrey Kamaev's avatar
Andrey Kamaev committed
13 14
    ~CV_GBTreesTest();

15 16 17 18 19 20 21
protected:
    void run(int);

    int TestTrainPredict(int test_num);
    int TestSaveLoad();

    int checkPredictError(int test_num);
Andrey Kamaev's avatar
Andrey Kamaev committed
22 23
    int checkLoadSave();

24 25
    string model_file_name1;
    string model_file_name2;
26

27 28
    string* datasets;
    string data_path;
Andrey Kamaev's avatar
Andrey Kamaev committed
29

30 31
    CvMLData* data;
    CvGBTrees* gtb;
Andrey Kamaev's avatar
Andrey Kamaev committed
32

33 34
    vector<float> test_resps1;
    vector<float> test_resps2;
35

Andrey Kamaev's avatar
Andrey Kamaev committed
36
    int64 initSeed;
37 38 39 40 41 42 43 44 45 46 47
};


int _get_len(const CvMat* mat)
{
    return (mat->cols > mat->rows) ? mat->cols : mat->rows;
}


CV_GBTreesTest::CV_GBTreesTest()
{
Andrey Kamaev's avatar
Andrey Kamaev committed
48
    int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52),
49 50 51 52 53 54 55
                      CV_BIG_INT(0x0000a17166072c7c),
                      CV_BIG_INT(0x0201b32115cd1f9a),
                      CV_BIG_INT(0x0513cb37abcd1234),
                      CV_BIG_INT(0x0001a2b3c4d5f678)
                    };

    int seedCount = sizeof(seeds)/sizeof(seeds[0]);
Andrey Kamaev's avatar
Andrey Kamaev committed
56
    cv::RNG& rng = cv::theRNG();
57 58 59
    initSeed = rng.state;
    rng.state = seeds[rng(seedCount)];

60 61 62 63 64 65 66 67 68 69
    datasets = 0;
    data = 0;
    gtb = 0;
}

CV_GBTreesTest::~CV_GBTreesTest()
{
    if (data)
        delete data;
    delete[] datasets;
Andrey Kamaev's avatar
Andrey Kamaev committed
70
    cv::theRNG().state = initSeed;
71 72 73 74 75 76
}


int CV_GBTreesTest::TestTrainPredict(int test_num)
{
    int code = cvtest::TS::OK;
Andrey Kamaev's avatar
Andrey Kamaev committed
77

78 79 80 81
    int weak_count = 200;
    float shrinkage = 0.1f;
    float subsample_portion = 0.5f;
    int max_depth = 5;
82
    bool use_surrogates = false;
83 84 85 86 87 88 89
    int loss_function_type = 0;
    switch (test_num)
    {
        case (1) : loss_function_type = CvGBTrees::SQUARED_LOSS; break;
        case (2) : loss_function_type = CvGBTrees::ABSOLUTE_LOSS; break;
        case (3) : loss_function_type = CvGBTrees::HUBER_LOSS; break;
        case (0) : loss_function_type = CvGBTrees::DEVIANCE_LOSS; break;
Andrey Kamaev's avatar
Andrey Kamaev committed
90
        default  :
91 92 93 94 95 96 97 98 99 100 101
            {
            ts->printf( cvtest::TS::LOG, "Bad test_num value in CV_GBTreesTest::TestTrainPredict(..) function." );
            return cvtest::TS::FAIL_BAD_ARG_CHECK;
            }
    }

    int dataset_num = test_num == 0 ? 0 : 1;
    if (!data)
    {
        data = new CvMLData();
        data->set_delimiter(',');
Andrey Kamaev's avatar
Andrey Kamaev committed
102

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
        if (data->read_csv(datasets[dataset_num].c_str()))
        {
            ts->printf( cvtest::TS::LOG, "File reading error." );
            return cvtest::TS::FAIL_INVALID_TEST_DATA;
        }

        if (test_num == 0)
        {
            data->set_response_idx(57);
            data->set_var_types("ord[0-56],cat[57]");
        }
        else
        {
            data->set_response_idx(13);
            data->set_var_types("ord[0-2,4-13],cat[3]");
            subsample_portion = 0.7f;
        }

        int train_sample_count = cvFloor(_get_len(data->get_responses())*0.5f);
        CvTrainTestSplit spl( train_sample_count );
        data->set_train_test_split( &spl );
    }
Andrey Kamaev's avatar
Andrey Kamaev committed
125 126 127 128

    data->mix_train_and_test_idx();


129 130 131 132 133 134
    if (gtb) delete gtb;
    gtb = new CvGBTrees();
    bool tmp_code = true;
    tmp_code = gtb->train(data, CvGBTreesParams(loss_function_type, weak_count,
                          shrinkage, subsample_portion,
                          max_depth, use_surrogates));
Andrey Kamaev's avatar
Andrey Kamaev committed
135

136 137 138 139 140
    if (!tmp_code)
    {
        ts->printf( cvtest::TS::LOG, "Model training was failed.");
        return cvtest::TS::FAIL_INVALID_OUTPUT;
    }
Andrey Kamaev's avatar
Andrey Kamaev committed
141

142
    code = checkPredictError(test_num);
Andrey Kamaev's avatar
Andrey Kamaev committed
143

144 145 146 147 148 149 150 151 152
    return code;

}


int CV_GBTreesTest::checkPredictError(int test_num)
{
    if (!gtb)
        return cvtest::TS::FAIL_GENERIC;
Andrey Kamaev's avatar
Andrey Kamaev committed
153

154 155
    //float mean[] = {5.430247f, 13.5654f, 12.6569f, 13.1661f};
    //float sigma[] = {0.4162694f, 3.21161f, 3.43297f, 3.00624f};
Andrey Kamaev's avatar
Andrey Kamaev committed
156
    float mean[] = {5.80226f, 12.68689f, 13.49095f, 13.19628f};
157
    float sigma[] = {0.4764534f, 3.166919f, 3.022405f, 2.868722f};
Andrey Kamaev's avatar
Andrey Kamaev committed
158

159
    float current_error = gtb->calc_error(data, CV_TEST_ERROR);
Andrey Kamaev's avatar
Andrey Kamaev committed
160

161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
    if ( abs( current_error - mean[test_num]) > 6*sigma[test_num] )
    {
        ts->printf( cvtest::TS::LOG, "Test error is out of range:\n"
                    "abs(%f/*curEr*/ - %f/*mean*/ > %f/*6*sigma*/",
                    current_error, mean[test_num], 6*sigma[test_num] );
        return cvtest::TS::FAIL_BAD_ACCURACY;
    }

    return cvtest::TS::OK;

}


int CV_GBTreesTest::TestSaveLoad()
{
    if (!gtb)
        return cvtest::TS::FAIL_GENERIC;
Andrey Kamaev's avatar
Andrey Kamaev committed
178

179 180
    model_file_name1 = cv::tempfile();
    model_file_name2 = cv::tempfile();
181

182
    gtb->save(model_file_name1.c_str());
183
    gtb->calc_error(data, CV_TEST_ERROR, &test_resps1);
184
    gtb->load(model_file_name1.c_str());
185
    gtb->calc_error(data, CV_TEST_ERROR, &test_resps2);
186
    gtb->save(model_file_name2.c_str());
Andrey Kamaev's avatar
Andrey Kamaev committed
187

188
    return checkLoadSave();
Andrey Kamaev's avatar
Andrey Kamaev committed
189

190 191 192 193 194 195 196 197 198
}



int CV_GBTreesTest::checkLoadSave()
{
    int code = cvtest::TS::OK;

    // 1. compare files
199
    ifstream f1( model_file_name1.c_str() ), f2( model_file_name2.c_str() );
200
    string s1, s2;
Andrey Kamaev's avatar
Andrey Kamaev committed
201
    int lineIdx = 0;
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
    CV_Assert( f1.is_open() && f2.is_open() );
    for( ; !f1.eof() && !f2.eof(); lineIdx++ )
    {
        getline( f1, s1 );
        getline( f2, s2 );
        if( s1.compare(s2) )
        {
            ts->printf( cvtest::TS::LOG, "first and second saved files differ in %n-line; first %n line: %s; second %n-line: %s",
               lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() );
            code = cvtest::TS::FAIL_INVALID_OUTPUT;
        }
    }
    if( !f1.eof() || !f2.eof() )
    {
        ts->printf( cvtest::TS::LOG, "First and second saved files differ in %n-line; first %n line: %s; second %n-line: %s",
            lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() );
        code = cvtest::TS::FAIL_INVALID_OUTPUT;
    }
    f1.close();
    f2.close();
    // delete temporary files
223 224
    remove( model_file_name1.c_str() );
    remove( model_file_name2.c_str() );
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244

    // 2. compare responses
    CV_Assert( test_resps1.size() == test_resps2.size() );
    vector<float>::const_iterator it1 = test_resps1.begin(), it2 = test_resps2.begin();
    for( ; it1 != test_resps1.end(); ++it1, ++it2 )
    {
        if( fabs(*it1 - *it2) > FLT_EPSILON )
        {
            ts->printf( cvtest::TS::LOG, "Responses predicted before saving and after loading are different" );
            code = cvtest::TS::FAIL_INVALID_OUTPUT;
        }
    }
    return code;
}



void CV_GBTreesTest::run(int)
{

Andrey Kamaev's avatar
Andrey Kamaev committed
245
    string dataPath = string(ts->get_data_path());
246
    datasets = new string[2];
Andrey Kamaev's avatar
Andrey Kamaev committed
247 248
    datasets[0] = dataPath + string("spambase.data"); /*string("dataset_classification.csv");*/
    datasets[1] = dataPath + string("housing_.data");  /*string("dataset_regression.csv");*/
249 250 251 252 253

    int code = cvtest::TS::OK;

    for (int i = 0; i < 4; i++)
    {
Andrey Kamaev's avatar
Andrey Kamaev committed
254

255 256 257 258 259 260
        int temp_code = TestTrainPredict(i);
        if (temp_code != cvtest::TS::OK)
        {
            code = temp_code;
            break;
        }
Andrey Kamaev's avatar
Andrey Kamaev committed
261

262 263 264 265 266 267 268 269
        else if (i==0)
        {
            temp_code = TestSaveLoad();
            if (temp_code != cvtest::TS::OK)
                code = temp_code;
            delete data;
            data = 0;
        }
Andrey Kamaev's avatar
Andrey Kamaev committed
270

271 272 273 274 275
        delete gtb;
        gtb = 0;
    }
    delete data;
    data = 0;
Andrey Kamaev's avatar
Andrey Kamaev committed
276

277 278 279 280 281 282 283 284
    ts->set_failed_test_info( code );
}

/////////////////////////////////////////////////////////////////////////////
//////////////////// test registration  /////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////

TEST(ML_GBTrees, regression) { CV_GBTreesTest test; test.safe_run(); }
285 286

#endif