Commit d8e3971e authored by niederb's avatar niederb Committed by Vadim Pisarevsky

Fixed variable importance in rtrees

parent bb4b4acc
...@@ -187,7 +187,7 @@ public: ...@@ -187,7 +187,7 @@ public:
oobidx.clear(); oobidx.clear();
for( i = 0; i < n; i++ ) for( i = 0; i < n; i++ )
{ {
if( !oobmask[i] ) if( oobmask[i] )
oobidx.push_back(i); oobidx.push_back(i);
} }
int n_oob = (int)oobidx.size(); int n_oob = (int)oobidx.size();
...@@ -217,6 +217,7 @@ public: ...@@ -217,6 +217,7 @@ public:
else else
{ {
int ival = cvRound(val); int ival = cvRound(val);
//Voting scheme to combine OOB errors of each tree
int* votes = &oobvotes[j*nclasses]; int* votes = &oobvotes[j*nclasses];
votes[ival]++; votes[ival]++;
int best_class = 0; int best_class = 0;
...@@ -235,35 +236,35 @@ public: ...@@ -235,35 +236,35 @@ public:
oobperm.resize(n_oob); oobperm.resize(n_oob);
for( i = 0; i < n_oob; i++ ) for( i = 0; i < n_oob; i++ )
oobperm[i] = oobidx[i]; oobperm[i] = oobidx[i];
for (i = n_oob - 1; i > 0; --i) //Randomly shuffle indices so we can permute features
{
int r_i = rng.uniform(0, i + 1);
std::swap(oobperm[i], oobperm[r_i]);
}
for( vi_ = 0; vi_ < nvars; vi_++ ) for( vi_ = 0; vi_ < nvars; vi_++ )
{ {
vi = vidx ? vidx[vi_] : vi_; vi = vidx ? vidx[vi_] : vi_; //Ensure that only the user specified predictors are used for training
double ncorrect_responses_permuted = 0; double ncorrect_responses_permuted = 0;
for( i = 0; i < n_oob; i++ )
{
int i1 = rng.uniform(0, n_oob);
int i2 = rng.uniform(0, n_oob);
std::swap(i1, i2);
}
for( i = 0; i < n_oob; i++ ) for( i = 0; i < n_oob; i++ )
{ {
j = oobidx[i]; j = oobidx[i];
int vj = oobperm[i]; int vj = oobperm[i];
sample0 = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) ); sample0 = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
for( k = 0; k < nallvars; k++ ) Mat sample_clone = sample0.clone(); //create a copy so we don't mess up the original data
sample.at<float>(k) = sample0.at<float>(k); sample_clone.at<float>(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi];
sample.at<float>(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi];
double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags); double val = predictTrees(Range(treeidx, treeidx+1), sample_clone, predictFlags);
if( !_isClassifier ) if( !_isClassifier )
{ {
val = (val - w->ord_responses[w->sidx[j]])/max_response; val = (val - w->ord_responses[w->sidx[j]])/max_response;
ncorrect_responses_permuted += exp( -val*val ); ncorrect_responses_permuted += exp( -val*val );
} }
else else
{
ncorrect_responses_permuted += cvRound(val) == w->cat_responses[w->sidx[j]]; ncorrect_responses_permuted += cvRound(val) == w->cat_responses[w->sidx[j]];
}
} }
varImportance[vi] += (float)(ncorrect_responses - ncorrect_responses_permuted); varImportance[vi] += (float)(ncorrect_responses - ncorrect_responses_permuted);
} }
......
...@@ -63,7 +63,6 @@ int main(int argc, char** argv) ...@@ -63,7 +63,6 @@ int main(int argc, char** argv)
const double train_test_split_ratio = 0.5; const double train_test_split_ratio = 0.5;
Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec); Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
if( data.empty() ) if( data.empty() )
{ {
printf("ERROR: File %s can not be read\n", filename); printf("ERROR: File %s can not be read\n", filename);
...@@ -71,6 +70,7 @@ int main(int argc, char** argv) ...@@ -71,6 +70,7 @@ int main(int argc, char** argv)
} }
data->setTrainTestSplitRatio(train_test_split_ratio); data->setTrainTestSplitRatio(train_test_split_ratio);
std::cout << "Test/Train: " << data->getNTestSamples() << "/" << data->getNTrainSamples();
printf("======DTREE=====\n"); printf("======DTREE=====\n");
Ptr<DTrees> dtree = DTrees::create(); Ptr<DTrees> dtree = DTrees::create();
...@@ -106,10 +106,19 @@ int main(int argc, char** argv) ...@@ -106,10 +106,19 @@ int main(int argc, char** argv)
rtrees->setUseSurrogates(false); rtrees->setUseSurrogates(false);
rtrees->setMaxCategories(16); rtrees->setMaxCategories(16);
rtrees->setPriors(Mat()); rtrees->setPriors(Mat());
rtrees->setCalculateVarImportance(false); rtrees->setCalculateVarImportance(true);
rtrees->setActiveVarCount(0); rtrees->setActiveVarCount(0);
rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0)); rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
train_and_print_errs(rtrees, data); train_and_print_errs(rtrees, data);
cv::Mat ref_labels = data->getClassLabels();
cv::Mat test_data = data->getTestSampleIdx();
cv::Mat predict_labels;
rtrees->predict(data->getSamples(), predict_labels);
cv::Mat variable_importance = rtrees->getVarImportance();
std::cout << "Estimated variable importance" << std::endl;
for (int i = 0; i < variable_importance.rows; i++) {
std::cout << "Variable " << i << ": " << variable_importance.at<float>(i, 0) << std::endl;
}
return 0; return 0;
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment