Commit 0cf1de8e authored by Vadim Pisarevsky's avatar Vadim Pisarevsky

Merge pull request #3236 from vpisarev:fix_traincascade

parents a6748466 5947519f
...@@ -198,7 +198,7 @@ bool CvCascadeClassifier::train( const string _cascadeDirName, ...@@ -198,7 +198,7 @@ bool CvCascadeClassifier::train( const string _cascadeDirName,
cout << endl << "===== TRAINING " << i << "-stage =====" << endl; cout << endl << "===== TRAINING " << i << "-stage =====" << endl;
cout << "<BEGIN" << endl; cout << "<BEGIN" << endl;
if ( !updateTrainingSet( tempLeafFARate ) ) if ( !updateTrainingSet( requiredLeafFARate, tempLeafFARate ) )
{ {
cout << "Train dataset for temp stage can not be filled. " cout << "Train dataset for temp stage can not be filled. "
"Branch training terminated." << endl; "Branch training terminated." << endl;
...@@ -284,17 +284,17 @@ int CvCascadeClassifier::predict( int sampleIdx ) ...@@ -284,17 +284,17 @@ int CvCascadeClassifier::predict( int sampleIdx )
return 1; return 1;
} }
bool CvCascadeClassifier::updateTrainingSet( double& acceptanceRatio) bool CvCascadeClassifier::updateTrainingSet( double minimumAcceptanceRatio, double& acceptanceRatio)
{ {
int64 posConsumed = 0, negConsumed = 0; int64 posConsumed = 0, negConsumed = 0;
imgReader.restart(); imgReader.restart();
int posCount = fillPassedSamples( 0, numPos, true, posConsumed ); int posCount = fillPassedSamples( 0, numPos, true, 0, posConsumed );
if( !posCount ) if( !posCount )
return false; return false;
cout << "POS count : consumed " << posCount << " : " << (int)posConsumed << endl; cout << "POS count : consumed " << posCount << " : " << (int)posConsumed << endl;
int proNumNeg = cvRound( ( ((double)numNeg) * ((double)posCount) ) / numPos ); // apply only a fraction of negative samples. double is required since overflow is possible int proNumNeg = cvRound( ( ((double)numNeg) * ((double)posCount) ) / numPos ); // apply only a fraction of negative samples. double is required since overflow is possible
int negCount = fillPassedSamples( posCount, proNumNeg, false, negConsumed ); int negCount = fillPassedSamples( posCount, proNumNeg, false, minimumAcceptanceRatio, negConsumed );
if ( !negCount ) if ( !negCount )
return false; return false;
...@@ -304,7 +304,7 @@ bool CvCascadeClassifier::updateTrainingSet( double& acceptanceRatio) ...@@ -304,7 +304,7 @@ bool CvCascadeClassifier::updateTrainingSet( double& acceptanceRatio)
return true; return true;
} }
int CvCascadeClassifier::fillPassedSamples( int first, int count, bool isPositive, int64& consumed ) int CvCascadeClassifier::fillPassedSamples( int first, int count, bool isPositive, double minimumAcceptanceRatio, int64& consumed )
{ {
int getcount = 0; int getcount = 0;
Mat img(cascadeParams.winSize, CV_8UC1); Mat img(cascadeParams.winSize, CV_8UC1);
...@@ -312,6 +312,9 @@ int CvCascadeClassifier::fillPassedSamples( int first, int count, bool isPositiv ...@@ -312,6 +312,9 @@ int CvCascadeClassifier::fillPassedSamples( int first, int count, bool isPositiv
{ {
for( ; ; ) for( ; ; )
{ {
if( consumed != 0 && ((double)getcount+1)/(double)(int64)consumed <= minimumAcceptanceRatio )
return getcount;
bool isGetImg = isPositive ? imgReader.getPos( img ) : bool isGetImg = isPositive ? imgReader.getPos( img ) :
imgReader.getNeg( img ); imgReader.getNeg( img );
if( !isGetImg ) if( !isGetImg )
......
...@@ -101,8 +101,8 @@ private: ...@@ -101,8 +101,8 @@ private:
int predict( int sampleIdx ); int predict( int sampleIdx );
void save( const std::string cascadeDirName, bool baseFormat = false ); void save( const std::string cascadeDirName, bool baseFormat = false );
bool load( const std::string cascadeDirName ); bool load( const std::string cascadeDirName );
bool updateTrainingSet( double& acceptanceRatio ); bool updateTrainingSet( double minimumAcceptanceRatio, double& acceptanceRatio );
int fillPassedSamples( int first, int count, bool isPositive, int64& consumed ); int fillPassedSamples( int first, int count, bool isPositive, double requiredAcceptanceRatio, int64& consumed );
void writeParams( cv::FileStorage &fs ) const; void writeParams( cv::FileStorage &fs ) const;
void writeStages( cv::FileStorage &fs, const cv::Mat& featureMap ) const; void writeStages( cv::FileStorage &fs, const cv::Mat& featureMap ) const;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment