Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
e0ee2f76
Commit
e0ee2f76
authored
Mar 02, 2017
by
Vadim Pisarevsky
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #8116 from mrquorr:master
parents
f46fa6e0
d8425d88
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
0 deletions
+123
-0
ml.hpp
modules/ml/include/opencv2/ml.hpp
+11
-0
rtrees.cpp
modules/ml/src/rtrees.cpp
+67
-0
test_mltests.cpp
modules/ml/test/test_mltests.cpp
+45
-0
No files found.
modules/ml/include/opencv2/ml.hpp
View file @
e0ee2f76
...
...
@@ -1206,6 +1206,17 @@ public:
*/
CV_WRAP
virtual
Mat
getVarImportance
()
const
=
0
;
/** Returns the result of each individual tree in the forest.
In case the model is a regression problem, the method will return each of the trees'
results for each of the sample cases. If the model is a classifier, it will return
a Mat with samples + 1 rows, where the first row gives the class number and the
following rows return the votes each class had for each sample.
@param samples Array containg the samples for which votes will be calculated.
@param results Array where the result of the calculation will be written.
@param flags Flags for defining the type of RTrees.
*/
CV_WRAP
void
getVotes
(
InputArray
samples
,
OutputArray
results
,
int
flags
)
const
;
/** Creates the empty model.
Use StatModel::train to train the model, StatModel::train to create and train the model,
Algorithm::load to load the pre-trained model.
...
...
modules/ml/src/rtrees.cpp
View file @
e0ee2f76
...
...
@@ -349,6 +349,60 @@ public:
}
}
void
getVotes
(
InputArray
input
,
OutputArray
output
,
int
flags
)
const
{
CV_Assert
(
!
roots
.
empty
()
);
int
nclasses
=
(
int
)
classLabels
.
size
(),
ntrees
=
(
int
)
roots
.
size
();
Mat
samples
=
input
.
getMat
(),
results
;
int
i
,
j
,
nsamples
=
samples
.
rows
;
int
predictType
=
flags
&
PREDICT_MASK
;
if
(
predictType
==
PREDICT_AUTO
)
{
predictType
=
!
_isClassifier
||
(
classLabels
.
size
()
==
2
&&
(
flags
&
RAW_OUTPUT
)
!=
0
)
?
PREDICT_SUM
:
PREDICT_MAX_VOTE
;
}
if
(
predictType
==
PREDICT_SUM
)
{
output
.
create
(
nsamples
,
ntrees
,
CV_32F
);
results
=
output
.
getMat
();
for
(
i
=
0
;
i
<
nsamples
;
i
++
)
{
for
(
j
=
0
;
j
<
ntrees
;
j
++
)
{
float
val
=
predictTrees
(
Range
(
j
,
j
+
1
),
samples
.
row
(
i
),
flags
);
results
.
at
<
float
>
(
i
,
j
)
=
val
;
}
}
}
else
{
vector
<
int
>
votes
;
output
.
create
(
nsamples
+
1
,
nclasses
,
CV_32S
);
results
=
output
.
getMat
();
for
(
j
=
0
;
j
<
nclasses
;
j
++
)
{
results
.
at
<
int
>
(
0
,
j
)
=
classLabels
[
j
];
}
for
(
i
=
0
;
i
<
nsamples
;
i
++
)
{
votes
.
clear
();
for
(
j
=
0
;
j
<
ntrees
;
j
++
)
{
int
val
=
(
int
)
predictTrees
(
Range
(
j
,
j
+
1
),
samples
.
row
(
i
),
flags
);
votes
.
push_back
(
val
);
}
for
(
j
=
0
;
j
<
nclasses
;
j
++
)
{
results
.
at
<
int
>
(
i
+
1
,
j
)
=
(
int
)
std
::
count
(
votes
.
begin
(),
votes
.
end
(),
classLabels
[
j
]);
}
}
}
}
RTreeParams
rparams
;
double
oobError
;
vector
<
float
>
varImportance
;
...
...
@@ -401,6 +455,11 @@ public:
impl
.
read
(
fn
);
}
void
getVotes_
(
InputArray
samples
,
OutputArray
results
,
int
flags
)
const
{
impl
.
getVotes
(
samples
,
results
,
flags
);
}
Mat
getVarImportance
()
const
{
return
Mat_
<
float
>
(
impl
.
varImportance
,
true
);
}
int
getVarCount
()
const
{
return
impl
.
getVarCount
();
}
...
...
@@ -427,6 +486,14 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
return
Algorithm
::
load
<
RTrees
>
(
filepath
,
nodeName
);
}
void
RTrees
::
getVotes
(
InputArray
input
,
OutputArray
output
,
int
flags
)
const
{
const
RTreesImpl
*
this_
=
dynamic_cast
<
const
RTreesImpl
*>
(
this
);
if
(
!
this_
)
CV_Error
(
Error
::
StsNotImplemented
,
"the class is not RTreesImpl"
);
return
this_
->
getVotes_
(
input
,
output
,
flags
);
}
}}
// End of file.
modules/ml/test/test_mltests.cpp
View file @
e0ee2f76
...
...
@@ -172,4 +172,49 @@ TEST(ML_NBAYES, regression_5911)
EXPECT_EQ
(
sum
(
P1
==
P3
)[
0
],
255
*
P3
.
total
());
}
TEST
(
ML_RTrees
,
getVotes
)
{
int
n
=
12
;
int
count
,
i
;
int
label_size
=
3
;
int
predicted_class
=
0
;
int
max_votes
=
-
1
;
int
val
;
// RTrees for classification
Ptr
<
ml
::
RTrees
>
rt
=
cv
::
ml
::
RTrees
::
create
();
//data
Mat
data
(
n
,
4
,
CV_32F
);
randu
(
data
,
0
,
10
);
//labels
Mat
labels
=
(
Mat_
<
int
>
(
n
,
1
)
<<
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
2
,
2
,
2
,
2
);
rt
->
train
(
data
,
ml
::
ROW_SAMPLE
,
labels
);
//run function
Mat
test
(
1
,
4
,
CV_32F
);
Mat
result
;
randu
(
test
,
0
,
10
);
rt
->
getVotes
(
test
,
result
,
0
);
//count vote amount and find highest vote
count
=
0
;
const
int
*
result_row
=
result
.
ptr
<
int
>
(
1
);
for
(
i
=
0
;
i
<
label_size
;
i
++
)
{
val
=
result_row
[
i
];
//predicted_class = max_votes < val? i;
if
(
max_votes
<
val
)
{
max_votes
=
val
;
predicted_class
=
i
;
}
count
+=
val
;
}
EXPECT_EQ
(
count
,
(
int
)
rt
->
getRoots
().
size
());
EXPECT_EQ
(
result
.
at
<
float
>
(
0
,
predicted_class
),
rt
->
predict
(
test
));
}
/* End of file. */
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment