Commit 48ea65e6 authored by Maria Dimashova's avatar Maria Dimashova

fixed traincascade for ordered features

parent b4f17ab7
...@@ -1066,7 +1066,7 @@ CvBoost::train( const CvMat* _train_data, int _tflag, ...@@ -1066,7 +1066,7 @@ CvBoost::train( const CvMat* _train_data, int _tflag,
if( !tree->train( data, subsample_mask, this ) ) if( !tree->train( data, subsample_mask, this ) )
{ {
delete tree; delete tree;
continue; break;
} }
//cvCheckArr( get_weak_response()); //cvCheckArr( get_weak_response());
cvSeqPush( weak, &tree ); cvSeqPush( weak, &tree );
......
...@@ -718,7 +718,7 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx ) ...@@ -718,7 +718,7 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
// co - array of count/offset pairs (to handle duplicated values in _subsample_idx) // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
int* co, cur_ofs = 0; int* co, cur_ofs = 0;
int vi, i; int vi, i;
int work_var_count = get_work_var_count(); int workVarCount = get_work_var_count();
int count = isubsample_idx->rows + isubsample_idx->cols - 1; int count = isubsample_idx->rows + isubsample_idx->cols - 1;
root = new_node( 0, count, 1, 0 ); root = new_node( 0, count, 1, 0 );
...@@ -740,7 +740,7 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx ) ...@@ -740,7 +740,7 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
} }
cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float))); cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
for( vi = 0; vi < work_var_count; vi++ ) for( vi = 0; vi < workVarCount; vi++ )
{ {
int ci = get_var_type(vi); int ci = get_var_type(vi);
...@@ -841,14 +841,14 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx ) ...@@ -841,14 +841,14 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
if (is_buf_16u) if (is_buf_16u)
{ {
unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
get_work_var_count()*sample_count + root->offset); workVarCount*sample_count + root->offset);
for (i = 0; i < count; i++) for (i = 0; i < count; i++)
sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]]; sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
} }
else else
{ {
int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols + int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols +
get_work_var_count()*sample_count + root->offset; workVarCount*sample_count + root->offset;
for (i = 0; i < count; i++) for (i = 0; i < count; i++)
sample_idx_dst[i] = sample_idx_src[sidx[i]]; sample_idx_dst[i] = sample_idx_src[sidx[i]];
} }
...@@ -1622,13 +1622,19 @@ bool CvDTree::do_train( const CvMat* _subsample_idx ) ...@@ -1622,13 +1622,19 @@ bool CvDTree::do_train( const CvMat* _subsample_idx )
CV_CALL( try_split_node(root)); CV_CALL( try_split_node(root));
if( data->params.cv_folds > 0 ) if( root->split )
CV_CALL( prune_cv() ); {
CV_Assert( root->left );
CV_Assert( root->right );
if( data->params.cv_folds > 0 )
CV_CALL( prune_cv() );
if( !data->shared ) if( !data->shared )
data->free_train_data(); data->free_train_data();
result = true; result = true;
}
__END__; __END__;
......
This diff is collapsed.
...@@ -32,6 +32,8 @@ struct CvCascadeBoostTrainData : CvDTreeTrainData ...@@ -32,6 +32,8 @@ struct CvCascadeBoostTrainData : CvDTreeTrainData
const CvDTreeParams& _params=CvDTreeParams() ); const CvDTreeParams& _params=CvDTreeParams() );
void precalculate(); void precalculate();
virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
virtual const int* get_class_labels( CvDTreeNode* n, int* labelsBuf ); virtual const int* get_class_labels( CvDTreeNode* n, int* labelsBuf );
virtual const int* get_cv_labels( CvDTreeNode* n, int* labelsBuf); virtual const int* get_cv_labels( CvDTreeNode* n, int* labelsBuf);
virtual const int* get_sample_indices( CvDTreeNode* n, int* indicesBuf ); virtual const int* get_sample_indices( CvDTreeNode* n, int* indicesBuf );
...@@ -67,7 +69,7 @@ public: ...@@ -67,7 +69,7 @@ public:
const CvCascadeBoostParams& _params=CvCascadeBoostParams() ); const CvCascadeBoostParams& _params=CvCascadeBoostParams() );
virtual float predict( int sampleIdx, bool returnSum = false ) const; virtual float predict( int sampleIdx, bool returnSum = false ) const;
float getThreshold() const { return threshold; }; float getThreshold() const { return threshold; }
void write( FileStorage &fs, const Mat& featureMap ) const; void write( FileStorage &fs, const Mat& featureMap ) const;
bool read( const FileNode &node, const CvFeatureEvaluator* _featureEvaluator, bool read( const FileNode &node, const CvFeatureEvaluator* _featureEvaluator,
const CvCascadeBoostParams& _params ); const CvCascadeBoostParams& _params );
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment