Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
fc5bba66
Commit
fc5bba66
authored
Apr 24, 2018
by
berak
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ml: refactor non-virtual methods
parent
4d7d630e
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
75 deletions
+38
-75
ml.hpp
modules/ml/include/opencv2/ml.hpp
+7
-7
data.cpp
modules/ml/src/data.cpp
+27
-30
rtrees.cpp
modules/ml/src/rtrees.cpp
+1
-15
svm.cpp
modules/ml/src/svm.cpp
+3
-23
No files found.
modules/ml/include/opencv2/ml.hpp
View file @
fc5bba66
...
...
@@ -198,7 +198,7 @@ public:
CV_WRAP
virtual
Mat
getTestSampleWeights
()
const
=
0
;
CV_WRAP
virtual
Mat
getVarIdx
()
const
=
0
;
CV_WRAP
virtual
Mat
getVarType
()
const
=
0
;
CV_WRAP
Mat
getVarSymbolFlags
()
const
;
CV_WRAP
virtual
Mat
getVarSymbolFlags
()
const
=
0
;
CV_WRAP
virtual
int
getResponseType
()
const
=
0
;
CV_WRAP
virtual
Mat
getTrainSampleIdx
()
const
=
0
;
CV_WRAP
virtual
Mat
getTestSampleIdx
()
const
=
0
;
...
...
@@ -234,10 +234,10 @@ public:
CV_WRAP
virtual
void
shuffleTrainTest
()
=
0
;
/** @brief Returns matrix of test samples */
CV_WRAP
Mat
getTestSamples
()
const
;
CV_WRAP
virtual
Mat
getTestSamples
()
const
=
0
;
/** @brief Returns vector of symbolic names captured in loadFromCSV() */
CV_WRAP
v
oid
getNames
(
std
::
vector
<
String
>&
names
)
const
;
CV_WRAP
v
irtual
void
getNames
(
std
::
vector
<
String
>&
names
)
const
=
0
;
CV_WRAP
static
Mat
getSubVector
(
const
Mat
&
vec
,
const
Mat
&
idx
);
...
...
@@ -727,7 +727,7 @@ public:
regression (SVM::EPS_SVR or SVM::NU_SVR). If it is SVM::ONE_CLASS, no optimization is made and
the usual %SVM with parameters specified in params is executed.
*/
CV_WRAP
bool
trainAuto
(
InputArray
samples
,
CV_WRAP
virtual
bool
trainAuto
(
InputArray
samples
,
int
layout
,
InputArray
responses
,
int
kFold
=
10
,
...
...
@@ -737,7 +737,7 @@ public:
Ptr
<
ParamGrid
>
nuGrid
=
SVM
::
getDefaultGridPtr
(
SVM
::
NU
),
Ptr
<
ParamGrid
>
coeffGrid
=
SVM
::
getDefaultGridPtr
(
SVM
::
COEF
),
Ptr
<
ParamGrid
>
degreeGrid
=
SVM
::
getDefaultGridPtr
(
SVM
::
DEGREE
),
bool
balanced
=
false
);
bool
balanced
=
false
)
=
0
;
/** @brief Retrieves all the support vectors
...
...
@@ -752,7 +752,7 @@ public:
support vector, used for prediction, was derived from. They are returned in a floating-point
matrix, where the support vectors are stored as matrix rows.
*/
CV_WRAP
Mat
getUncompressedSupportVectors
()
const
;
CV_WRAP
virtual
Mat
getUncompressedSupportVectors
()
const
=
0
;
/** @brief Retrieves the decision function
...
...
@@ -1273,7 +1273,7 @@ public:
@param results Array where the result of the calculation will be written.
@param flags Flags for defining the type of RTrees.
*/
CV_WRAP
v
oid
getVotes
(
InputArray
samples
,
OutputArray
results
,
int
flags
)
const
;
CV_WRAP
v
irtual
void
getVotes
(
InputArray
samples
,
OutputArray
results
,
int
flags
)
const
=
0
;
/** Creates the empty model.
Use StatModel::train to train the model, StatModel::train to create and train the model,
...
...
modules/ml/src/data.cpp
View file @
fc5bba66
...
...
@@ -50,13 +50,6 @@ static const int VAR_MISSED = VAR_ORDERED;
TrainData
::~
TrainData
()
{}
Mat
TrainData
::
getTestSamples
()
const
{
Mat
idx
=
getTestSampleIdx
();
Mat
samples
=
getSamples
();
return
idx
.
empty
()
?
Mat
()
:
getSubVector
(
samples
,
idx
);
}
Mat
TrainData
::
getSubVector
(
const
Mat
&
vec
,
const
Mat
&
idx
)
{
if
(
idx
.
empty
()
)
...
...
@@ -119,6 +112,7 @@ Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
return
subvec
;
}
class
TrainDataImpl
CV_FINAL
:
public
TrainData
{
public
:
...
...
@@ -155,6 +149,12 @@ public:
return
layout
==
ROW_SAMPLE
?
samples
.
cols
:
samples
.
rows
;
}
Mat
getTestSamples
()
const
CV_OVERRIDE
{
Mat
idx
=
getTestSampleIdx
();
return
idx
.
empty
()
?
Mat
()
:
getSubVector
(
samples
,
idx
);
}
Mat
getSamples
()
const
CV_OVERRIDE
{
return
samples
;
}
Mat
getResponses
()
const
CV_OVERRIDE
{
return
responses
;
}
Mat
getMissing
()
const
CV_OVERRIDE
{
return
missing
;
}
...
...
@@ -987,22 +987,11 @@ public:
}
}
FILE
*
file
;
int
layout
;
Mat
samples
,
missing
,
varType
,
varIdx
,
varSymbolFlags
,
responses
,
missingSubst
;
Mat
sampleIdx
,
trainSampleIdx
,
testSampleIdx
;
Mat
sampleWeights
,
catMap
,
catOfs
;
Mat
normCatResponses
,
classLabels
,
classCounters
;
MapType
nameMap
;
};
void
TrainData
::
getNames
(
std
::
vector
<
String
>&
names
)
const
{
const
TrainDataImpl
*
impl
=
dynamic_cast
<
const
TrainDataImpl
*>
(
this
);
CV_Assert
(
impl
!=
0
);
size_t
n
=
impl
->
nameMap
.
size
();
TrainDataImpl
::
MapType
::
const_iterator
it
=
impl
->
nameMap
.
begin
(),
it_end
=
impl
->
nameMap
.
end
();
void
getNames
(
std
::
vector
<
String
>&
names
)
const
CV_OVERRIDE
{
size_t
n
=
nameMap
.
size
();
TrainDataImpl
::
MapType
::
const_iterator
it
=
nameMap
.
begin
(),
it_end
=
nameMap
.
end
();
names
.
resize
(
n
+
1
);
names
[
0
]
=
"?"
;
for
(
;
it
!=
it_end
;
++
it
)
...
...
@@ -1012,14 +1001,22 @@ void TrainData::getNames(std::vector<String>& names) const
CV_Assert
(
label
>
0
&&
label
<=
(
int
)
n
);
names
[
label
]
=
s
;
}
}
}
Mat
getVarSymbolFlags
()
const
CV_OVERRIDE
{
return
varSymbolFlags
;
}
FILE
*
file
;
int
layout
;
Mat
samples
,
missing
,
varType
,
varIdx
,
varSymbolFlags
,
responses
,
missingSubst
;
Mat
sampleIdx
,
trainSampleIdx
,
testSampleIdx
;
Mat
sampleWeights
,
catMap
,
catOfs
;
Mat
normCatResponses
,
classLabels
,
classCounters
;
MapType
nameMap
;
};
Mat
TrainData
::
getVarSymbolFlags
()
const
{
const
TrainDataImpl
*
impl
=
dynamic_cast
<
const
TrainDataImpl
*>
(
this
);
CV_Assert
(
impl
!=
0
);
return
impl
->
varSymbolFlags
;
}
Ptr
<
TrainData
>
TrainData
::
loadFromCSV
(
const
String
&
filename
,
int
headerLines
,
...
...
modules/ml/src/rtrees.cpp
View file @
fc5bba66
...
...
@@ -453,6 +453,7 @@ public:
inline
void
setRegressionAccuracy
(
float
val
)
CV_OVERRIDE
{
impl
.
params
.
setRegressionAccuracy
(
val
);
}
inline
cv
::
Mat
getPriors
()
const
CV_OVERRIDE
{
return
impl
.
params
.
getPriors
();
}
inline
void
setPriors
(
const
cv
::
Mat
&
val
)
CV_OVERRIDE
{
impl
.
params
.
setPriors
(
val
);
}
inline
void
getVotes
(
InputArray
input
,
OutputArray
output
,
int
flags
)
const
CV_OVERRIDE
{
return
impl
.
getVotes
(
input
,
output
,
flags
);}
RTreesImpl
()
{}
virtual
~
RTreesImpl
()
CV_OVERRIDE
{}
...
...
@@ -485,12 +486,6 @@ public:
impl
.
read
(
fn
);
}
void
getVotes_
(
InputArray
samples
,
OutputArray
results
,
int
flags
)
const
{
CV_TRACE_FUNCTION
();
impl
.
getVotes
(
samples
,
results
,
flags
);
}
Mat
getVarImportance
()
const
CV_OVERRIDE
{
return
Mat_
<
float
>
(
impl
.
varImportance
,
true
);
}
int
getVarCount
()
const
CV_OVERRIDE
{
return
impl
.
getVarCount
();
}
...
...
@@ -519,15 +514,6 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
return
Algorithm
::
load
<
RTrees
>
(
filepath
,
nodeName
);
}
void
RTrees
::
getVotes
(
InputArray
input
,
OutputArray
output
,
int
flags
)
const
{
CV_TRACE_FUNCTION
();
const
RTreesImpl
*
this_
=
dynamic_cast
<
const
RTreesImpl
*>
(
this
);
if
(
!
this_
)
CV_Error
(
Error
::
StsNotImplemented
,
"the class is not RTreesImpl"
);
return
this_
->
getVotes_
(
input
,
output
,
flags
);
}
}}
// End of file.
modules/ml/src/svm.cpp
View file @
fc5bba66
...
...
@@ -1250,7 +1250,7 @@ public:
uncompressed_sv
.
release
();
}
Mat
getUncompressedSupportVectors
_
()
const
Mat
getUncompressedSupportVectors
()
const
CV_OVERRIDE
{
return
uncompressed_sv
;
}
...
...
@@ -1982,10 +1982,10 @@ public:
bool
returnDFVal
;
};
bool
trainAuto
_
(
InputArray
samples
,
int
layout
,
bool
trainAuto
(
InputArray
samples
,
int
layout
,
InputArray
responses
,
int
kfold
,
Ptr
<
ParamGrid
>
Cgrid
,
Ptr
<
ParamGrid
>
gammaGrid
,
Ptr
<
ParamGrid
>
pGrid
,
Ptr
<
ParamGrid
>
nuGrid
,
Ptr
<
ParamGrid
>
coeffGrid
,
Ptr
<
ParamGrid
>
degreeGrid
,
bool
balanced
)
Ptr
<
ParamGrid
>
coeffGrid
,
Ptr
<
ParamGrid
>
degreeGrid
,
bool
balanced
)
CV_OVERRIDE
{
Ptr
<
TrainData
>
data
=
TrainData
::
create
(
samples
,
layout
,
responses
);
return
this
->
trainAuto
(
...
...
@@ -2353,26 +2353,6 @@ Ptr<SVM> SVM::load(const String& filepath)
return
svm
;
}
Mat
SVM
::
getUncompressedSupportVectors
()
const
{
const
SVMImpl
*
this_
=
dynamic_cast
<
const
SVMImpl
*>
(
this
);
if
(
!
this_
)
CV_Error
(
Error
::
StsNotImplemented
,
"the class is not SVMImpl"
);
return
this_
->
getUncompressedSupportVectors_
();
}
bool
SVM
::
trainAuto
(
InputArray
samples
,
int
layout
,
InputArray
responses
,
int
kfold
,
Ptr
<
ParamGrid
>
Cgrid
,
Ptr
<
ParamGrid
>
gammaGrid
,
Ptr
<
ParamGrid
>
pGrid
,
Ptr
<
ParamGrid
>
nuGrid
,
Ptr
<
ParamGrid
>
coeffGrid
,
Ptr
<
ParamGrid
>
degreeGrid
,
bool
balanced
)
{
SVMImpl
*
this_
=
dynamic_cast
<
SVMImpl
*>
(
this
);
if
(
!
this_
)
{
CV_Error
(
Error
::
StsNotImplemented
,
"the class is not SVMImpl"
);
}
return
this_
->
trainAuto_
(
samples
,
layout
,
responses
,
kfold
,
Cgrid
,
gammaGrid
,
pGrid
,
nuGrid
,
coeffGrid
,
degreeGrid
,
balanced
);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment