Commit 71d7482a authored by Maria Dimashova's avatar Maria Dimashova

fixed nan in EM, added new test on EM

parent 94bcaeb2
......@@ -563,7 +563,7 @@ public:
enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2, COV_MAT_DEFAULT=COV_MAT_DIAGONAL};
// Default parameters
enum {DEFAULT_NCLUSTERS=10, DEFAULT_MAX_ITERS=100};
enum {DEFAULT_NCLUSTERS=5, DEFAULT_MAX_ITERS=100};
// The initial step
enum {START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0};
......@@ -635,7 +635,6 @@ protected:
Mat trainProbs;
Mat trainLogLikelihoods;
Mat trainLabels;
Mat trainCounts;
CV_PROP Mat weights;
CV_PROP Mat means;
......@@ -2035,7 +2034,7 @@ public:
// returns:
// 0 - OK
// 1 - file can not be opened or is not correct
// -1 - file can not be opened or is not correct
int read_csv( const char* filename );
const CvMat* get_values() const;
......
......@@ -44,7 +44,7 @@
namespace cv
{
const double minEigenValue = DBL_MIN;
const double minEigenValue = DBL_EPSILON;
///////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -67,7 +67,6 @@ void EM::clear()
trainProbs.release();
trainLogLikelihoods.release();
trainLabels.release();
trainCounts.release();
weights.release();
means.release();
......@@ -469,7 +468,6 @@ bool EM::doTrain(int startStep, OutputArray labels, OutputArray probs, OutputArr
trainProbs.release();
trainLabels.release();
trainLogLikelihoods.release();
trainCounts.release();
return true;
}
......@@ -556,97 +554,114 @@ void EM::eStep()
void EM::mStep()
{
trainCounts.create(1, nclusters, CV_32SC1);
trainCounts = Scalar(0);
// Update means_k, covs_k and weights_k from probs_ik
int dim = trainSamples.cols;
for(int sampleIndex = 0; sampleIndex < trainLabels.rows; sampleIndex++)
trainCounts.at<int>(trainLabels.at<int>(sampleIndex))++;
// Update weights
// not normalized first
reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
if(countNonZero(trainCounts) != (int)trainCounts.total())
{
clusterTrainSamples();
}
else
{
// Update means_k, covs_k and weights_k from probs_ik
int dim = trainSamples.cols;
// Update means
means.create(nclusters, dim, CV_64FC1);
means = Scalar(0);
// Update weights
// not normalized first
reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
const double minPosWeight = trainSamples.rows * DBL_EPSILON;
double minWeight = DBL_MAX;
int minWeightClusterIndex = -1;
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
{
if(weights.at<double>(clusterIndex) <= minPosWeight)
continue;
// Update means
means.create(nclusters, dim, CV_64FC1);
means = Scalar(0);
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
if(weights.at<double>(clusterIndex) < minWeight)
{
Mat clusterMean = means.row(clusterIndex);
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
clusterMean /= weights.at<double>(clusterIndex);
minWeight = weights.at<double>(clusterIndex);
minWeightClusterIndex = clusterIndex;
}
// Update covsEigenValues and invCovsEigenValues
covs.resize(nclusters);
covsEigenValues.resize(nclusters);
Mat clusterMean = means.row(clusterIndex);
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
clusterMean /= weights.at<double>(clusterIndex);
}
// Update covsEigenValues and invCovsEigenValues
covs.resize(nclusters);
covsEigenValues.resize(nclusters);
if(covMatType == EM::COV_MAT_GENERIC)
covsRotateMats.resize(nclusters);
invCovsEigenValues.resize(nclusters);
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
{
if(weights.at<double>(clusterIndex) <= minPosWeight)
continue;
if(covMatType != EM::COV_MAT_SPHERICAL)
covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
else
covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);
if(covMatType == EM::COV_MAT_GENERIC)
covsRotateMats.resize(nclusters);
invCovsEigenValues.resize(nclusters);
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
{
if(covMatType != EM::COV_MAT_SPHERICAL)
covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
else
covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);
covs[clusterIndex].create(dim, dim, CV_64FC1);
if(covMatType == EM::COV_MAT_GENERIC)
covs[clusterIndex].create(dim, dim, CV_64FC1);
Mat clusterCov = covMatType != EM::COV_MAT_GENERIC ?
covsEigenValues[clusterIndex] : covs[clusterIndex];
Mat clusterCov = covMatType != EM::COV_MAT_GENERIC ?
covsEigenValues[clusterIndex] : covs[clusterIndex];
clusterCov = Scalar(0);
clusterCov = Scalar(0);
Mat centeredSample;
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
{
centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
Mat centeredSample;
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
if(covMatType == EM::COV_MAT_GENERIC)
clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
else
{
centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
if(covMatType == EM::COV_MAT_GENERIC)
clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
else
double p = trainProbs.at<double>(sampleIndex, clusterIndex);
for(int di = 0; di < dim; di++ )
{
double p = trainProbs.at<double>(sampleIndex, clusterIndex);
for(int di = 0; di < dim; di++ )
{
double val = centeredSample.at<double>(di);
clusterCov.at<double>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0) += p*val*val;
}
double val = centeredSample.at<double>(di);
clusterCov.at<double>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0) += p*val*val;
}
}
}
if(covMatType == EM::COV_MAT_SPHERICAL)
clusterCov /= dim;
if(covMatType == EM::COV_MAT_SPHERICAL)
clusterCov /= dim;
clusterCov /= weights.at<double>(clusterIndex);
clusterCov /= weights.at<double>(clusterIndex);
// Update covsRotateMats for EM::COV_MAT_GENERIC only
if(covMatType == EM::COV_MAT_GENERIC)
{
SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
covsEigenValues[clusterIndex] = svd.w;
covsRotateMats[clusterIndex] = svd.u;
}
// Update covsRotateMats for EM::COV_MAT_GENERIC only
if(covMatType == EM::COV_MAT_GENERIC)
{
SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
covsEigenValues[clusterIndex] = svd.w;
covsRotateMats[clusterIndex] = svd.u;
}
max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
// update invCovsEigenValues
invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
}
// update invCovsEigenValues
invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
}
// Normalize weights
weights /= trainSamples.rows;
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
{
if(weights.at<double>(clusterIndex) <= minPosWeight)
{
Mat clusterMean = means.row(clusterIndex);
means.row(minWeightClusterIndex).copyTo(clusterMean);
covs[minWeightClusterIndex].copyTo(covs[clusterIndex]);
covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]);
if(covMatType == EM::COV_MAT_GENERIC)
covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]);
invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]);
}
}
// Normalize weights
weights /= trainSamples.rows;
}
void EM::read(const FileNode& fn)
......
......@@ -572,7 +572,106 @@ protected:
}
};
class CV_EMTest_Classification : public cvtest::BaseTest
{
public:
CV_EMTest_Classification() {}
protected:
virtual void run(int)
{
// This test classifies spam by the following way:
// 1. estimates distributions of "spam" / "not spam"
// 2. predict classID using Bayes classifier for estimated distributions.
CvMLData data;
string dataFilename = string(ts->get_data_path()) + "spambase.data";
if(data.read_csv(dataFilename.c_str()) != 0)
{
ts->printf(cvtest::TS::LOG, "File with spambase dataset cann't be read.\n");
ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA);
}
Mat values = data.get_values();
CV_Assert(values.cols == 58);
int responseIndex = 57;
Mat samples = values.colRange(0, responseIndex);
Mat responses = values.col(responseIndex);
vector<int> trainSamplesMask(samples.rows, 0);
int trainSamplesCount = (int)(0.5f * samples.rows);
for(int i = 0; i < trainSamplesCount; i++)
trainSamplesMask[i] = 1;
RNG rng(0);
for(size_t i = 0; i < trainSamplesMask.size(); i++)
{
int i1 = rng(trainSamplesMask.size());
int i2 = rng(trainSamplesMask.size());
std::swap(trainSamplesMask[i1], trainSamplesMask[i2]);
}
EM model0(3), model1(3);
Mat samples0, samples1;
for(int i = 0; i < samples.rows; i++)
{
if(trainSamplesMask[i])
{
Mat sample = samples.row(i);
int resp = (int)responses.at<float>(i);
if(resp == 0)
samples0.push_back(sample);
else
samples1.push_back(sample);
}
}
model0.train(samples0);
model1.train(samples1);
Mat trainConfusionMat(2, 2, CV_32SC1, Scalar(0)),
testConfusionMat(2, 2, CV_32SC1, Scalar(0));
const double lambda = 1.;
for(int i = 0; i < samples.rows; i++)
{
double sampleLogLikelihoods0 = 0, sampleLogLikelihoods1 = 0;
Mat sample = samples.row(i);
model0.predict(sample, noArray(), &sampleLogLikelihoods0);
model1.predict(sample, noArray(), &sampleLogLikelihoods1);
int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1;
if(trainSamplesMask[i])
trainConfusionMat.at<int>((int)responses.at<float>(i), classID)++;
else
testConfusionMat.at<int>((int)responses.at<float>(i), classID)++;
}
// std::cout << trainConfusionMat << std::endl;
// std::cout << testConfusionMat << std::endl;
double trainError = (double)(trainConfusionMat.at<int>(1,0) + trainConfusionMat.at<int>(0,1)) / trainSamplesCount;
double testError = (double)(testConfusionMat.at<int>(1,0) + testConfusionMat.at<int>(0,1)) / (samples.rows - trainSamplesCount);
const double maxTrainError = 0.16;
const double maxTestError = 0.19;
int code = cvtest::TS::OK;
if(trainError > maxTrainError)
{
ts->printf(cvtest::TS::LOG, "Too large train classification error (calc = %f, valid=%f).\n", trainError, maxTrainError);
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
}
if(testError > maxTestError)
{
ts->printf(cvtest::TS::LOG, "Too large test classification error (calc = %f, valid=%f).\n", trainError, maxTrainError);
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
}
ts->set_failed_test_info(code);
}
};
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); }
TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment