Commit b5a71db7 authored by Maria Dimashova's avatar Maria Dimashova

modified FernClassifier::train(); remove old RTreeClassifier and added new…

modified FernClassifier::train(); remove old RTreeClassifier and added new implementation CalonderClassifier; removed old find_obj_calonder and added new one
parent 1135bc24
This diff is collapsed.
...@@ -228,6 +228,10 @@ void SurfDescriptorExtractor::write( FileStorage &fs ) const ...@@ -228,6 +228,10 @@ void SurfDescriptorExtractor::write( FileStorage &fs ) const
fs << "extended" << surf.extended; fs << "extended" << surf.extended;
} }
/****************************************************************************************\
* Factory functions for descriptor extractor and matcher creating *
\****************************************************************************************/
Ptr<DescriptorExtractor> createDescriptorExtractor( const string& descriptorExtractorType ) Ptr<DescriptorExtractor> createDescriptorExtractor( const string& descriptorExtractorType )
{ {
DescriptorExtractor* de = 0; DescriptorExtractor* de = 0;
...@@ -270,7 +274,9 @@ Ptr<DescriptorMatcher> createDescriptorMatcher( const string& descriptorMatcherT ...@@ -270,7 +274,9 @@ Ptr<DescriptorMatcher> createDescriptorMatcher( const string& descriptorMatcherT
return dm; return dm;
} }
/****************************************************************************************\
* BruteForceMatcher L2 specialization *
\****************************************************************************************/
template<> template<>
void BruteForceMatcher<L2<float> >::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2, void BruteForceMatcher<L2<float> >::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& /*mask*/, vector<int>& matches ) const const Mat& /*mask*/, vector<int>& matches ) const
...@@ -317,7 +323,6 @@ void BruteForceMatcher<L2<float> >::matchImpl( const Mat& descriptors_1, const M ...@@ -317,7 +323,6 @@ void BruteForceMatcher<L2<float> >::matchImpl( const Mat& descriptors_1, const M
#endif #endif
} }
/****************************************************************************************\ /****************************************************************************************\
* GenericDescriptorMatch * * GenericDescriptorMatch *
\****************************************************************************************/ \****************************************************************************************/
...@@ -394,6 +399,9 @@ void GenericDescriptorMatch::clear() ...@@ -394,6 +399,9 @@ void GenericDescriptorMatch::clear()
collection.clear(); collection.clear();
} }
/*
* Factory function for GenericDescriptorMatch creating
*/
Ptr<GenericDescriptorMatch> createGenericDescriptorMatch( const string& genericDescritptorMatchType, const string &paramsFilename ) Ptr<GenericDescriptorMatch> createGenericDescriptorMatch( const string& genericDescritptorMatchType, const string &paramsFilename )
{ {
GenericDescriptorMatch *descriptorMatch = 0; GenericDescriptorMatch *descriptorMatch = 0;
...@@ -409,7 +417,7 @@ Ptr<GenericDescriptorMatch> createGenericDescriptorMatch( const string& genericD ...@@ -409,7 +417,7 @@ Ptr<GenericDescriptorMatch> createGenericDescriptorMatch( const string& genericD
} }
else if( ! genericDescritptorMatchType.compare ("CALONDER") ) else if( ! genericDescritptorMatchType.compare ("CALONDER") )
{ {
descriptorMatch = new CalonderDescriptorMatch (); //descriptorMatch = new CalonderDescriptorMatch ();
} }
if( !paramsFilename.empty() && descriptorMatch != 0 ) if( !paramsFilename.empty() && descriptorMatch != 0 )
...@@ -626,6 +634,7 @@ void OneWayDescriptorMatch::clear () ...@@ -626,6 +634,7 @@ void OneWayDescriptorMatch::clear ()
/****************************************************************************************\ /****************************************************************************************\
* CalonderDescriptorMatch * * CalonderDescriptorMatch *
\****************************************************************************************/ \****************************************************************************************/
#if 0
CalonderDescriptorMatch::Params::Params( const RNG& _rng, const PatchGenerator& _patchGen, CalonderDescriptorMatch::Params::Params( const RNG& _rng, const PatchGenerator& _patchGen,
int _numTrees, int _depth, int _views, int _numTrees, int _depth, int _views,
size_t _reducedNumDim, size_t _reducedNumDim,
...@@ -774,6 +783,7 @@ void CalonderDescriptorMatch::write( FileStorage& fs ) const ...@@ -774,6 +783,7 @@ void CalonderDescriptorMatch::write( FileStorage& fs ) const
fs << "numQuantBits" << params.numQuantBits; fs << "numQuantBits" << params.numQuantBits;
fs << "printStatus" << params.printStatus; fs << "printStatus" << params.printStatus;
} }
#endif
/****************************************************************************************\ /****************************************************************************************\
* FernDescriptorMatch * * FernDescriptorMatch *
...@@ -827,22 +837,13 @@ void FernDescriptorMatch::trainFernClassifier() ...@@ -827,22 +837,13 @@ void FernDescriptorMatch::trainFernClassifier()
{ {
assert( params.filename.empty() ); assert( params.filename.empty() );
vector<Point2f> points; vector<vector<Point2f> > points;
vector<Ptr<Mat> > refimgs; for( size_t imgIdx = 0; imgIdx < collection.images.size(); imgIdx++ )
vector<int> labels; KeyPoint::convert( collection.points[imgIdx], points[imgIdx] );
for( size_t imageIdx = 0; imageIdx < collection.images.size(); imageIdx++ )
{
for( size_t pointIdx = 0; pointIdx < collection.points[imageIdx].size(); pointIdx++ )
{
refimgs.push_back(new Mat (collection.images[imageIdx]));
points.push_back(collection.points[imageIdx][pointIdx].pt);
labels.push_back((int)pointIdx);
}
}
classifier = new FernClassifier( points, refimgs, labels, params.nclasses, params.patchSize, classifier = new FernClassifier( points, collection.images, vector<vector<int> >(), 0, // each points is a class
params.signatureSize, params.nstructs, params.structSize, params.nviews, params.patchSize, params.signatureSize, params.nstructs, params.structSize,
params.compressionMethod, params.patchGenerator ); params.nviews, params.compressionMethod, params.patchGenerator );
} }
} }
...@@ -966,4 +967,59 @@ void FernDescriptorMatch::clear () ...@@ -966,4 +967,59 @@ void FernDescriptorMatch::clear ()
classifier.release(); classifier.release();
} }
/****************************************************************************************\
* VectorDescriptorMatch *
\****************************************************************************************/
void VectorDescriptorMatch::add( const Mat& image, vector<KeyPoint>& keypoints )
{
Mat descriptors;
extractor->compute( image, keypoints, descriptors );
matcher->add( descriptors );
collection.add( Mat(), keypoints );
};
void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<int>& keypointIndices )
{
Mat descriptors;
extractor->compute( image, points, descriptors );
matcher->match( descriptors, keypointIndices );
};
void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<DMatch>& matches )
{
Mat descriptors;
extractor->compute( image, points, descriptors );
matcher->match( descriptors, matches );
}
void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points,
vector<vector<DMatch> >& matches, float threshold )
{
Mat descriptors;
extractor->compute( image, points, descriptors );
matcher->match( descriptors, matches, threshold );
}
void VectorDescriptorMatch::clear()
{
GenericDescriptorMatch::clear();
matcher->clear();
}
void VectorDescriptorMatch::read( const FileNode& fn )
{
GenericDescriptorMatch::read(fn);
extractor->read (fn);
}
void VectorDescriptorMatch::write (FileStorage& fs) const
{
GenericDescriptorMatch::write(fs);
extractor->write (fs);
}
} }
...@@ -692,9 +692,9 @@ Size FernClassifier::getPatchSize() const ...@@ -692,9 +692,9 @@ Size FernClassifier::getPatchSize() const
} }
FernClassifier::FernClassifier(const vector<Point2f>& points, FernClassifier::FernClassifier(const vector<vector<Point2f> >& points,
const vector<Ptr<Mat> >& refimgs, const vector<Mat>& refimgs,
const vector<int>& labels, const vector<vector<int> >& labels,
int _nclasses, int _patchSize, int _nclasses, int _patchSize,
int _signatureSize, int _nstructs, int _signatureSize, int _nstructs,
int _structSize, int _nviews, int _compressionMethod, int _structSize, int _nviews, int _compressionMethod,
...@@ -829,43 +829,58 @@ void FernClassifier::prepare(int _nclasses, int _patchSize, int _signatureSize, ...@@ -829,43 +829,58 @@ void FernClassifier::prepare(int _nclasses, int _patchSize, int _signatureSize,
} }
} }
static int calcNumPoints( const vector<vector<Point2f> >& points )
{
int count = 0;
for( size_t i = 0; i < points.size(); i++ )
count += points[i].size();
return count;
}
void FernClassifier::train(const vector<Point2f>& points, void FernClassifier::train(const vector<vector<Point2f> >& points,
const vector<Ptr<Mat> >& refimgs, const vector<Mat>& refimgs,
const vector<int>& labels, const vector<vector<int> >& labels,
int _nclasses, int _patchSize, int _nclasses, int _patchSize,
int _signatureSize, int _nstructs, int _signatureSize, int _nstructs,
int _structSize, int _nviews, int _compressionMethod, int _structSize, int _nviews, int _compressionMethod,
const PatchGenerator& patchGenerator) const PatchGenerator& patchGenerator)
{ {
_nclasses = _nclasses > 0 ? _nclasses : (int)points.size(); CV_Assert( points.size() == refimgs.size() );
int numPoints = calcNumPoints( points );
_nclasses = (!labels.empty() && _nclasses>0) ? _nclasses : numPoints;
CV_Assert( labels.empty() || labels.size() == points.size() ); CV_Assert( labels.empty() || labels.size() == points.size() );
prepare(_nclasses, _patchSize, _signatureSize, _nstructs, prepare(_nclasses, _patchSize, _signatureSize, _nstructs,
_structSize, _nviews, _compressionMethod); _structSize, _nviews, _compressionMethod);
// pass all the views of all the samples through the generated trees and accumulate // pass all the views of all the samples through the generated trees and accumulate
// the statistics (posterior probabilities) in leaves. // the statistics (posterior probabilities) in leaves.
Mat patch; Mat patch;
int i, j, nsamples = (int)points.size();
RNG& rng = theRNG(); RNG& rng = theRNG();
for( i = 0; i < nsamples; i++ ) int globalPointIdx = 0;
for( size_t imgIdx = 0; imgIdx < points.size(); imgIdx++ )
{
const Point2f* imgPoints = &points[imgIdx][0];
const int* imgLabels = labels.empty() ? 0 : &labels[imgIdx][0];
for( size_t pointIdx = 0; pointIdx < points[imgIdx].size(); pointIdx++, globalPointIdx++ )
{ {
Point2f pt = points[i]; Point2f pt = imgPoints[pointIdx];
const Mat& src = *refimgs[i]; const Mat& src = refimgs[imgIdx];
int classId = labels.empty() ? i : labels[i]; int classId = imgLabels==0 ? globalPointIdx : imgLabels[pointIdx];
if( verbose && (i+1)*progressBarSize/nsamples != i*progressBarSize/nsamples ) if( verbose && (globalPointIdx+1)*progressBarSize/numPoints != globalPointIdx*progressBarSize/numPoints )
putchar('.'); putchar('.');
CV_Assert( 0 <= classId && classId < nclasses ); CV_Assert( 0 <= classId && classId < nclasses );
classCounters[classId] += _nviews; classCounters[classId] += _nviews;
for( j = 0; j < _nviews; j++ ) for( int v = 0; v < _nviews; v++ )
{ {
patchGenerator(src, pt, patch, patchSize, rng); patchGenerator(src, pt, patch, patchSize, rng);
for( int f = 0; f < nstructs; f++ ) for( int f = 0; f < nstructs; f++ )
posteriors[getLeaf(f, patch)*nclasses + classId]++; posteriors[getLeaf(f, patch)*nclasses + classId]++;
} }
} }
}
if( verbose ) if( verbose )
putchar('\n'); putchar('\n');
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment