Commit b5a71db7 authored by Maria Dimashova's avatar Maria Dimashova

modified FernClassifier::train(); remove old RTreeClassifier and added new…

modified FernClassifier::train(); remove old RTreeClassifier and added new implementation CalonderClassifier; removed old find_obj_calonder and added new one
parent 1135bc24
This diff is collapsed.
......@@ -228,6 +228,10 @@ void SurfDescriptorExtractor::write( FileStorage &fs ) const
fs << "extended" << surf.extended;
}
/****************************************************************************************\
* Factory functions for descriptor extractor and matcher creating *
\****************************************************************************************/
Ptr<DescriptorExtractor> createDescriptorExtractor( const string& descriptorExtractorType )
{
DescriptorExtractor* de = 0;
......@@ -270,7 +274,9 @@ Ptr<DescriptorMatcher> createDescriptorMatcher( const string& descriptorMatcherT
return dm;
}
/****************************************************************************************\
* BruteForceMatcher L2 specialization *
\****************************************************************************************/
template<>
void BruteForceMatcher<L2<float> >::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& /*mask*/, vector<int>& matches ) const
......@@ -317,7 +323,6 @@ void BruteForceMatcher<L2<float> >::matchImpl( const Mat& descriptors_1, const M
#endif
}
/****************************************************************************************\
* GenericDescriptorMatch *
\****************************************************************************************/
......@@ -394,6 +399,9 @@ void GenericDescriptorMatch::clear()
collection.clear();
}
/*
* Factory function for GenericDescriptorMatch creating
*/
Ptr<GenericDescriptorMatch> createGenericDescriptorMatch( const string& genericDescritptorMatchType, const string &paramsFilename )
{
GenericDescriptorMatch *descriptorMatch = 0;
......@@ -409,7 +417,7 @@ Ptr<GenericDescriptorMatch> createGenericDescriptorMatch( const string& genericD
}
else if( ! genericDescritptorMatchType.compare ("CALONDER") )
{
descriptorMatch = new CalonderDescriptorMatch ();
//descriptorMatch = new CalonderDescriptorMatch ();
}
if( !paramsFilename.empty() && descriptorMatch != 0 )
......@@ -626,6 +634,7 @@ void OneWayDescriptorMatch::clear ()
/****************************************************************************************\
* CalonderDescriptorMatch *
\****************************************************************************************/
#if 0
CalonderDescriptorMatch::Params::Params( const RNG& _rng, const PatchGenerator& _patchGen,
int _numTrees, int _depth, int _views,
size_t _reducedNumDim,
......@@ -774,6 +783,7 @@ void CalonderDescriptorMatch::write( FileStorage& fs ) const
fs << "numQuantBits" << params.numQuantBits;
fs << "printStatus" << params.printStatus;
}
#endif
/****************************************************************************************\
* FernDescriptorMatch *
......@@ -827,22 +837,13 @@ void FernDescriptorMatch::trainFernClassifier()
{
assert( params.filename.empty() );
vector<Point2f> points;
vector<Ptr<Mat> > refimgs;
vector<int> labels;
for( size_t imageIdx = 0; imageIdx < collection.images.size(); imageIdx++ )
{
for( size_t pointIdx = 0; pointIdx < collection.points[imageIdx].size(); pointIdx++ )
{
refimgs.push_back(new Mat (collection.images[imageIdx]));
points.push_back(collection.points[imageIdx][pointIdx].pt);
labels.push_back((int)pointIdx);
}
}
vector<vector<Point2f> > points;
for( size_t imgIdx = 0; imgIdx < collection.images.size(); imgIdx++ )
KeyPoint::convert( collection.points[imgIdx], points[imgIdx] );
classifier = new FernClassifier( points, refimgs, labels, params.nclasses, params.patchSize,
params.signatureSize, params.nstructs, params.structSize, params.nviews,
params.compressionMethod, params.patchGenerator );
classifier = new FernClassifier( points, collection.images, vector<vector<int> >(), 0, // each points is a class
params.patchSize, params.signatureSize, params.nstructs, params.structSize,
params.nviews, params.compressionMethod, params.patchGenerator );
}
}
......@@ -966,4 +967,59 @@ void FernDescriptorMatch::clear ()
classifier.release();
}
/****************************************************************************************\
* VectorDescriptorMatch *
\****************************************************************************************/
void VectorDescriptorMatch::add( const Mat& image, vector<KeyPoint>& keypoints )
{
Mat descriptors;
extractor->compute( image, keypoints, descriptors );
matcher->add( descriptors );
collection.add( Mat(), keypoints );
};
void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<int>& keypointIndices )
{
Mat descriptors;
extractor->compute( image, points, descriptors );
matcher->match( descriptors, keypointIndices );
};
void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<DMatch>& matches )
{
Mat descriptors;
extractor->compute( image, points, descriptors );
matcher->match( descriptors, matches );
}
void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points,
vector<vector<DMatch> >& matches, float threshold )
{
Mat descriptors;
extractor->compute( image, points, descriptors );
matcher->match( descriptors, matches, threshold );
}
void VectorDescriptorMatch::clear()
{
GenericDescriptorMatch::clear();
matcher->clear();
}
void VectorDescriptorMatch::read( const FileNode& fn )
{
GenericDescriptorMatch::read(fn);
extractor->read (fn);
}
void VectorDescriptorMatch::write (FileStorage& fs) const
{
GenericDescriptorMatch::write(fs);
extractor->write (fs);
}
}
......@@ -692,9 +692,9 @@ Size FernClassifier::getPatchSize() const
}
FernClassifier::FernClassifier(const vector<Point2f>& points,
const vector<Ptr<Mat> >& refimgs,
const vector<int>& labels,
FernClassifier::FernClassifier(const vector<vector<Point2f> >& points,
const vector<Mat>& refimgs,
const vector<vector<int> >& labels,
int _nclasses, int _patchSize,
int _signatureSize, int _nstructs,
int _structSize, int _nviews, int _compressionMethod,
......@@ -829,43 +829,58 @@ void FernClassifier::prepare(int _nclasses, int _patchSize, int _signatureSize,
}
}
static int calcNumPoints( const vector<vector<Point2f> >& points )
{
int count = 0;
for( size_t i = 0; i < points.size(); i++ )
count += points[i].size();
return count;
}
void FernClassifier::train(const vector<Point2f>& points,
const vector<Ptr<Mat> >& refimgs,
const vector<int>& labels,
void FernClassifier::train(const vector<vector<Point2f> >& points,
const vector<Mat>& refimgs,
const vector<vector<int> >& labels,
int _nclasses, int _patchSize,
int _signatureSize, int _nstructs,
int _structSize, int _nviews, int _compressionMethod,
const PatchGenerator& patchGenerator)
{
_nclasses = _nclasses > 0 ? _nclasses : (int)points.size();
CV_Assert( points.size() == refimgs.size() );
int numPoints = calcNumPoints( points );
_nclasses = (!labels.empty() && _nclasses>0) ? _nclasses : numPoints;
CV_Assert( labels.empty() || labels.size() == points.size() );
prepare(_nclasses, _patchSize, _signatureSize, _nstructs,
_structSize, _nviews, _compressionMethod);
// pass all the views of all the samples through the generated trees and accumulate
// the statistics (posterior probabilities) in leaves.
Mat patch;
int i, j, nsamples = (int)points.size();
RNG& rng = theRNG();
for( i = 0; i < nsamples; i++ )
int globalPointIdx = 0;
for( size_t imgIdx = 0; imgIdx < points.size(); imgIdx++ )
{
const Point2f* imgPoints = &points[imgIdx][0];
const int* imgLabels = labels.empty() ? 0 : &labels[imgIdx][0];
for( size_t pointIdx = 0; pointIdx < points[imgIdx].size(); pointIdx++, globalPointIdx++ )
{
Point2f pt = points[i];
const Mat& src = *refimgs[i];
int classId = labels.empty() ? i : labels[i];
if( verbose && (i+1)*progressBarSize/nsamples != i*progressBarSize/nsamples )
Point2f pt = imgPoints[pointIdx];
const Mat& src = refimgs[imgIdx];
int classId = imgLabels==0 ? globalPointIdx : imgLabels[pointIdx];
if( verbose && (globalPointIdx+1)*progressBarSize/numPoints != globalPointIdx*progressBarSize/numPoints )
putchar('.');
CV_Assert( 0 <= classId && classId < nclasses );
classCounters[classId] += _nviews;
for( j = 0; j < _nviews; j++ )
for( int v = 0; v < _nviews; v++ )
{
patchGenerator(src, pt, patch, patchSize, rng);
for( int f = 0; f < nstructs; f++ )
posteriors[getLeaf(f, patch)*nclasses + classId]++;
}
}
}
if( verbose )
putchar('\n');
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment