Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
20400aa9
Commit
20400aa9
authored
6 years ago
by
Dmitry Kurtaev
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Import Upsample and Unsqueeze from ONNX
parent
1db5d82b
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
3 deletions
+51
-3
onnx_importer.cpp
modules/dnn/src/onnx/onnx_importer.cpp
+41
-3
test_onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp
+10
-0
No files found.
modules/dnn/src/onnx/onnx_importer.cpp
View file @
20400aa9
...
...
@@ -392,10 +392,10 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams
.
set
(
"ceil_mode"
,
isCeilMode
(
layerParams
));
layerParams
.
set
(
"ave_pool_padded_area"
,
framework_name
==
"pytorch"
);
}
else
if
(
layer_type
==
"GlobalAveragePool"
)
else
if
(
layer_type
==
"GlobalAveragePool"
||
layer_type
==
"GlobalMaxPool"
)
{
layerParams
.
type
=
"Pooling"
;
layerParams
.
set
(
"pool"
,
"AVE
"
);
layerParams
.
set
(
"pool"
,
layer_type
==
"GlobalAveragePool"
?
"AVE"
:
"MAX
"
);
layerParams
.
set
(
"global_pooling"
,
true
);
}
else
if
(
layer_type
==
"Add"
||
layer_type
==
"Sum"
)
...
...
@@ -448,6 +448,11 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams
.
set
(
"bias_term"
,
false
);
}
}
else
if
(
layer_type
==
"Neg"
)
{
layerParams
.
type
=
"Power"
;
layerParams
.
set
(
"scale"
,
-
1
);
}
else
if
(
layer_type
==
"Constant"
)
{
CV_Assert
(
node_proto
.
input_size
()
==
0
);
...
...
@@ -595,9 +600,12 @@ void ONNXImporter::populateNet(Net dstNet)
else
if
(
layer_type
==
"Unsqueeze"
)
{
CV_Assert
(
node_proto
.
input_size
()
==
1
);
DictValue
axes
=
layerParams
.
get
(
"axes"
);
if
(
constBlobs
.
find
(
node_proto
.
input
(
0
))
!=
constBlobs
.
end
())
{
// Constant input.
Mat
input
=
getBlob
(
node_proto
,
constBlobs
,
0
);
DictValue
axes
=
layerParams
.
get
(
"axes"
);
std
::
vector
<
int
>
dims
;
for
(
int
j
=
0
;
j
<
input
.
dims
;
j
++
)
{
dims
.
push_back
(
input
.
size
[
j
]);
...
...
@@ -611,6 +619,17 @@ void ONNXImporter::populateNet(Net dstNet)
constBlobs
.
insert
(
std
::
make_pair
(
layerParams
.
name
,
out
));
continue
;
}
// Variable input.
if
(
axes
.
size
()
!=
1
)
CV_Error
(
Error
::
StsNotImplemented
,
"Multidimensional unsqueeze"
);
int
dims
[]
=
{
1
,
-
1
};
layerParams
.
type
=
"Reshape"
;
layerParams
.
set
(
"axis"
,
axes
.
getIntValue
(
0
));
layerParams
.
set
(
"num_axes"
,
1
);
layerParams
.
set
(
"dim"
,
DictValue
::
arrayInt
(
&
dims
[
0
],
2
));
}
else
if
(
layer_type
==
"Reshape"
)
{
CV_Assert
(
node_proto
.
input_size
()
==
2
||
layerParams
.
has
(
"shape"
));
...
...
@@ -707,6 +726,25 @@ void ONNXImporter::populateNet(Net dstNet)
continue
;
}
}
else
if
(
layer_type
==
"Upsample"
)
{
layerParams
.
type
=
"Resize"
;
if
(
layerParams
.
has
(
"scales"
))
{
// Pytorch layer
DictValue
scales
=
layerParams
.
get
(
"scales"
);
CV_Assert
(
scales
.
size
()
==
4
);
layerParams
.
set
(
"zoom_factor_y"
,
scales
.
getIntValue
(
2
));
layerParams
.
set
(
"zoom_factor_x"
,
scales
.
getIntValue
(
3
));
}
else
{
// Caffe2 layer
replaceLayerParam
(
layerParams
,
"height_scale"
,
"zoom_factor_y"
);
replaceLayerParam
(
layerParams
,
"width_scale"
,
"zoom_factor_x"
);
}
replaceLayerParam
(
layerParams
,
"mode"
,
"interpolation"
);
}
else
{
for
(
int
j
=
0
;
j
<
node_proto
.
input_size
();
j
++
)
{
...
...
This diff is collapsed.
Click to expand it.
modules/dnn/test/test_onnx_importer.cpp
View file @
20400aa9
...
...
@@ -140,6 +140,11 @@ TEST_P(Test_ONNX_layers, Padding)
testONNXModels
(
"padding"
);
}
TEST_P
(
Test_ONNX_layers
,
Resize
)
{
testONNXModels
(
"resize_nearest"
);
}
TEST_P
(
Test_ONNX_layers
,
MultyInputs
)
{
const
String
model
=
_tf
(
"models/multy_inputs.onnx"
);
...
...
@@ -169,6 +174,11 @@ TEST_P(Test_ONNX_layers, DynamicReshape)
testONNXModels
(
"dynamic_reshape"
);
}
TEST_P
(
Test_ONNX_layers
,
Reshape
)
{
testONNXModels
(
"unsqueeze"
);
}
INSTANTIATE_TEST_CASE_P
(
/*nothing*/
,
Test_ONNX_layers
,
dnnBackendsAndTargets
());
class
Test_ONNX_nets
:
public
Test_ONNX_layers
{};
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment