Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
11d565ca
Commit
11d565ca
authored
Mar 17, 2020
by
Dmitry Kurtaev
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix LSTM from ONNX with batch==1
parent
8d69dbdf
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
37 deletions
+69
-37
recurrent_layers.cpp
modules/dnn/src/layers/recurrent_layers.cpp
+5
-4
onnx_importer.cpp
modules/dnn/src/onnx/onnx_importer.cpp
+64
-33
No files found.
modules/dnn/src/layers/recurrent_layers.cpp
View file @
11d565ca
...
...
@@ -110,10 +110,11 @@ public:
const
Mat
&
Wh
=
blobs
[
0
];
const
Mat
&
Wx
=
blobs
[
1
];
const
Mat
&
bias
=
blobs
[
2
];
CV_Assert
(
Wh
.
dims
==
2
&&
Wx
.
dims
==
2
);
CV_Assert
(
Wh
.
rows
==
Wx
.
rows
);
CV_Assert
(
Wh
.
rows
==
4
*
Wh
.
cols
);
CV_Assert
(
Wh
.
rows
==
(
int
)
bias
.
total
());
CV_CheckEQ
(
Wh
.
dims
,
2
,
""
);
CV_CheckEQ
(
Wx
.
dims
,
2
,
""
);
CV_CheckEQ
(
Wh
.
rows
,
Wx
.
rows
,
""
);
CV_CheckEQ
(
Wh
.
rows
,
4
*
Wh
.
cols
,
""
);
CV_CheckEQ
(
Wh
.
rows
,
(
int
)
bias
.
total
(),
""
);
CV_Assert
(
Wh
.
type
()
==
Wx
.
type
()
&&
Wx
.
type
()
==
bias
.
type
());
// Peephole weights.
...
...
modules/dnn/src/onnx/onnx_importer.cpp
View file @
11d565ca
...
...
@@ -49,6 +49,11 @@ class ONNXImporter
LayerParams
getLayerParams
(
const
opencv_onnx
::
NodeProto
&
node_proto
);
bool
isCeilMode
(
const
LayerParams
&
layerParams
);
void
addLayer
(
Net
&
dstNet
,
LayerParams
&
layerParams
,
const
opencv_onnx
::
NodeProto
&
node_proto
,
std
::
map
<
std
::
string
,
LayerInfo
>&
layer_id
,
std
::
map
<
std
::
string
,
MatShape
>&
outShapes
);
public
:
ONNXImporter
(
const
char
*
onnxFile
)
...
...
@@ -259,6 +264,42 @@ Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto,
return
constBlob
->
second
;
}
void
ONNXImporter
::
addLayer
(
Net
&
dstNet
,
LayerParams
&
layerParams
,
const
opencv_onnx
::
NodeProto
&
node_proto
,
std
::
map
<
std
::
string
,
LayerInfo
>&
layer_id
,
std
::
map
<
std
::
string
,
MatShape
>&
outShapes
)
{
std
::
map
<
std
::
string
,
LayerInfo
>::
iterator
layerId
;
std
::
map
<
std
::
string
,
MatShape
>::
iterator
shapeIt
;
int
id
=
dstNet
.
addLayer
(
layerParams
.
name
,
layerParams
.
type
,
layerParams
);
for
(
int
i
=
0
;
i
<
node_proto
.
output_size
();
++
i
)
{
layer_id
.
insert
(
std
::
make_pair
(
node_proto
.
output
(
i
),
LayerInfo
(
id
,
i
)));
}
std
::
vector
<
MatShape
>
layerInpShapes
,
layerOutShapes
,
layerInternalShapes
;
int
inpNum
=
0
;
for
(
int
j
=
0
;
j
<
node_proto
.
input_size
();
j
++
)
{
layerId
=
layer_id
.
find
(
node_proto
.
input
(
j
));
if
(
layerId
!=
layer_id
.
end
())
{
dstNet
.
connect
(
layerId
->
second
.
layerId
,
layerId
->
second
.
outputId
,
id
,
inpNum
);
++
inpNum
;
// Collect input shapes.
shapeIt
=
outShapes
.
find
(
node_proto
.
input
(
j
));
CV_Assert
(
shapeIt
!=
outShapes
.
end
());
layerInpShapes
.
push_back
(
shapeIt
->
second
);
}
}
// Compute shape of output blob for this layer.
Ptr
<
Layer
>
layer
=
dstNet
.
getLayer
(
id
);
layer
->
getMemoryShapes
(
layerInpShapes
,
0
,
layerOutShapes
,
layerInternalShapes
);
for
(
int
i
=
0
;
i
<
node_proto
.
output_size
()
&&
i
<
(
int
)
layerOutShapes
.
size
();
++
i
)
{
outShapes
[
node_proto
.
output
(
i
)]
=
layerOutShapes
[
i
];
}
}
void
ONNXImporter
::
populateNet
(
Net
dstNet
)
{
CV_Assert
(
model_proto
.
has_graph
());
...
...
@@ -581,13 +622,16 @@ void ONNXImporter::populateNet(Net dstNet)
}
else
if
(
layer_type
==
"LSTM"
)
{
LayerParams
lstmParams
=
layerParams
;
lstmParams
.
name
+=
"/lstm"
;
// https://pytorch.org/docs/stable/nn.html#lstm
CV_Assert
(
node_proto
.
input_size
()
==
7
);
Mat
Wx
=
getBlob
(
node_proto
,
constBlobs
,
1
);
Mat
Wh
=
getBlob
(
node_proto
,
constBlobs
,
2
);
Mat
b
=
getBlob
(
node_proto
,
constBlobs
,
3
);
const
int
numHidden
=
Wh
.
size
[
2
]
;
const
int
numHidden
=
lstmParams
.
get
<
int
>
(
"hidden_size"
)
;
Wx
=
Wx
.
reshape
(
1
,
Wx
.
size
[
1
]);
Wh
=
Wh
.
reshape
(
1
,
Wh
.
size
[
1
]);
...
...
@@ -612,10 +656,24 @@ void ONNXImporter::populateNet(Net dstNet)
}
std
::
swap
(
biasData
[
numHidden
+
j
],
biasData
[
numHidden
*
2
+
j
]);
}
layerParams
.
blobs
.
resize
(
3
);
layerParams
.
blobs
[
0
]
=
Wh
;
layerParams
.
blobs
[
1
]
=
Wx
;
layerParams
.
blobs
[
2
]
=
b
;
lstmParams
.
blobs
.
resize
(
3
);
lstmParams
.
blobs
[
0
]
=
Wh
;
lstmParams
.
blobs
[
1
]
=
Wx
;
lstmParams
.
blobs
[
2
]
=
b
;
node_proto
.
set_output
(
0
,
lstmParams
.
name
);
// set different name so output shapes will be registered on that name
addLayer
(
dstNet
,
lstmParams
,
node_proto
,
layer_id
,
outShapes
);
MatShape
lstmShape
=
outShapes
[
node_proto
.
output
(
0
)];
// Add fake 1 as it is done in ONNX
lstmShape
.
insert
(
lstmShape
.
begin
()
+
1
,
1
);
layerParams
.
type
=
"Reshape"
;
layerParams
.
set
(
"dim"
,
DictValue
::
arrayInt
(
&
lstmShape
[
0
],
lstmShape
.
size
()));
node_proto
.
set_input
(
0
,
lstmParams
.
name
);
// redirect input to LSTM
node_proto
.
set_output
(
0
,
layerParams
.
name
);
// keep origin LSTM's name
}
else
if
(
layer_type
==
"ImageScaler"
)
{
...
...
@@ -1228,34 +1286,7 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams
.
blobs
.
push_back
(
getBlob
(
node_proto
,
constBlobs
,
j
));
}
}
int
id
=
dstNet
.
addLayer
(
layerParams
.
name
,
layerParams
.
type
,
layerParams
);
for
(
int
i
=
0
;
i
<
node_proto
.
output_size
();
++
i
)
{
layer_id
.
insert
(
std
::
make_pair
(
node_proto
.
output
(
i
),
LayerInfo
(
id
,
i
)));
}
std
::
vector
<
MatShape
>
layerInpShapes
,
layerOutShapes
,
layerInternalShapes
;
int
inpNum
=
0
;
for
(
int
j
=
0
;
j
<
node_proto
.
input_size
();
j
++
)
{
layerId
=
layer_id
.
find
(
node_proto
.
input
(
j
));
if
(
layerId
!=
layer_id
.
end
())
{
dstNet
.
connect
(
layerId
->
second
.
layerId
,
layerId
->
second
.
outputId
,
id
,
inpNum
);
++
inpNum
;
// Collect input shapes.
shapeIt
=
outShapes
.
find
(
node_proto
.
input
(
j
));
CV_Assert
(
shapeIt
!=
outShapes
.
end
());
layerInpShapes
.
push_back
(
shapeIt
->
second
);
}
}
// Compute shape of output blob for this layer.
Ptr
<
Layer
>
layer
=
dstNet
.
getLayer
(
id
);
layer
->
getMemoryShapes
(
layerInpShapes
,
0
,
layerOutShapes
,
layerInternalShapes
);
for
(
int
i
=
0
;
i
<
node_proto
.
output_size
()
&&
i
<
(
int
)
layerOutShapes
.
size
();
++
i
)
{
outShapes
[
node_proto
.
output
(
i
)]
=
layerOutShapes
[
i
];
}
addLayer
(
dstNet
,
layerParams
,
node_proto
,
layer_id
,
outShapes
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment