Commit 4d676165 authored by Vadim Pisarevsky's avatar Vadim Pisarevsky

incorporated several critical fixes in EM implementation from Albert G (ticket #264)

parent 7174957f
...@@ -789,8 +789,9 @@ double CvEM::run_em( const CvVectors& train_data ) ...@@ -789,8 +789,9 @@ double CvEM::run_em( const CvVectors& train_data )
int nsamples = train_data.count, dims = train_data.dims, nclusters = params.nclusters; int nsamples = train_data.count, dims = train_data.dims, nclusters = params.nclusters;
double min_variation = FLT_EPSILON; double min_variation = FLT_EPSILON;
double min_det_value = MAX( DBL_MIN, pow( min_variation, dims )); double min_det_value = MAX( DBL_MIN, pow( min_variation, dims ));
double likelihood_bias = -CV_LOG2PI * (double)nsamples * (double)dims / 2., _log_likelihood = -DBL_MAX; double _log_likelihood = -DBL_MAX;
int start_step = params.start_step; int start_step = params.start_step;
double sum_max_val;
int i, j, k, n; int i, j, k, n;
int is_general = 0, is_diagonal = 0, is_spherical = 0; int is_general = 0, is_diagonal = 0, is_spherical = 0;
...@@ -912,6 +913,7 @@ double CvEM::run_em( const CvVectors& train_data ) ...@@ -912,6 +913,7 @@ double CvEM::run_em( const CvVectors& train_data )
// e-step: compute probs_ik from means_k, covs_k and weights_k. // e-step: compute probs_ik from means_k, covs_k and weights_k.
CV_CALL(cvLog( weights, log_weights )); CV_CALL(cvLog( weights, log_weights ));
sum_max_val = 0.;
// S_ik = -0.5[log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] + log(weights_k) // S_ik = -0.5[log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] + log(weights_k)
for( k = 0; k < nclusters; k++ ) for( k = 0; k < nclusters; k++ )
{ {
...@@ -934,14 +936,16 @@ double CvEM::run_em( const CvVectors& train_data ) ...@@ -934,14 +936,16 @@ double CvEM::run_em( const CvVectors& train_data )
cvGEMM( centered_sample, u, 1, 0, 0, centered_sample, CV_GEMM_B_T ); cvGEMM( centered_sample, u, 1, 0, 0, centered_sample, CV_GEMM_B_T );
for( j = 0; j < dims; j++ ) for( j = 0; j < dims; j++ )
p += csample[j]*csample[j]*w_data[is_spherical ? 0 : j]; p += csample[j]*csample[j]*w_data[is_spherical ? 0 : j];
pp[k] = -0.5*p + log_weights->data.db[k]; //pp[k] = -0.5*p + log_weights->data.db[k];
pp[k] = -0.5*(p+CV_LOG2PI * (double)dims) + log_weights->data.db[k];
// S_ik <- S_ik - max_j S_ij // S_ik <- S_ik - max_j S_ij
if( k == nclusters - 1 ) if( k == nclusters - 1 )
{ {
double max_val = 0; double max_val = pp[0];
for( j = 0; j < nclusters; j++ ) for( j = 1; j < nclusters; j++ )
max_val = MAX( max_val, pp[j] ); max_val = MAX( max_val, pp[j] );
sum_max_val += max_val;
for( j = 0; j < nclusters; j++ ) for( j = 0; j < nclusters; j++ )
pp[j] -= max_val; pp[j] -= max_val;
} }
...@@ -953,7 +957,7 @@ double CvEM::run_em( const CvVectors& train_data ) ...@@ -953,7 +957,7 @@ double CvEM::run_em( const CvVectors& train_data )
// alpha_ik = exp( S_ik ) / sum_j exp( S_ij ), // alpha_ik = exp( S_ik ) / sum_j exp( S_ij ),
// log_likelihood = sum_i log (sum_j exp(S_ij)) // log_likelihood = sum_i log (sum_j exp(S_ij))
for( i = 0, _log_likelihood = likelihood_bias; i < nsamples; i++ ) for( i = 0, _log_likelihood = 0; i < nsamples; i++ )
{ {
double* pp = (double*)(probs->data.ptr + probs->step*i), sum = 0; double* pp = (double*)(probs->data.ptr + probs->step*i), sum = 0;
for( j = 0; j < nclusters; j++ ) for( j = 0; j < nclusters; j++ )
...@@ -966,9 +970,11 @@ double CvEM::run_em( const CvVectors& train_data ) ...@@ -966,9 +970,11 @@ double CvEM::run_em( const CvVectors& train_data )
} }
_log_likelihood -= log( sum ); _log_likelihood -= log( sum );
} }
_log_likelihood+=sum_max_val;
// check termination criteria // check termination criteria
if( fabs( (_log_likelihood - prev_log_likelihood) / prev_log_likelihood ) < params.term_crit.epsilon ) //if( fabs( (_log_likelihood - prev_log_likelihood) / prev_log_likelihood ) < params.term_crit.epsilon )
if( fabs( (_log_likelihood - prev_log_likelihood) ) < params.term_crit.epsilon )
break; break;
prev_log_likelihood = _log_likelihood; prev_log_likelihood = _log_likelihood;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment