Commit e0ee2f76 authored by Vadim Pisarevsky's avatar Vadim Pisarevsky

Merge pull request #8116 from mrquorr:master

parents f46fa6e0 d8425d88
......@@ -1206,6 +1206,17 @@ public:
*/
CV_WRAP virtual Mat getVarImportance() const = 0;
/** Returns the result of each individual tree in the forest.
In case the model is a regression problem, the method will return each of the trees'
results for each of the sample cases. If the model is a classifier, it will return
a Mat with samples + 1 rows, where the first row gives the class number and the
following rows return the votes each class had for each sample.
@param samples Array containg the samples for which votes will be calculated.
@param results Array where the result of the calculation will be written.
@param flags Flags for defining the type of RTrees.
*/
CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const;
/** Creates the empty model.
Use StatModel::train to train the model, StatModel::train to create and train the model,
Algorithm::load to load the pre-trained model.
......
......@@ -349,6 +349,60 @@ public:
}
}
void getVotes( InputArray input, OutputArray output, int flags ) const
{
CV_Assert( !roots.empty() );
int nclasses = (int)classLabels.size(), ntrees = (int)roots.size();
Mat samples = input.getMat(), results;
int i, j, nsamples = samples.rows;
int predictType = flags & PREDICT_MASK;
if( predictType == PREDICT_AUTO )
{
predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
PREDICT_SUM : PREDICT_MAX_VOTE;
}
if( predictType == PREDICT_SUM )
{
output.create(nsamples, ntrees, CV_32F);
results = output.getMat();
for( i = 0; i < nsamples; i++ )
{
for( j = 0; j < ntrees; j++ )
{
float val = predictTrees( Range(j, j+1), samples.row(i), flags);
results.at<float> (i, j) = val;
}
}
} else
{
vector<int> votes;
output.create(nsamples+1, nclasses, CV_32S);
results = output.getMat();
for ( j = 0; j < nclasses; j++)
{
results.at<int> (0, j) = classLabels[j];
}
for( i = 0; i < nsamples; i++ )
{
votes.clear();
for( j = 0; j < ntrees; j++ )
{
int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags);
votes.push_back(val);
}
for ( j = 0; j < nclasses; j++)
{
results.at<int> (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]);
}
}
}
}
RTreeParams rparams;
double oobError;
vector<float> varImportance;
......@@ -401,6 +455,11 @@ public:
impl.read(fn);
}
void getVotes_( InputArray samples, OutputArray results, int flags ) const
{
impl.getVotes(samples, results, flags);
}
Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
int getVarCount() const { return impl.getVarCount(); }
......@@ -427,6 +486,14 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
return Algorithm::load<RTrees>(filepath, nodeName);
}
void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
{
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
if(!this_)
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
return this_->getVotes_(input, output, flags);
}
}}
// End of file.
......@@ -172,4 +172,49 @@ TEST(ML_NBAYES, regression_5911)
EXPECT_EQ(sum(P1 == P3)[0], 255 * P3.total());
}
TEST(ML_RTrees, getVotes)
{
int n = 12;
int count, i;
int label_size = 3;
int predicted_class = 0;
int max_votes = -1;
int val;
// RTrees for classification
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
//data
Mat data(n, 4, CV_32F);
randu(data, 0, 10);
//labels
Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
rt->train(data, ml::ROW_SAMPLE, labels);
//run function
Mat test(1, 4, CV_32F);
Mat result;
randu(test, 0, 10);
rt->getVotes(test, result, 0);
//count vote amount and find highest vote
count = 0;
const int* result_row = result.ptr<int>(1);
for( i = 0; i < label_size; i++ )
{
val = result_row[i];
//predicted_class = max_votes < val? i;
if( max_votes < val )
{
max_votes = val;
predicted_class = i;
}
count += val;
}
EXPECT_EQ(count, (int)rt->getRoots().size());
EXPECT_EQ(result.at<float>(0, predicted_class), rt->predict(test));
}
/* End of file. */
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment