Commit e20d570e authored by P. Druzhkov's avatar P. Druzhkov

brief gbt documentation added. some sample fixes made. code updated.

parent 9c071c6a
This diff is collapsed.
......@@ -15,7 +15,7 @@ Most of the classification and regression algorithms are implemented as C++ clas
support_vector_machines
decision_trees
boosting
gradient_boosted_trees
random_trees
expectation_maximization
neural_networks
......@@ -1571,7 +1571,7 @@ public:
// Response value prediction
//
// API
// virtual float predict( const CvMat* sample, const CvMat* missing=0,
// virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
int k=-1 ) const;
......@@ -1594,12 +1594,44 @@ public:
// RESULT
// Predicted value.
*/
virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
int k=-1 ) const;
/*
// Response value prediction.
// Parallel version (in the case of TBB existence)
//
// API
// virtual float predict( const CvMat* sample, const CvMat* missing=0,
CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
int k=-1 ) const;
// INPUT
// sample - input sample of the same type as in the training set.
// missing - missing values mask. missing=0 if there are no
// missing values in sample vector.
// weak_responses - predictions of all of the trees.
// not implemented (!)
// slice - part of the ensemble used for prediction.
// slice = CV_WHOLE_SEQ when all trees are used.
// k - number of ensemble used.
// k is in {-1,0,1,..,<count of output classes-1>}.
// in the case of classification problem
// <count of output classes-1> ensembles are built.
// If k = -1 ordinary prediction is the result,
// otherwise function gives the prediction of the
// k-th ensemble only.
// OUTPUT
// RESULT
// Predicted value.
*/
virtual float predict( const CvMat* sample, const CvMat* missing=0,
CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
int k=-1 ) const;
/*
// Delete all temporary data.
// Deletes all the data.
//
// API
// virtual void clear();
......@@ -1607,7 +1639,7 @@ public:
// INPUT
// OUTPUT
// delete data, weak, orig_response, sum_response,
// weak_eval, ubsample_train, subsample_test,
// weak_eval, subsample_train, subsample_test,
// sample_idx, missing, lass_labels
// delta = 0.0
// RESULT
......@@ -1623,7 +1655,7 @@ public:
//
// INPUT
// data - dataset
// type - defines which error is to compute^ train (CV_TRAIN_ERROR) or
// type - defines which error is to compute: train (CV_TRAIN_ERROR) or
// test (CV_TEST_ERROR).
// OUTPUT
// resp - vector of predicitons
......@@ -1633,7 +1665,6 @@ public:
virtual float calc_error( CvMLData* _data, int type,
std::vector<float> *resp = 0 );
/*
//
// Write parameters of the gtb model and data. Write learned model.
......@@ -1852,7 +1883,6 @@ protected:
CvMat* orig_response;
CvMat* sum_response;
CvMat* sum_response_tmp;
CvMat* weak_eval;
CvMat* sample_idx;
CvMat* subsample_train;
CvMat* subsample_test;
......
This diff is collapsed.
......@@ -125,7 +125,10 @@ int main(int argc, char** argv)
print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );
printf("======GBTREES=====\n");
gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.05f, 0.6f, 10, true));
if (categorical_response)
gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.1f, 0.8f, 5, false));
else
gbtrees.train( &data, CvGBTreesParams(CvGBTrees::SQUARED_LOSS, 100, 0.1f, 0.8f, 5, false));
print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
}
else
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment