Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
bbbec300
Commit
bbbec300
authored
7 years ago
by
Dmitry Kurtaev
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
nn.BatchNormalization and nn.Dropout layers from Torch
parent
fc9e0314
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
4 deletions
+18
-4
batch_norm_layer.cpp
modules/dnn/src/layers/batch_norm_layer.cpp
+3
-2
torch_importer.cpp
modules/dnn/src/torch/torch_importer.cpp
+10
-2
test_torch_importer.cpp
modules/dnn/test/test_torch_importer.cpp
+5
-0
No files found.
modules/dnn/src/layers/batch_norm_layer.cpp
View file @
bbbec300
...
...
@@ -119,8 +119,9 @@ public:
CV_Assert
(
inputs
.
size
()
==
1
);
Mat
&
inpBlob
=
*
inputs
[
0
];
int
rows
=
inpBlob
.
size
[
2
];
int
cols
=
inpBlob
.
size
[
3
];
CV_Assert
(
inpBlob
.
dims
==
2
||
inpBlob
.
dims
==
4
);
int
rows
=
inpBlob
.
dims
>
2
?
inpBlob
.
size
[
2
]
:
1
;
int
cols
=
inpBlob
.
dims
>
2
?
inpBlob
.
size
[
3
]
:
1
;
for
(
size_t
ii
=
0
;
ii
<
outputs
.
size
();
ii
++
)
{
...
...
This diff is collapsed.
Click to expand it.
modules/dnn/src/torch/torch_importer.cpp
View file @
bbbec300
...
...
@@ -617,7 +617,8 @@ struct TorchImporter : public ::cv::dnn::Importer
curModule
->
modules
.
push_back
(
cv
::
Ptr
<
Module
>
(
new
Module
(
nnName
,
"Sigmoid"
)));
readObject
();
}
else
if
(
nnName
==
"SpatialBatchNormalization"
||
nnName
==
"InstanceNormalization"
)
else
if
(
nnName
==
"SpatialBatchNormalization"
||
nnName
==
"InstanceNormalization"
||
nnName
==
"BatchNormalization"
)
{
newModule
->
apiType
=
"BatchNorm"
;
readTorchTable
(
scalarParams
,
tensorParams
);
...
...
@@ -700,17 +701,24 @@ struct TorchImporter : public ::cv::dnn::Importer
curModule
->
modules
.
push_back
(
newModule
);
}
else
if
(
nnName
==
"SpatialDropout"
)
else
if
(
nnName
==
"SpatialDropout"
||
nnName
==
"Dropout"
)
{
readTorchTable
(
scalarParams
,
tensorParams
);
CV_Assert
(
scalarParams
.
has
(
"p"
));
if
(
scalarParams
.
has
(
"v2"
)
&&
scalarParams
.
get
<
bool
>
(
"v2"
))
{
newModule
->
apiType
=
"Identity"
;
}
else
{
float
scale
=
1
-
scalarParams
.
get
<
double
>
(
"p"
);
CV_Assert
(
scale
>
0
);
newModule
->
apiType
=
"Power"
;
layerParams
.
set
(
"scale"
,
scale
);
}
curModule
->
modules
.
push_back
(
newModule
);
}
// TotalVariation layer is from fast-neural-style project: https://github.com/jcjohnson/fast-neural-style
...
...
This diff is collapsed.
Click to expand it.
modules/dnn/test/test_torch_importer.cpp
View file @
bbbec300
...
...
@@ -234,6 +234,11 @@ TEST(Torch_Importer, net_padding)
runTorchNet
(
"net_spatial_reflection_padding"
,
DNN_TARGET_CPU
,
""
,
false
,
true
);
}
TEST
(
Torch_Importer
,
net_non_spatial
)
{
runTorchNet
(
"net_non_spatial"
,
DNN_TARGET_CPU
,
""
,
false
,
true
);
}
TEST
(
Torch_Importer
,
ENet_accuracy
)
{
Net
net
;
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment