Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
8d9d9645
Commit
8d9d9645
authored
Mar 29, 2012
by
Maria Dimashova
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
added smoke test on EM, fixed EM reading #1570 (thanks to mr.pppoe),
parent
ec793df3
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
91 additions
and
21 deletions
+91
-21
em.cpp
modules/ml/src/em.cpp
+15
-21
test_emknearestkmeans.cpp
modules/ml/test/test_emknearestkmeans.cpp
+76
-0
No files found.
modules/ml/src/em.cpp
View file @
8d9d9645
...
...
@@ -141,8 +141,6 @@ void CvEM::read( CvFileStorage* fs, CvFileNode* node )
CvFileNode
*
em_node
=
0
;
CvFileNode
*
tmp_node
=
0
;
CvSeq
*
seq
=
0
;
CvMat
**
tmp_covs
=
0
;
CvMat
**
tmp_cov_rotate_mats
=
0
;
read_params
(
fs
,
node
);
...
...
@@ -156,13 +154,10 @@ void CvEM::read( CvFileStorage* fs, CvFileNode* node )
CV_CALL
(
inv_eigen_values
=
(
CvMat
*
)
cvReadByName
(
fs
,
em_node
,
"inv_eigen_values"
));
// Size of all the following data
data_size
=
params
.
nclusters
*
2
*
sizeof
(
CvMat
*
);
CV_CALL
(
tmp_covs
=
(
CvMat
**
)
cvAlloc
(
data_size
));
memset
(
tmp_covs
,
0
,
data_size
);
tmp_cov_rotate_mats
=
tmp_covs
+
params
.
nclusters
;
data_size
=
params
.
nclusters
*
sizeof
(
CvMat
*
);
CV_CALL
(
covs
=
(
CvMat
**
)
cvAlloc
(
data_size
));
memset
(
covs
,
0
,
data_size
);
CV_CALL
(
tmp_node
=
cvGetFileNodeByName
(
fs
,
em_node
,
"covs"
));
seq
=
tmp_node
->
data
.
seq
;
if
(
!
CV_NODE_IS_SEQ
(
tmp_node
->
tag
)
||
seq
->
total
!=
params
.
nclusters
)
...
...
@@ -170,24 +165,23 @@ void CvEM::read( CvFileStorage* fs, CvFileNode* node )
CV_CALL
(
cvStartReadSeq
(
seq
,
&
reader
,
0
));
for
(
int
i
=
0
;
i
<
params
.
nclusters
;
i
++
)
{
CV_CALL
(
tmp_
covs
[
i
]
=
(
CvMat
*
)
cvRead
(
fs
,
(
CvFileNode
*
)
reader
.
ptr
));
CV_CALL
(
covs
[
i
]
=
(
CvMat
*
)
cvRead
(
fs
,
(
CvFileNode
*
)
reader
.
ptr
));
CV_NEXT_SEQ_ELEM
(
seq
->
elem_size
,
reader
);
}
CV_CALL
(
cov_rotate_mats
=
(
CvMat
**
)
cvAlloc
(
data_size
));
memset
(
cov_rotate_mats
,
0
,
data_size
);
CV_CALL
(
tmp_node
=
cvGetFileNodeByName
(
fs
,
em_node
,
"cov_rotate_mats"
));
seq
=
tmp_node
->
data
.
seq
;
if
(
!
CV_NODE_IS_SEQ
(
tmp_node
->
tag
)
||
seq
->
total
!=
params
.
nclusters
)
CV_ERROR
(
CV_StsParseError
,
"Missing or invalid sequence of
rotated cov.
matrices"
);
CV_ERROR
(
CV_StsParseError
,
"Missing or invalid sequence of
covariance
matrices"
);
CV_CALL
(
cvStartReadSeq
(
seq
,
&
reader
,
0
));
for
(
int
i
=
0
;
i
<
params
.
nclusters
;
i
++
)
{
CV_CALL
(
tmp_
cov_rotate_mats
[
i
]
=
(
CvMat
*
)
cvRead
(
fs
,
(
CvFileNode
*
)
reader
.
ptr
));
CV_CALL
(
cov_rotate_mats
[
i
]
=
(
CvMat
*
)
cvRead
(
fs
,
(
CvFileNode
*
)
reader
.
ptr
));
CV_NEXT_SEQ_ELEM
(
seq
->
elem_size
,
reader
);
}
covs
=
tmp_covs
;
cov_rotate_mats
=
tmp_cov_rotate_mats
;
ok
=
true
;
__END__
;
...
...
@@ -862,10 +856,10 @@ void CvEM::kmeans( const CvVectors& train_data, int nclusters, CvMat* labels,
{
int
i
,
nsamples
=
train_data
.
count
,
dims
=
train_data
.
dims
;
cv
::
Ptr
<
CvMat
>
temp_mat
=
cvCreateMat
(
nsamples
,
dims
,
CV_32F
);
for
(
i
=
0
;
i
<
nsamples
;
i
++
)
memcpy
(
temp_mat
->
data
.
ptr
+
temp_mat
->
step
*
i
,
train_data
.
data
.
fl
[
i
],
dims
*
sizeof
(
float
));
cvKMeans2
(
temp_mat
,
nclusters
,
labels
,
termcrit
,
10
);
}
...
...
@@ -1240,20 +1234,20 @@ CvEM::CvEM( const Mat& samples, const Mat& sample_idx, CvEMParams params )
{
means
=
weights
=
probs
=
inv_eigen_values
=
log_weight_div_det
=
0
;
covs
=
cov_rotate_mats
=
0
;
// just invoke the train() method
train
(
samples
,
sample_idx
,
params
);
}
}
bool
CvEM
::
train
(
const
Mat
&
_samples
,
const
Mat
&
_sample_idx
,
CvEMParams
_params
,
Mat
*
_labels
)
{
CvMat
samples
=
_samples
,
sidx
=
_sample_idx
,
labels
,
*
plabels
=
0
;
if
(
_labels
)
{
int
nsamples
=
sidx
.
data
.
ptr
?
sidx
.
rows
:
samples
.
rows
;
if
(
!
(
_labels
->
data
&&
_labels
->
type
()
==
CV_32SC1
&&
(
_labels
->
cols
==
1
||
_labels
->
rows
==
1
)
&&
_labels
->
cols
+
_labels
->
rows
-
1
==
nsamples
)
)
...
...
@@ -1267,7 +1261,7 @@ float
CvEM
::
predict
(
const
Mat
&
_sample
,
Mat
*
_probs
)
const
{
CvMat
sample
=
_sample
,
probs
,
*
pprobs
=
0
;
if
(
_probs
)
{
int
nclusters
=
params
.
nclusters
;
...
...
modules/ml/test/test_emknearestkmeans.cpp
View file @
8d9d9645
...
...
@@ -332,6 +332,82 @@ void CV_EMTest::run( int /*start_from*/ )
ts
->
set_failed_test_info
(
code
);
}
class
CV_EMTest_Smoke
:
public
cvtest
::
BaseTest
{
public
:
CV_EMTest_Smoke
()
{}
protected
:
virtual
void
run
(
int
/*start_from*/
)
{
int
code
=
cvtest
::
TS
::
OK
;
CvEM
em
;
Mat
samples
=
Mat
(
3
,
2
,
CV_32F
);
samples
.
at
<
float
>
(
0
,
0
)
=
1
;
samples
.
at
<
float
>
(
1
,
0
)
=
2
;
samples
.
at
<
float
>
(
2
,
0
)
=
3
;
CvEMParams
params
;
params
.
nclusters
=
2
;
Mat
labels
;
em
.
train
(
samples
,
Mat
(),
params
,
&
labels
);
Mat
firstResult
(
samples
.
rows
,
1
,
CV_32FC1
);
for
(
int
i
=
0
;
i
<
samples
.
rows
;
i
++
)
firstResult
.
at
<
float
>
(
i
)
=
em
.
predict
(
samples
.
row
(
i
)
);
// Write out
string
filename
=
tempfile
()
+
".xml"
;
{
FileStorage
fs
=
FileStorage
(
filename
,
FileStorage
::
WRITE
);
try
{
em
.
write
(
fs
.
fs
,
"EM"
);
}
catch
(...)
{
ts
->
printf
(
cvtest
::
TS
::
LOG
,
"Crash in write method.
\n
"
);
ts
->
set_failed_test_info
(
cvtest
::
TS
::
FAIL_EXCEPTION
);
}
}
em
.
clear
();
// Read in
{
FileStorage
fs
=
FileStorage
(
filename
,
FileStorage
::
READ
);
FileNode
fileNode
=
fs
[
"EM"
];
try
{
em
.
read
(
const_cast
<
CvFileStorage
*>
(
fileNode
.
fs
),
const_cast
<
CvFileNode
*>
(
fileNode
.
node
));
}
catch
(...)
{
ts
->
printf
(
cvtest
::
TS
::
LOG
,
"Crash in read method.
\n
"
);
ts
->
set_failed_test_info
(
cvtest
::
TS
::
FAIL_EXCEPTION
);
}
}
remove
(
filename
.
c_str
()
);
int
errCaseCount
=
0
;
for
(
int
i
=
0
;
i
<
samples
.
rows
;
i
++
)
errCaseCount
=
std
::
abs
(
em
.
predict
(
samples
.
row
(
i
))
-
firstResult
.
at
<
float
>
(
i
))
<
FLT_EPSILON
?
0
:
1
;
if
(
errCaseCount
>
0
)
{
ts
->
printf
(
cvtest
::
TS
::
LOG
,
"Different prediction results before writeing and after reading (errCaseCount=%d).
\n
"
,
errCaseCount
);
code
=
cvtest
::
TS
::
FAIL_BAD_ACCURACY
;
}
ts
->
set_failed_test_info
(
code
);
}
};
TEST
(
ML_KMeans
,
accuracy
)
{
CV_KMeansTest
test
;
test
.
safe_run
();
}
TEST
(
ML_KNearest
,
accuracy
)
{
CV_KNearestTest
test
;
test
.
safe_run
();
}
TEST
(
ML_EM
,
accuracy
)
{
CV_EMTest
test
;
test
.
safe_run
();
}
TEST
(
ML_EM
,
smoke
)
{
CV_EMTest_Smoke
test
;
test
.
safe_run
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment