Commit 5d2edced authored by Daniil Osokin's avatar Daniil Osokin

Added throwing exception when saving untrained SVM model

parent 890f1baf
...@@ -2298,14 +2298,24 @@ void CvSVM::write_params( CvFileStorage* fs ) const ...@@ -2298,14 +2298,24 @@ void CvSVM::write_params( CvFileStorage* fs ) const
} }
static bool isSvmModelApplicable(int sv_total, int var_all, int var_count, int class_count)
{
return (sv_total > 0 && var_count > 0 && var_count <= var_all && class_count >= 0);
}
void CvSVM::write( CvFileStorage* fs, const char* name ) const void CvSVM::write( CvFileStorage* fs, const char* name ) const
{ {
CV_FUNCNAME( "CvSVM::write" ); CV_FUNCNAME( "CvSVM::write" );
__BEGIN__; __BEGIN__;
int i, var_count = get_var_count(), df_count, class_count; int i, var_count = get_var_count(), df_count;
int class_count = class_labels ? class_labels->cols :
params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
const CvSVMDecisionFunc* df = decision_func; const CvSVMDecisionFunc* df = decision_func;
if( !isSvmModelApplicable(sv_total, var_all, var_count, class_count) )
CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM ); cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM );
...@@ -2314,9 +2324,6 @@ void CvSVM::write( CvFileStorage* fs, const char* name ) const ...@@ -2314,9 +2324,6 @@ void CvSVM::write( CvFileStorage* fs, const char* name ) const
cvWriteInt( fs, "var_all", var_all ); cvWriteInt( fs, "var_all", var_all );
cvWriteInt( fs, "var_count", var_count ); cvWriteInt( fs, "var_count", var_count );
class_count = class_labels ? class_labels->cols :
params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
if( class_count ) if( class_count )
{ {
cvWriteInt( fs, "class_count", class_count ); cvWriteInt( fs, "class_count", class_count );
...@@ -2454,7 +2461,6 @@ void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node ) ...@@ -2454,7 +2461,6 @@ void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node )
__END__; __END__;
} }
void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node ) void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
{ {
const double not_found_dbl = DBL_MAX; const double not_found_dbl = DBL_MAX;
...@@ -2483,7 +2489,7 @@ void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node ) ...@@ -2483,7 +2489,7 @@ void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
var_count = cvReadIntByName( fs, svm_node, "var_count", var_all ); var_count = cvReadIntByName( fs, svm_node, "var_count", var_all );
class_count = cvReadIntByName( fs, svm_node, "class_count", 0 ); class_count = cvReadIntByName( fs, svm_node, "class_count", 0 );
if( sv_total <= 0 || var_all <= 0 || var_count <= 0 || var_count > var_all || class_count < 0 ) if( !isSvmModelApplicable(sv_total, var_all, var_count, class_count) )
CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" ); CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" )); CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" ));
......
...@@ -155,6 +155,14 @@ TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); } ...@@ -155,6 +155,14 @@ TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
TEST(ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); } TEST(ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
TEST(ML_SVM, throw_exception_when_save_untrained_model)
{
SVM svm;
string filename = tempfile("svm.xml");
ASSERT_THROW(svm.save(filename.c_str()), Exception);
remove(filename.c_str());
}
TEST(DISABLED_ML_SVM, linear_save_load) TEST(DISABLED_ML_SVM, linear_save_load)
{ {
CvSVM svm1, svm2, svm3; CvSVM svm1, svm2, svm3;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment