Commit e20d570e authored by P. Druzhkov's avatar P. Druzhkov

brief gbt documentation added. some sample fixes made. code updated.

parent 9c071c6a
This diff is collapsed.
...@@ -15,7 +15,7 @@ Most of the classification and regression algorithms are implemented as C++ clas ...@@ -15,7 +15,7 @@ Most of the classification and regression algorithms are implemented as C++ clas
support_vector_machines support_vector_machines
decision_trees decision_trees
boosting boosting
gradient_boosted_trees
random_trees random_trees
expectation_maximization expectation_maximization
neural_networks neural_networks
...@@ -1571,6 +1571,38 @@ public: ...@@ -1571,6 +1571,38 @@ public:
// Response value prediction // Response value prediction
// //
// API // API
// virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
int k=-1 ) const;
// INPUT
// sample - input sample of the same type as in the training set.
// missing - missing values mask. missing=0 if there are no
// missing values in sample vector.
// weak_responses - predictions of all of the trees.
// not implemented (!)
// slice - part of the ensemble used for prediction.
// slice = CV_WHOLE_SEQ when all trees are used.
// k - number of ensemble used.
// k is in {-1,0,1,..,<count of output classes-1>}.
// in the case of classification problem
// <count of output classes-1> ensembles are built.
// If k = -1 ordinary prediction is the result,
// otherwise function gives the prediction of the
// k-th ensemble only.
// OUTPUT
// RESULT
// Predicted value.
*/
virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
int k=-1 ) const;
/*
// Response value prediction.
// Parallel version (in the case of TBB existence)
//
// API
// virtual float predict( const CvMat* sample, const CvMat* missing=0, // virtual float predict( const CvMat* sample, const CvMat* missing=0,
CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ, CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
int k=-1 ) const; int k=-1 ) const;
...@@ -1599,7 +1631,7 @@ public: ...@@ -1599,7 +1631,7 @@ public:
int k=-1 ) const; int k=-1 ) const;
/* /*
// Delete all temporary data. // Deletes all the data.
// //
// API // API
// virtual void clear(); // virtual void clear();
...@@ -1607,7 +1639,7 @@ public: ...@@ -1607,7 +1639,7 @@ public:
// INPUT // INPUT
// OUTPUT // OUTPUT
// delete data, weak, orig_response, sum_response, // delete data, weak, orig_response, sum_response,
// weak_eval, ubsample_train, subsample_test, // weak_eval, subsample_train, subsample_test,
// sample_idx, missing, lass_labels // sample_idx, missing, lass_labels
// delta = 0.0 // delta = 0.0
// RESULT // RESULT
...@@ -1623,7 +1655,7 @@ public: ...@@ -1623,7 +1655,7 @@ public:
// //
// INPUT // INPUT
// data - dataset // data - dataset
// type - defines which error is to compute^ train (CV_TRAIN_ERROR) or // type - defines which error is to compute: train (CV_TRAIN_ERROR) or
// test (CV_TEST_ERROR). // test (CV_TEST_ERROR).
// OUTPUT // OUTPUT
// resp - vector of predicitons // resp - vector of predicitons
...@@ -1633,7 +1665,6 @@ public: ...@@ -1633,7 +1665,6 @@ public:
virtual float calc_error( CvMLData* _data, int type, virtual float calc_error( CvMLData* _data, int type,
std::vector<float> *resp = 0 ); std::vector<float> *resp = 0 );
/* /*
// //
// Write parameters of the gtb model and data. Write learned model. // Write parameters of the gtb model and data. Write learned model.
...@@ -1852,7 +1883,6 @@ protected: ...@@ -1852,7 +1883,6 @@ protected:
CvMat* orig_response; CvMat* orig_response;
CvMat* sum_response; CvMat* sum_response;
CvMat* sum_response_tmp; CvMat* sum_response_tmp;
CvMat* weak_eval;
CvMat* sample_idx; CvMat* sample_idx;
CvMat* subsample_train; CvMat* subsample_train;
CvMat* subsample_test; CvMat* subsample_test;
......
This diff is collapsed.
...@@ -125,7 +125,10 @@ int main(int argc, char** argv) ...@@ -125,7 +125,10 @@ int main(int argc, char** argv)
print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() ); print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );
printf("======GBTREES=====\n"); printf("======GBTREES=====\n");
gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.05f, 0.6f, 10, true)); if (categorical_response)
gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.1f, 0.8f, 5, false));
else
gbtrees.train( &data, CvGBTreesParams(CvGBTrees::SQUARED_LOSS, 100, 0.1f, 0.8f, 5, false));
print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
} }
else else
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment