Commit fc04b7ab authored by Maria Dimashova's avatar Maria Dimashova

minor refactoring of CvMLData interface

parent 77be493e
...@@ -2061,10 +2061,9 @@ CVAPI(void) cvCreateTestSet( int type, CvMat** samples, ...@@ -2061,10 +2061,9 @@ CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
struct CV_EXPORTS CvTrainTestSplit struct CV_EXPORTS CvTrainTestSplit
{ {
public:
CvTrainTestSplit(); CvTrainTestSplit();
CvTrainTestSplit( int _train_sample_count, bool _mix = true); CvTrainTestSplit( int train_sample_count, bool mix = true);
CvTrainTestSplit( float _train_sample_portion, bool _mix = true); CvTrainTestSplit( float train_sample_portion, bool mix = true);
union union
{ {
...@@ -2073,14 +2072,7 @@ public: ...@@ -2073,14 +2072,7 @@ public:
} train_sample_part; } train_sample_part;
int train_sample_part_mode; int train_sample_part_mode;
union bool mix;
{
int *count;
float *portion;
} *class_part;
int class_part_mode;
bool mix;
}; };
class CV_EXPORTS CvMLData class CV_EXPORTS CvMLData
...@@ -2094,24 +2086,24 @@ public: ...@@ -2094,24 +2086,24 @@ public:
// 1 - file can not be opened or is not correct // 1 - file can not be opened or is not correct
int read_csv( const char* filename ); int read_csv( const char* filename );
const CvMat* get_values(); const CvMat* get_values() const;
const CvMat* get_responses(); const CvMat* get_responses();
const CvMat* get_missing(); const CvMat* get_missing() const;
void set_response_idx( int idx ); // old response become predictors, new response_idx = idx void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
// if idx < 0 there will be no response // if idx < 0 there will be no response
int get_response_idx(); int get_response_idx() const;
const CvMat* get_train_sample_idx();
const CvMat* get_test_sample_idx();
void mix_train_and_test_idx();
void set_train_test_split( const CvTrainTestSplit * spl ); void set_train_test_split( const CvTrainTestSplit * spl );
const CvMat* get_train_sample_idx() const;
const CvMat* get_test_sample_idx() const;
void mix_train_and_test_idx();
const CvMat* get_var_idx(); const CvMat* get_var_idx();
void chahge_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor void chahge_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
const CvMat* get_var_types(); const CvMat* get_var_types();
int get_var_type( int var_idx ); int get_var_type( int var_idx ) const;
// following 2 methods enable to change vars type // following 2 methods enable to change vars type
// use these methods to assign CV_VAR_CATEGORICAL type for categorical variable // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
// with numerical labels; in the other cases var types are correctly determined automatically // with numerical labels; in the other cases var types are correctly determined automatically
...@@ -2121,11 +2113,13 @@ public: ...@@ -2121,11 +2113,13 @@ public:
void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL } void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }
void set_delimiter( char ch ); void set_delimiter( char ch );
char get_delimiter(); char get_delimiter() const;
void set_miss_ch( char ch ); void set_miss_ch( char ch );
char get_miss_ch(); char get_miss_ch() const;
const std::map<std::string, int>& get_class_labels_map() const;
protected: protected:
virtual void clear(); virtual void clear();
...@@ -2151,7 +2145,7 @@ protected: ...@@ -2151,7 +2145,7 @@ protected:
bool mix; bool mix;
int total_class_count; int total_class_count;
std::map<std::string, int> *class_map; std::map<std::string, int> class_map;
CvMat* train_sample_idx; CvMat* train_sample_idx;
CvMat* test_sample_idx; CvMat* test_sample_idx;
......
...@@ -48,7 +48,6 @@ CvTrainTestSplit::CvTrainTestSplit() ...@@ -48,7 +48,6 @@ CvTrainTestSplit::CvTrainTestSplit()
{ {
train_sample_part_mode = CV_COUNT; train_sample_part_mode = CV_COUNT;
train_sample_part.count = -1; train_sample_part.count = -1;
class_part = 0;
mix = false; mix = false;
} }
...@@ -56,7 +55,6 @@ CvTrainTestSplit::CvTrainTestSplit( int _train_sample_count, bool _mix ) ...@@ -56,7 +55,6 @@ CvTrainTestSplit::CvTrainTestSplit( int _train_sample_count, bool _mix )
{ {
train_sample_part_mode = CV_COUNT; train_sample_part_mode = CV_COUNT;
train_sample_part.count = _train_sample_count; train_sample_part.count = _train_sample_count;
class_part = 0;
mix = _mix; mix = _mix;
} }
...@@ -64,7 +62,6 @@ CvTrainTestSplit::CvTrainTestSplit( float _train_sample_portion, bool _mix ) ...@@ -64,7 +62,6 @@ CvTrainTestSplit::CvTrainTestSplit( float _train_sample_portion, bool _mix )
{ {
train_sample_part_mode = CV_PORTION; train_sample_part_mode = CV_PORTION;
train_sample_part.portion = _train_sample_portion; train_sample_part.portion = _train_sample_portion;
class_part = 0;
mix = _mix; mix = _mix;
} }
...@@ -83,14 +80,12 @@ CvMLData::CvMLData() ...@@ -83,14 +80,12 @@ CvMLData::CvMLData()
miss_ch = '?'; miss_ch = '?';
//flt_separator = '.'; //flt_separator = '.';
class_map = new std::map<std::string, int>();
rng = &cv::theRNG(); rng = &cv::theRNG();
} }
CvMLData::~CvMLData() CvMLData::~CvMLData()
{ {
clear(); clear();
delete class_map;
} }
void CvMLData::free_train_test_idx() void CvMLData::free_train_test_idx()
...@@ -102,8 +97,7 @@ void CvMLData::free_train_test_idx() ...@@ -102,8 +97,7 @@ void CvMLData::free_train_test_idx()
void CvMLData::clear() void CvMLData::clear()
{ {
if ( !class_map->empty() ) class_map.clear();
class_map->clear();
cvReleaseMat( &values ); cvReleaseMat( &values );
cvReleaseMat( &missing ); cvReleaseMat( &missing );
...@@ -244,16 +238,29 @@ int CvMLData::read_csv(const char* filename) ...@@ -244,16 +238,29 @@ int CvMLData::read_csv(const char* filename)
return 0; return 0;
} }
const CvMat* CvMLData::get_values() const CvMat* CvMLData::get_values() const
{ {
return values; return values;
} }
const CvMat* CvMLData::get_missing() const CvMat* CvMLData::get_missing() const
{ {
CV_FUNCNAME( "CvMLData::get_missing" );
__BEGIN__;
if ( !values )
CV_ERROR( CV_StsInternal, "data is empty" );
__END__;
return missing; return missing;
} }
const std::map<std::string, int>& CvMLData::get_class_labels_map() const
{
return class_map;
}
void CvMLData::str_to_flt_elem( const char* token, float& flt_elem, int& type) void CvMLData::str_to_flt_elem( const char* token, float& flt_elem, int& type)
{ {
...@@ -270,12 +277,12 @@ void CvMLData::str_to_flt_elem( const char* token, float& flt_elem, int& type) ...@@ -270,12 +277,12 @@ void CvMLData::str_to_flt_elem( const char* token, float& flt_elem, int& type)
{ {
if ( (*stopstring != 0) && (*stopstring != '\n') && (strcmp(stopstring, "\r\n") != 0) ) // class label if ( (*stopstring != 0) && (*stopstring != '\n') && (strcmp(stopstring, "\r\n") != 0) ) // class label
{ {
int idx = (*class_map)[token]; int idx = class_map[token];
if ( idx == 0) if ( idx == 0)
{ {
total_class_count++; total_class_count++;
idx = total_class_count; idx = total_class_count;
(*class_map)[token] = idx; class_map[token] = idx;
} }
flt_elem = (float)idx; flt_elem = (float)idx;
type = CV_VAR_CATEGORICAL; type = CV_VAR_CATEGORICAL;
...@@ -296,7 +303,7 @@ void CvMLData::set_delimiter(char ch) ...@@ -296,7 +303,7 @@ void CvMLData::set_delimiter(char ch)
__END__; __END__;
} }
char CvMLData::get_delimiter() char CvMLData::get_delimiter() const
{ {
return delimiter; return delimiter;
} }
...@@ -314,7 +321,7 @@ void CvMLData::set_miss_ch(char ch) ...@@ -314,7 +321,7 @@ void CvMLData::set_miss_ch(char ch)
__END__; __END__;
} }
char CvMLData::get_miss_ch() char CvMLData::get_miss_ch() const
{ {
return miss_ch; return miss_ch;
} }
...@@ -339,8 +346,14 @@ void CvMLData::set_response_idx( int idx ) ...@@ -339,8 +346,14 @@ void CvMLData::set_response_idx( int idx )
__END__; __END__;
} }
int CvMLData::get_response_idx() int CvMLData::get_response_idx() const
{ {
CV_FUNCNAME( "CvMLData::get_response_idx" );
__BEGIN__;
if ( !values )
CV_ERROR( CV_StsInternal, "data is empty" );
__END__;
return response_idx; return response_idx;
} }
...@@ -536,7 +549,7 @@ const CvMat* CvMLData::get_var_types() ...@@ -536,7 +549,7 @@ const CvMat* CvMLData::get_var_types()
return var_types_out; return var_types_out;
} }
int CvMLData::get_var_type( int var_idx ) int CvMLData::get_var_type( int var_idx ) const
{ {
return var_types->data.ptr[var_idx]; return var_types->data.ptr[var_idx];
} }
...@@ -572,9 +585,6 @@ void CvMLData::set_train_test_split( const CvTrainTestSplit * spl) ...@@ -572,9 +585,6 @@ void CvMLData::set_train_test_split( const CvTrainTestSplit * spl)
int sample_count = 0; int sample_count = 0;
if ( spl->class_part )
CV_ERROR( CV_StsBadArg, "this division type is not supported yet" );
if ( !values ) if ( !values )
CV_ERROR( CV_StsInternal, "data is empty" ); CV_ERROR( CV_StsInternal, "data is empty" );
...@@ -627,19 +637,41 @@ void CvMLData::set_train_test_split( const CvTrainTestSplit * spl) ...@@ -627,19 +637,41 @@ void CvMLData::set_train_test_split( const CvTrainTestSplit * spl)
__END__; __END__;
} }
const CvMat* CvMLData::get_train_sample_idx() const CvMat* CvMLData::get_train_sample_idx() const
{ {
CV_FUNCNAME( "CvMLData::get_train_sample_idx" );
__BEGIN__;
if ( !values )
CV_ERROR( CV_StsInternal, "data is empty" );
__END__;
return train_sample_idx; return train_sample_idx;
} }
const CvMat* CvMLData::get_test_sample_idx() const CvMat* CvMLData::get_test_sample_idx() const
{ {
CV_FUNCNAME( "CvMLData::get_test_sample_idx" );
__BEGIN__;
if ( !values )
CV_ERROR( CV_StsInternal, "data is empty" );
__END__;
return test_sample_idx; return test_sample_idx;
} }
void CvMLData::mix_train_and_test_idx() void CvMLData::mix_train_and_test_idx()
{ {
if ( !values || !sample_idx) return; CV_FUNCNAME( "CvMLData::mix_train_and_test_idx" );
__BEGIN__;
if ( !values )
CV_ERROR( CV_StsInternal, "data is empty" );
__END__;
if ( !sample_idx)
return;
if ( train_sample_count > 0 && train_sample_count < values->rows ) if ( train_sample_count > 0 && train_sample_count < values->rows )
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment