Commit 9783cd98 authored by Maria Dimashova's avatar Maria Dimashova

fixed ERTrees name for writing to/reading from xml

parent 8a8b3466
...@@ -125,6 +125,7 @@ CV_INLINE CvParamLattice cvDefaultParamLattice( void ) ...@@ -125,6 +125,7 @@ CV_INLINE CvParamLattice cvDefaultParamLattice( void )
#define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp" #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
#define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn" #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
#define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees" #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
#define CV_TYPE_NAME_ML_ERTREES "opencv-ml-extremely-randomized-trees"
#define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees" #define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees"
#define CV_TRAIN_ERROR 0 #define CV_TRAIN_ERROR 0
...@@ -1041,6 +1042,7 @@ public: ...@@ -1041,6 +1042,7 @@ public:
CvForestTree* get_tree(int i) const; CvForestTree* get_tree(int i) const;
protected: protected:
virtual std::string getName() const;
virtual bool grow_forest( const CvTermCriteria term_crit ); virtual bool grow_forest( const CvTermCriteria term_crit );
...@@ -1114,6 +1116,7 @@ public: ...@@ -1114,6 +1116,7 @@ public:
#endif #endif
virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() ); virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
protected: protected:
virtual std::string getName() const;
virtual bool grow_forest( const CvTermCriteria term_crit ); virtual bool grow_forest( const CvTermCriteria term_crit );
}; };
......
...@@ -1517,6 +1517,11 @@ CvERTrees::~CvERTrees() ...@@ -1517,6 +1517,11 @@ CvERTrees::~CvERTrees()
{ {
} }
std::string CvERTrees::getName() const
{
return CV_TYPE_NAME_ML_ERTREES;
}
bool CvERTrees::train( const CvMat* _train_data, int _tflag, bool CvERTrees::train( const CvMat* _train_data, int _tflag,
const CvMat* _responses, const CvMat* _var_idx, const CvMat* _responses, const CvMat* _var_idx,
const CvMat* _sample_idx, const CvMat* _var_type, const CvMat* _sample_idx, const CvMat* _var_type,
......
...@@ -246,6 +246,10 @@ CvRTrees::~CvRTrees() ...@@ -246,6 +246,10 @@ CvRTrees::~CvRTrees()
clear(); clear();
} }
std::string CvRTrees::getName() const
{
return CV_TYPE_NAME_ML_RTREES;
}
CvMat* CvRTrees::get_active_var_mask() CvMat* CvRTrees::get_active_var_mask()
{ {
...@@ -726,7 +730,8 @@ void CvRTrees::write( CvFileStorage* fs, const char* name ) const ...@@ -726,7 +730,8 @@ void CvRTrees::write( CvFileStorage* fs, const char* name ) const
if( ntrees < 1 || !trees || nsamples < 1 ) if( ntrees < 1 || !trees || nsamples < 1 )
CV_Error( CV_StsBadArg, "Invalid CvRTrees object" ); CV_Error( CV_StsBadArg, "Invalid CvRTrees object" );
cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_RTREES ); std::string modelNodeName = this->getName();
cvStartWriteStruct( fs, name, CV_NODE_MAP, modelNodeName.c_str() );
cvWriteInt( fs, "nclasses", nclasses ); cvWriteInt( fs, "nclasses", nclasses );
cvWriteInt( fs, "nsamples", nsamples ); cvWriteInt( fs, "nsamples", nsamples );
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment