Commit f7495e58 authored by Ilya Lysenkov's avatar Ilya Lysenkov

Updated GBT docs

parent 0f0573e7
......@@ -104,36 +104,30 @@ CvGBTreesParams
---------------
.. ocv:class:: CvGBTreesParams
GBT training parameters. ::
struct CvGBTreesParams : public CvDTreeParams
{
int weak_count;
int loss_function_type;
float subsample_portion;
float shrinkage;
CvGBTreesParams();
CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
float subsample_portion, int max_depth, bool use_surrogates );
};
GBT training parameters.
The structure contains parameters for each sigle decision tree in the ensemble,
as well as the whole model characteristics. The structure is derived from
:ocv:class:`CvDTreeParams` but not all of the decision tree parameters are supported:
cross-validation, pruning, and class priorities are not used.
:param weak_count: Count of boosting algorithm iterations. ``weak_count*K`` is the total
count of trees in the GBT model, where ``K`` is the output classes count
(equal to one in case of a regression).
CvGBTreesParams::CvGBTreesParams
--------------------------------
.. ocv:function:: CvGBTreesParams::CvGBTreesParams()
.. ocv:function:: CvGBTreesParams::CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage, float subsample_portion, int max_depth, bool use_surrogates )
:param loss_function_type: Type of the loss function used for training
(see :ref:`Training the GBT model`). It must be one of the
following types: ``CvGBTrees::SQUARED_LOSS``, ``CvGBTrees::ABSOLUTE_LOSS``,
``CvGBTrees::HUBER_LOSS``, ``CvGBTrees::DEVIANCE_LOSS``. The first three
types are used for regression problems, and the last one for
classification.
:param weak_count: Count of boosting algorithm iterations. ``weak_count*K`` is the total
count of trees in the GBT model, where ``K`` is the output classes count
(equal to one in case of a regression).
:param shrinkage: Regularization parameter (see :ref:`Training the GBT model`).
:param subsample_portion: Portion of the whole training set used for each algorithm iteration.
......@@ -155,81 +149,29 @@ CvGBTrees
---------
.. ocv:class:: CvGBTrees
GBT model. ::
class CvGBTrees : public CvStatModel
{
public:
enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
CvGBTrees();
CvGBTrees( const cv::Mat& trainData, int tflag,
const Mat& responses, const Mat& varIdx=Mat(),
const Mat& sampleIdx=Mat(), const cv::Mat& varType=Mat(),
const Mat& missingDataMask=Mat(),
CvGBTreesParams params=CvGBTreesParams() );
virtual ~CvGBTrees();
virtual bool train( const Mat& trainData, int tflag,
const Mat& responses, const Mat& varIdx=Mat(),
const Mat& sampleIdx=Mat(), const Mat& varType=Mat(),
const Mat& missingDataMask=Mat(),
CvGBTreesParams params=CvGBTreesParams(),
bool update=false );
virtual bool train( CvMLData* data,
CvGBTreesParams params=CvGBTreesParams(),
bool update=false );
virtual float predict( const Mat& sample, const Mat& missing=Mat(),
const Range& slice = Range::all(),
int k=-1 ) const;
virtual void clear();
virtual float calc_error( CvMLData* _data, int type,
std::vector<float> *resp = 0 );
virtual void write( CvFileStorage* fs, const char* name ) const;
virtual void read( CvFileStorage* fs, CvFileNode* node );
protected:
CvDTreeTrainData* data;
CvGBTreesParams params;
CvSeq** weak;
Mat& orig_response;
Mat& sum_response;
Mat& sum_response_tmp;
Mat& weak_eval;
Mat& sample_idx;
Mat& subsample_train;
Mat& subsample_test;
Mat& missing;
Mat& class_labels;
RNG* rng;
int class_count;
float delta;
float base_value;
...
};
The class implements the Gradient boosted tree model as described in the beginning of this section.
CvGBTrees::CvGBTrees
--------------------
Default and training constructors.
.. ocv:function:: CvGBTrees::CvGBTrees()
.. ocv:function:: CvGBTrees::CvGBTrees( const Mat& trainData, int tflag, const Mat& responses, const Mat& varIdx=Mat(), const Mat& sampleIdx=Mat(), const Mat& varType=Mat(), const Mat& missingDataMask=Mat(), CvGBTreesParams params=CvGBTreesParams() )
.. ocv:cfunction:: CvGBTrees::CvGBTrees( const CvMat* trainData, int tflag, const CvMat* responses, const CvMat* varIdx=0, const CvMat* sampleIdx=0, const CvMat* varType=0, const CvMat* missingDataMask=0, CvGBTreesParams params=CvGBTreesParams() )
The constructors follow conventions of :ocv:func:`CvStatModel::CvStatModel`. See :ocv:func:`CvStatModel::train` for parameters descriptions.
CvGBTrees::train
----------------
<<<<<<< .mine
Trains a Gradient boosted tree model.
.. ocv:function:: bool train(const Mat & trainData, int tflag, const Mat & responses, const Mat & varIdx=Mat(), const Mat & sampleIdx=Mat(), const Mat & varType=Mat(), const Mat & missingDataMask=Mat(), CvGBTreesParams params=CvGBTreesParams(), bool update=false)
=======
.. ocv:function:: bool CvGBTrees::train(const Mat& trainData, int tflag, const Mat& responses, const Mat& varIdx=Mat(), const Mat& sampleIdx=Mat(), const Mat& varType=Mat(), const Mat& missingDataMask=Mat(), CvGBTreesParams params=CvGBTreesParams(), bool update=false)
>>>>>>> .r5669
.. ocv:function:: bool CvGBTrees::train(CvMLData* data, CvGBTreesParams params=CvGBTreesParams(), bool update=false)
.. ocv:cfunction:: bool CvGBTrees::train( const CvMat* trainData, int tflag, const CvMat* responses, const CvMat* varIdx=0, const CvMat* sampleIdx=0, const CvMat* varType=0, const CvMat* missingDataMask=0, CvGBTreesParams params=CvGBTreesParams(), bool update=false )
.. ocv:cfunction:: bool CvGBTrees::train(CvMLData* data, CvGBTreesParams params=CvGBTreesParams(), bool update=false)
The first train method follows the common template (see :ocv:func:`CvStatModel::train`).
Both ``tflag`` values (``CV_ROW_SAMPLE``, ``CV_COL_SAMPLE``) are supported.
......@@ -248,14 +190,12 @@ as a :ocv:class:`CvGBTreesParams` structure.
CvGBTrees::predict
------------------
<<<<<<< .mine
Predicts a response for an input sample.
=======
.. ocv:function:: float CvGBTrees::predict(const Mat& sample, const Mat& missing=Mat(), const Range& slice = Range::all(), int k=-1) const
>>>>>>> .r5669
.. ocv:function:: float predict(const Mat & sample, const Mat & missing=Mat(), const Range & slice = Range::all(), int k=-1) const
.. ocv:cfunction:: float CvGBTrees::predict( const CvMat* sample, const CvMat* missing=0, CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ, int k=-1 ) const
:param sample: Input feature vector that has the same format as every training set
element. If not all the variables were actualy used during training,
``sample`` contains forged values at the appropriate places.
......@@ -305,7 +245,7 @@ CvGBTrees::calc_error
Calculates a training or testing error.
.. ocv:function:: float CvGBTrees::calc_error( CvMLData* _data, int type, std::vector<float> *resp = 0 )
:param _data: Data set.
:param type: Parameter defining the error that should be computed: train (``CV_TRAIN_ERROR``) or test
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment