Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
c1c84d2f
Commit
c1c84d2f
authored
Jan 06, 2020
by
Dmitry Kurtaev
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ONNX graphs simplifier
parent
74bc8d35
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
565 additions
and
192 deletions
+565
-192
graph_simplifier.cpp
modules/dnn/src/graph_simplifier.cpp
+207
-0
graph_simplifier.hpp
modules/dnn/src/graph_simplifier.hpp
+100
-0
onnx_graph_simplifier.cpp
modules/dnn/src/onnx/onnx_graph_simplifier.cpp
+157
-0
onnx_graph_simplifier.hpp
modules/dnn/src/onnx/onnx_graph_simplifier.hpp
+30
-0
onnx_importer.cpp
modules/dnn/src/onnx/onnx_importer.cpp
+5
-0
tf_graph_simplifier.cpp
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
+65
-192
test_onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp
+1
-0
No files found.
modules/dnn/src/graph_simplifier.cpp
0 → 100644
View file @
c1c84d2f
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2020, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#include "precomp.hpp"
#include "graph_simplifier.hpp"
#include <queue>
namespace
cv
{
namespace
dnn
{
Subgraph
::~
Subgraph
()
{}
int
Subgraph
::
addNodeToMatch
(
const
std
::
string
&
op
,
int
input_0
,
int
input_1
,
int
input_2
,
int
input_3
)
{
int
nodeInputs
[]
=
{
input_0
,
input_1
,
input_2
,
input_3
};
int
numInputs
=
0
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
numInputs
+=
(
int
)(
nodeInputs
[
i
]
!=
-
1
);
}
return
addNodeToMatch
(
op
,
std
::
vector
<
int
>
(
&
nodeInputs
[
0
],
&
nodeInputs
[
0
]
+
numInputs
));
}
int
Subgraph
::
addNodeToMatch
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
)
{
for
(
int
i
=
0
;
i
<
inputs_
.
size
();
++
i
)
{
CV_Assert
(
inputs_
[
i
]
<
(
int
)
nodes
.
size
());
}
nodes
.
push_back
(
op
);
inputs
.
push_back
(
inputs_
);
return
nodes
.
size
()
-
1
;
}
void
Subgraph
::
setFusedNode
(
const
std
::
string
&
op
,
int
input_0
,
int
input_1
,
int
input_2
,
int
input_3
,
int
input_4
,
int
input_5
)
{
int
nodeInputs
[]
=
{
input_0
,
input_1
,
input_2
,
input_3
,
input_4
,
input_5
};
int
numInputs
=
0
;
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
CV_Assert
(
nodeInputs
[
i
]
<
(
int
)
nodes
.
size
());
numInputs
+=
(
int
)(
nodeInputs
[
i
]
!=
-
1
);
}
setFusedNode
(
op
,
std
::
vector
<
int
>
(
&
nodeInputs
[
0
],
&
nodeInputs
[
0
]
+
numInputs
));
}
void
Subgraph
::
setFusedNode
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
)
{
fusedNodeInputs
=
inputs_
;
fusedNodeOp
=
op
;
}
int
Subgraph
::
getInputNodeId
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
Ptr
<
ImportNodeWrapper
>&
node
,
int
inpId
)
{
CV_Assert
(
inpId
<
node
->
getNumInputs
());
std
::
string
name
=
node
->
getInputName
(
inpId
);
// If operation produces several tensors, they are specified by index
// after ':' character. In example, "input:0".
name
=
name
.
substr
(
0
,
name
.
rfind
(
':'
));
const
int
numNodes
=
net
->
getNumNodes
();
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
if
(
net
->
getNodeName
(
i
)
==
name
)
return
i
;
}
CV_Error
(
Error
::
StsParseError
,
"Input node with name "
+
name
+
" not found"
);
}
bool
Subgraph
::
match
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
{
matchedNodesIds
.
clear
();
targetNodesIds
.
clear
();
std
::
queue
<
int
>
nodesToMatch
;
std
::
queue
<
int
>
targetNodes
;
nodesToMatch
.
push
(
nodeId
);
targetNodes
.
push
(
nodes
.
size
()
-
1
);
while
(
!
nodesToMatch
.
empty
())
{
int
nodeToMatch
=
nodesToMatch
.
front
();
int
targetNodeId
=
targetNodes
.
front
();
nodesToMatch
.
pop
();
targetNodes
.
pop
();
if
(
std
::
find
(
matchedNodesIds
.
begin
(),
matchedNodesIds
.
end
(),
nodeToMatch
)
!=
matchedNodesIds
.
end
())
continue
;
const
Ptr
<
ImportNodeWrapper
>
node
=
net
->
getNode
(
nodeToMatch
);
if
(
node
->
getType
()
!=
nodes
[
targetNodeId
])
return
false
;
std
::
vector
<
int
>&
inputNodes
=
inputs
[
targetNodeId
];
if
(
inputNodes
.
size
()
!=
node
->
getNumInputs
())
return
false
;
for
(
int
j
=
0
;
j
<
inputNodes
.
size
();
++
j
)
{
if
(
nodes
[
inputNodes
[
j
]].
empty
())
// Unknown input node type.
continue
;
nodeId
=
getInputNodeId
(
net
,
node
,
j
);
const
Ptr
<
ImportNodeWrapper
>
inpNode
=
net
->
getNode
(
nodeId
);
if
(
inpNode
->
getType
()
!=
"Const"
)
{
nodesToMatch
.
push
(
nodeId
);
targetNodes
.
push
(
inputNodes
[
j
]);
}
else
if
(
nodes
[
inputNodes
[
j
]]
!=
"Const"
)
return
false
;
}
matchedNodesIds
.
push_back
(
nodeToMatch
);
targetNodesIds
.
push_back
(
targetNodeId
);
}
const
int
n
=
matchedNodesIds
.
size
();
std
::
vector
<
std
::
pair
<
int
,
int
>
>
elements
(
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
elements
[
i
]
=
std
::
make_pair
(
matchedNodesIds
[
i
],
targetNodesIds
[
i
]);
std
::
sort
(
elements
.
begin
(),
elements
.
end
());
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
matchedNodesIds
[
i
]
=
elements
[
i
].
first
;
targetNodesIds
[
i
]
=
elements
[
i
].
second
;
}
return
true
;
}
void
Subgraph
::
replace
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
std
::
vector
<
int
>&
matchedNodesIds
,
const
std
::
vector
<
int
>&
targetNodesIds
)
{
// Extract names of input nodes.
std
::
vector
<
std
::
string
>
inputsNames
(
fusedNodeInputs
.
size
());
for
(
int
i
=
0
;
i
<
fusedNodeInputs
.
size
();
++
i
)
{
std
::
string
inpName
;
// Find input node name looking at inputs of fused nodes.
for
(
int
j
=
0
;
j
<
matchedNodesIds
.
size
()
&&
inpName
.
empty
();
++
j
)
{
Ptr
<
ImportNodeWrapper
>
node
=
net
->
getNode
(
matchedNodesIds
[
j
]);
std
::
vector
<
int
>&
inpIndices
=
inputs
[
targetNodesIds
[
j
]];
CV_Assert
(
node
->
getNumInputs
()
==
inpIndices
.
size
());
for
(
int
k
=
0
;
k
<
inpIndices
.
size
();
++
k
)
{
if
(
inpIndices
[
k
]
==
fusedNodeInputs
[
i
])
{
inpName
=
node
->
getInputName
(
k
);
break
;
}
}
}
CV_Assert
(
!
inpName
.
empty
());
inputsNames
[
i
]
=
inpName
;
}
// Remove matched nodes except the last one. Indices in ascending order are expected.
Ptr
<
ImportNodeWrapper
>
node
=
net
->
getNode
(
matchedNodesIds
.
back
());
for
(
int
i
=
matchedNodesIds
.
size
()
-
2
;
i
>=
0
;
--
i
)
net
->
removeNode
(
matchedNodesIds
[
i
]);
// Modify the last node to be a fused one.
node
->
setType
(
fusedNodeOp
);
node
->
setInputNames
(
inputsNames
);
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>
inputNodes
(
inputsNames
.
size
());
for
(
int
i
=
0
;
i
<
inputsNames
.
size
();
++
i
)
{
inputNodes
[
i
]
=
net
->
getNode
(
getInputNodeId
(
net
,
node
,
i
));
}
finalize
(
net
,
node
,
inputNodes
);
}
void
Subgraph
::
finalize
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
Ptr
<
ImportNodeWrapper
>&
fusedNode
,
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>&
inputs
)
{}
void
simplifySubgraphs
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
std
::
vector
<
Ptr
<
Subgraph
>
>&
patterns
)
{
int
numNodes
=
net
->
getNumNodes
();
std
::
vector
<
int
>
matchedNodesIds
,
targetNodesIds
;
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
for
(
int
j
=
0
;
j
<
patterns
.
size
();
++
j
)
{
if
(
patterns
[
j
]
->
match
(
net
,
i
,
matchedNodesIds
,
targetNodesIds
))
{
patterns
[
j
]
->
replace
(
net
,
matchedNodesIds
,
targetNodesIds
);
numNodes
-=
matchedNodesIds
.
size
()
-
1
;
// #matchedNodes removed and one added.
break
;
}
}
}
}
}}
// namespace cv::dnn
modules/dnn/src/graph_simplifier.hpp
0 → 100644
View file @
c1c84d2f
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2020, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#ifndef __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__
#define __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__
#include <string>
#include <opencv2/core.hpp>
namespace
cv
{
namespace
dnn
{
class
ImportNodeWrapper
{
public
:
virtual
~
ImportNodeWrapper
()
{};
virtual
int
getNumInputs
()
const
=
0
;
virtual
std
::
string
getInputName
(
int
idx
)
const
=
0
;
virtual
std
::
string
getType
()
const
=
0
;
virtual
void
setType
(
const
std
::
string
&
type
)
=
0
;
virtual
void
setInputNames
(
const
std
::
vector
<
std
::
string
>&
inputs
)
=
0
;
};
class
ImportGraphWrapper
{
public
:
virtual
~
ImportGraphWrapper
()
{};
virtual
Ptr
<
ImportNodeWrapper
>
getNode
(
int
idx
)
const
=
0
;
virtual
int
getNumNodes
()
const
=
0
;
virtual
std
::
string
getNodeName
(
int
idx
)
const
=
0
;
virtual
void
removeNode
(
int
idx
)
=
0
;
};
class
Subgraph
// Interface to match and replace subgraphs.
{
public
:
virtual
~
Subgraph
();
// Add a node to be matched in the origin graph. Specify ids of nodes that
// are expected to be inputs. Returns id of a newly added node.
// TODO: Replace inputs to std::vector<int> in C++11
int
addNodeToMatch
(
const
std
::
string
&
op
,
int
input_0
=
-
1
,
int
input_1
=
-
1
,
int
input_2
=
-
1
,
int
input_3
=
-
1
);
int
addNodeToMatch
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
);
// Specify resulting node. All the matched nodes in subgraph excluding
// input nodes will be fused into this single node.
// TODO: Replace inputs to std::vector<int> in C++11
void
setFusedNode
(
const
std
::
string
&
op
,
int
input_0
=
-
1
,
int
input_1
=
-
1
,
int
input_2
=
-
1
,
int
input_3
=
-
1
,
int
input_4
=
-
1
,
int
input_5
=
-
1
);
void
setFusedNode
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
);
static
int
getInputNodeId
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
Ptr
<
ImportNodeWrapper
>&
node
,
int
inpId
);
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
virtual
bool
match
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
);
// Fuse matched subgraph.
void
replace
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
std
::
vector
<
int
>&
matchedNodesIds
,
const
std
::
vector
<
int
>&
targetNodesIds
);
virtual
void
finalize
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
Ptr
<
ImportNodeWrapper
>&
fusedNode
,
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>&
inputs
);
private
:
std
::
vector
<
std
::
string
>
nodes
;
// Nodes to be matched in the origin graph.
std
::
vector
<
std
::
vector
<
int
>
>
inputs
;
// Connections of an every node to it's inputs.
std
::
string
fusedNodeOp
;
// Operation name of resulting fused node.
std
::
vector
<
int
>
fusedNodeInputs
;
// Inputs of fused node.
};
void
simplifySubgraphs
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
std
::
vector
<
Ptr
<
Subgraph
>
>&
patterns
);
}}
// namespace dnn, namespace cv
#endif // __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__
modules/dnn/src/onnx/onnx_graph_simplifier.cpp
0 → 100644
View file @
c1c84d2f
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2020, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#include "../precomp.hpp"
#include "../graph_simplifier.hpp"
#include "onnx_graph_simplifier.hpp"
#include <queue>
namespace
cv
{
namespace
dnn
{
CV__DNN_EXPERIMENTAL_NS_BEGIN
// This wrapper can behave differently for fake input nodes and real graph nodes.
class
ONNXNodeWrapper
:
public
ImportNodeWrapper
{
public
:
ONNXNodeWrapper
(
opencv_onnx
::
NodeProto
*
_node
=
0
)
:
node
(
_node
)
{}
virtual
int
getNumInputs
()
const
CV_OVERRIDE
{
return
node
?
node
->
input_size
()
:
0
;
}
virtual
std
::
string
getInputName
(
int
idx
)
const
CV_OVERRIDE
{
CV_Assert_N
(
node
,
idx
<
node
->
input_size
());
return
node
->
input
(
idx
);
}
virtual
std
::
string
getType
()
const
CV_OVERRIDE
{
return
node
?
node
->
op_type
()
:
""
;
}
virtual
void
setType
(
const
std
::
string
&
type
)
CV_OVERRIDE
{
CV_Assert
(
node
);
node
->
set_op_type
(
type
);
}
virtual
void
setInputNames
(
const
std
::
vector
<
std
::
string
>&
inputs
)
CV_OVERRIDE
{
CV_Assert
(
node
);
node
->
clear_input
();
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
node
->
add_input
(
inputs
[
i
]);
}
opencv_onnx
::
NodeProto
*
node
;
};
// ONNX graph's inputs are separate from nodes so we index them before the rest of nodes.
class
ONNXGraphWrapper
:
public
ImportGraphWrapper
{
public
:
ONNXGraphWrapper
(
opencv_onnx
::
GraphProto
&
_net
)
:
net
(
_net
)
{
numInputs
=
net
.
input_size
();
}
virtual
Ptr
<
ImportNodeWrapper
>
getNode
(
int
idx
)
const
CV_OVERRIDE
{
opencv_onnx
::
NodeProto
*
node
=
0
;
if
(
idx
>=
numInputs
)
node
=
net
.
mutable_node
(
idx
-
numInputs
);
return
makePtr
<
ONNXNodeWrapper
>
(
node
);
}
virtual
int
getNumNodes
()
const
CV_OVERRIDE
{
return
numInputs
+
net
.
node_size
();
}
virtual
std
::
string
getNodeName
(
int
idx
)
const
CV_OVERRIDE
{
if
(
idx
<
numInputs
)
return
net
.
input
(
idx
).
name
();
else
return
net
.
node
(
idx
-
numInputs
).
output
(
0
);
}
virtual
void
removeNode
(
int
idx
)
CV_OVERRIDE
{
CV_Assert
(
idx
>=
numInputs
);
net
.
mutable_node
()
->
DeleteSubrange
(
idx
-
numInputs
,
1
);
}
private
:
int
numInputs
;
opencv_onnx
::
GraphProto
&
net
;
};
class
SoftMaxSubgraph
:
public
Subgraph
{
public
:
SoftMaxSubgraph
()
{
int
input
=
addNodeToMatch
(
""
);
int
inpExp
=
addNodeToMatch
(
"Exp"
,
input
);
int
sum
=
addNodeToMatch
(
"ReduceSum"
,
inpExp
);
addNodeToMatch
(
"Div"
,
inpExp
,
sum
);
setFusedNode
(
"Softmax"
,
input
);
}
virtual
bool
match
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
CV_OVERRIDE
{
if
(
Subgraph
::
match
(
net
,
nodeId
,
matchedNodesIds
,
targetNodesIds
))
{
Ptr
<
ImportNodeWrapper
>
sum
=
net
->
getNode
(
matchedNodesIds
[
1
]);
opencv_onnx
::
NodeProto
*
node
=
sum
.
dynamicCast
<
ONNXNodeWrapper
>
()
->
node
;
for
(
int
i
=
0
;
i
<
node
->
attribute_size
();
i
++
)
{
opencv_onnx
::
AttributeProto
attr
=
node
->
attribute
(
i
);
if
(
attr
.
name
()
!=
"axes"
)
continue
;
if
(
attr
.
ints_size
()
!=
1
)
CV_Error
(
Error
::
StsNotImplemented
,
format
(
"Unexpected number of axes: %d"
,
attr
.
ints_size
()));
axis
=
attr
.
ints
(
0
);
return
true
;
}
CV_Error
(
Error
::
StsNotImplemented
,
"Missed axes attribute"
);
}
return
false
;
}
virtual
void
finalize
(
const
Ptr
<
ImportGraphWrapper
>&
,
const
Ptr
<
ImportNodeWrapper
>&
fusedNode
,
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>&
)
CV_OVERRIDE
{
opencv_onnx
::
NodeProto
*
node
=
fusedNode
.
dynamicCast
<
ONNXNodeWrapper
>
()
->
node
;
opencv_onnx
::
AttributeProto
*
attr
=
node
->
add_attribute
();
attr
->
set_name
(
"axis"
);
attr
->
set_i
(
axis
);
}
private
:
int
axis
;
};
void
simplifySubgraphs
(
opencv_onnx
::
GraphProto
&
net
)
{
std
::
vector
<
Ptr
<
Subgraph
>
>
subgraphs
;
subgraphs
.
push_back
(
makePtr
<
SoftMaxSubgraph
>
());
simplifySubgraphs
(
Ptr
<
ImportGraphWrapper
>
(
new
ONNXGraphWrapper
(
net
)),
subgraphs
);
}
CV__DNN_EXPERIMENTAL_NS_END
}}
// namespace cv::dnn
modules/dnn/src/onnx/onnx_graph_simplifier.hpp
0 → 100644
View file @
c1c84d2f
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2020, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#ifndef __OPENCV_DNN_ONNX_SIMPLIFIER_HPP__
#define __OPENCV_DNN_ONNX_SIMPLIFIER_HPP__
#include "../precomp.hpp"
#if defined(__GNUC__) && __GNUC__ >= 5
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wsuggest-override"
#endif
#include "opencv-onnx.pb.h"
#if defined(__GNUC__) && __GNUC__ >= 5
#pragma GCC diagnostic pop
#endif
namespace
cv
{
namespace
dnn
{
CV__DNN_EXPERIMENTAL_NS_BEGIN
void
simplifySubgraphs
(
opencv_onnx
::
GraphProto
&
net
);
CV__DNN_EXPERIMENTAL_NS_END
}}
// namespace dnn, namespace cv
#endif // __OPENCV_DNN_ONNX_SIMPLIFIER_HPP__
modules/dnn/src/onnx/onnx_importer.cpp
View file @
c1c84d2f
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
#pragma GCC diagnostic pop
#pragma GCC diagnostic pop
#endif
#endif
#include "onnx_graph_simplifier.hpp"
namespace
cv
{
namespace
cv
{
namespace
dnn
{
namespace
dnn
{
CV__DNN_EXPERIMENTAL_NS_BEGIN
CV__DNN_EXPERIMENTAL_NS_BEGIN
...
@@ -326,6 +328,9 @@ void ONNXImporter::populateNet(Net dstNet)
...
@@ -326,6 +328,9 @@ void ONNXImporter::populateNet(Net dstNet)
{
{
CV_Assert
(
model_proto
.
has_graph
());
CV_Assert
(
model_proto
.
has_graph
());
opencv_onnx
::
GraphProto
graph_proto
=
model_proto
.
graph
();
opencv_onnx
::
GraphProto
graph_proto
=
model_proto
.
graph
();
simplifySubgraphs
(
graph_proto
);
std
::
map
<
std
::
string
,
Mat
>
constBlobs
=
getGraphTensors
(
graph_proto
);
std
::
map
<
std
::
string
,
Mat
>
constBlobs
=
getGraphTensors
(
graph_proto
);
// List of internal blobs shapes.
// List of internal blobs shapes.
std
::
map
<
std
::
string
,
MatShape
>
outShapes
;
std
::
map
<
std
::
string
,
MatShape
>
outShapes
;
...
...
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
View file @
c1c84d2f
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#ifdef HAVE_PROTOBUF
#ifdef HAVE_PROTOBUF
#include "../graph_simplifier.hpp"
#include "tf_graph_simplifier.hpp"
#include "tf_graph_simplifier.hpp"
#include <queue>
#include <queue>
...
@@ -18,203 +19,87 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
...
@@ -18,203 +19,87 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
using
::
google
::
protobuf
::
RepeatedField
;
using
::
google
::
protobuf
::
RepeatedField
;
using
::
google
::
protobuf
::
MapPair
;
using
::
google
::
protobuf
::
MapPair
;
class
Subgraph
// Interface to match and replace TensorFlow subgraphs.
class
TFNodeWrapper
:
public
ImportNodeWrapper
{
{
public
:
public
:
virtual
~
Subgraph
(
)
{}
TFNodeWrapper
(
tensorflow
::
NodeDef
*
_node
)
:
node
(
_node
)
{}
// Add a node to be matched in the origin graph. Specify ids of nodes that
virtual
int
getNumInputs
()
const
CV_OVERRIDE
// are expected to be inputs. Returns id of a newly added node.
// TODO: Replace inputs to std::vector<int> in C++11
int
addNodeToMatch
(
const
std
::
string
&
op
,
int
input_0
=
-
1
,
int
input_1
=
-
1
,
int
input_2
=
-
1
,
int
input_3
=
-
1
)
{
{
int
nodeInputs
[]
=
{
input_0
,
input_1
,
input_2
,
input_3
};
return
node
->
input_size
();
int
numInputs
=
0
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
numInputs
+=
(
int
)(
nodeInputs
[
i
]
!=
-
1
);
}
return
addNodeToMatch
(
op
,
std
::
vector
<
int
>
(
&
nodeInputs
[
0
],
&
nodeInputs
[
0
]
+
numInputs
));
}
}
int
addNodeToMatch
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
)
virtual
std
::
string
getInputName
(
int
idx
)
const
CV_OVERRIDE
{
{
for
(
int
i
=
0
;
i
<
inputs_
.
size
();
++
i
)
return
node
->
input
(
idx
);
{
CV_Assert
(
inputs_
[
i
]
<
(
int
)
nodes
.
size
());
}
nodes
.
push_back
(
op
);
inputs
.
push_back
(
inputs_
);
return
nodes
.
size
()
-
1
;
}
}
// Specify resulting node. All the matched nodes in subgraph excluding
virtual
std
::
string
getType
()
const
CV_OVERRIDE
// input nodes will be fused into this single node.
// TODO: Replace inputs to std::vector<int> in C++11
void
setFusedNode
(
const
std
::
string
&
op
,
int
input_0
=
-
1
,
int
input_1
=
-
1
,
int
input_2
=
-
1
,
int
input_3
=
-
1
,
int
input_4
=
-
1
,
int
input_5
=
-
1
)
{
{
int
nodeInputs
[]
=
{
input_0
,
input_1
,
input_2
,
input_3
,
input_4
,
input_5
};
return
node
->
op
();
int
numInputs
=
0
;
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
CV_Assert
(
nodeInputs
[
i
]
<
(
int
)
nodes
.
size
());
numInputs
+=
(
int
)(
nodeInputs
[
i
]
!=
-
1
);
}
setFusedNode
(
op
,
std
::
vector
<
int
>
(
&
nodeInputs
[
0
],
&
nodeInputs
[
0
]
+
numInputs
));
}
}
v
oid
setFusedNode
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
)
v
irtual
void
setType
(
const
std
::
string
&
type
)
CV_OVERRIDE
{
{
fusedNodeInputs
=
inputs_
;
node
->
set_op
(
type
);
fusedNodeOp
=
op
;
}
}
static
int
getInputNodeId
(
const
tensorflow
::
GraphDef
&
net
,
virtual
void
setInputNames
(
const
std
::
vector
<
std
::
string
>&
inputs
)
CV_OVERRIDE
const
tensorflow
::
NodeDef
&
node
,
int
inpId
)
{
{
CV_Assert
(
inpId
<
node
.
input_size
());
node
->
clear_input
();
std
::
string
name
=
node
.
input
(
inpId
);
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
// If operation produces several tensors, they are specified by index
node
->
add_input
(
inputs
[
i
]);
// after ':' character. In example, "input:0".
name
=
name
.
substr
(
0
,
name
.
rfind
(
':'
));
const
int
numNodes
=
net
.
node_size
();
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
if
(
net
.
node
(
i
).
name
()
==
name
)
return
i
;
}
CV_Error
(
Error
::
StsParseError
,
"Input node with name "
+
name
+
" not found"
);
}
}
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
tensorflow
::
NodeDef
*
node
;
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
};
virtual
bool
match
(
const
tensorflow
::
GraphDef
&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
{
matchedNodesIds
.
clear
();
targetNodesIds
.
clear
();
std
::
queue
<
int
>
nodesToMatch
;
std
::
queue
<
int
>
targetNodes
;
nodesToMatch
.
push
(
nodeId
);
targetNodes
.
push
(
nodes
.
size
()
-
1
);
while
(
!
nodesToMatch
.
empty
())
{
int
nodeToMatch
=
nodesToMatch
.
front
();
int
targetNodeId
=
targetNodes
.
front
();
nodesToMatch
.
pop
();
targetNodes
.
pop
();
if
(
std
::
find
(
matchedNodesIds
.
begin
(),
matchedNodesIds
.
end
(),
nodeToMatch
)
!=
matchedNodesIds
.
end
())
continue
;
const
tensorflow
::
NodeDef
&
node
=
net
.
node
(
nodeToMatch
);
if
(
node
.
op
()
!=
nodes
[
targetNodeId
])
return
false
;
std
::
vector
<
int
>&
inputNodes
=
inputs
[
targetNodeId
];
if
(
inputNodes
.
size
()
!=
node
.
input_size
())
return
false
;
for
(
int
j
=
0
;
j
<
inputNodes
.
size
();
++
j
)
class
TFGraphWrapper
:
public
ImportGraphWrapper
{
{
if
(
nodes
[
inputNodes
[
j
]].
empty
())
// Unknown input node type.
public
:
continue
;
TFGraphWrapper
(
tensorflow
::
GraphDef
&
_net
)
:
net
(
_net
)
{}
nodeId
=
getInputNodeId
(
net
,
node
,
j
);
const
tensorflow
::
NodeDef
&
inpNode
=
net
.
node
(
nodeId
);
if
(
inpNode
.
op
()
!=
"Const"
)
{
nodesToMatch
.
push
(
nodeId
);
targetNodes
.
push
(
inputNodes
[
j
]);
}
else
if
(
nodes
[
inputNodes
[
j
]]
!=
"Const"
)
return
false
;
}
matchedNodesIds
.
push_back
(
nodeToMatch
);
targetNodesIds
.
push_back
(
targetNodeId
);
}
const
int
n
=
matchedNodesIds
.
size
();
virtual
Ptr
<
ImportNodeWrapper
>
getNode
(
int
idx
)
const
CV_OVERRIDE
std
::
vector
<
std
::
pair
<
int
,
int
>
>
elements
(
n
);
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
return
makePtr
<
TFNodeWrapper
>
(
net
.
mutable_node
(
idx
));
elements
[
i
]
=
std
::
make_pair
(
matchedNodesIds
[
i
],
targetNodesIds
[
i
]);
std
::
sort
(
elements
.
begin
(),
elements
.
end
());
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
matchedNodesIds
[
i
]
=
elements
[
i
].
first
;
targetNodesIds
[
i
]
=
elements
[
i
].
second
;
}
return
true
;
}
}
// Fuse matched subgraph.
virtual
int
getNumNodes
()
const
CV_OVERRIDE
void
replace
(
tensorflow
::
GraphDef
&
net
,
const
std
::
vector
<
int
>&
matchedNodesIds
,
const
std
::
vector
<
int
>&
targetNodesIds
)
{
{
// Extract names of input nodes.
return
net
.
node_size
();
std
::
vector
<
std
::
string
>
inputsNames
(
fusedNodeInputs
.
size
());
}
for
(
int
i
=
0
;
i
<
fusedNodeInputs
.
size
();
++
i
)
{
std
::
string
inpName
;
// Find input node name looking at inputs of fused nodes.
for
(
int
j
=
0
;
j
<
matchedNodesIds
.
size
()
&&
inpName
.
empty
();
++
j
)
{
const
tensorflow
::
NodeDef
&
node
=
net
.
node
(
matchedNodesIds
[
j
]);
std
::
vector
<
int
>&
inpIndices
=
inputs
[
targetNodesIds
[
j
]];
CV_Assert
(
node
.
input_size
()
==
inpIndices
.
size
());
for
(
int
k
=
0
;
k
<
inpIndices
.
size
();
++
k
)
{
if
(
inpIndices
[
k
]
==
fusedNodeInputs
[
i
])
{
inpName
=
node
.
input
(
k
);
break
;
}
}
}
CV_Assert
(
!
inpName
.
empty
());
inputsNames
[
i
]
=
inpName
;
}
// Remove matched nodes except the last one. Indices in ascending order are expected.
tensorflow
::
NodeDef
*
node
=
net
.
mutable_node
(
matchedNodesIds
.
back
());
for
(
int
i
=
matchedNodesIds
.
size
()
-
2
;
i
>=
0
;
--
i
)
net
.
mutable_node
()
->
DeleteSubrange
(
matchedNodesIds
[
i
],
1
);
// Modify the last node to be a fused one.
virtual
std
::
string
getNodeName
(
int
idx
)
const
CV_OVERRIDE
node
->
set_op
(
fusedNodeOp
);
{
node
->
clear_input
();
return
net
.
node
(
idx
).
name
();
for
(
int
i
=
0
;
i
<
inputsNames
.
size
();
++
i
)
}
{
node
->
add_input
(
inputsNames
[
i
]);
}
std
::
vector
<
tensorflow
::
NodeDef
*>
inputNodes
(
inputsNames
.
size
());
virtual
void
removeNode
(
int
idx
)
CV_OVERRIDE
for
(
int
i
=
0
;
i
<
inputsNames
.
size
();
++
i
)
{
{
net
.
mutable_node
()
->
DeleteSubrange
(
idx
,
1
);
inputNodes
[
i
]
=
net
.
mutable_node
(
getInputNodeId
(
net
,
*
node
,
i
));
}
finalize
(
net
,
node
,
inputNodes
);
}
}
virtual
void
finalize
(
tensorflow
::
GraphDef
&
,
tensorflow
::
NodeDef
*
,
tensorflow
::
GraphDef
&
net
;
std
::
vector
<
tensorflow
::
NodeDef
*>&
)
{}
};
private
:
class
TFSubgraph
:
public
Subgraph
std
::
vector
<
std
::
string
>
nodes
;
// Nodes to be matched in the origin graph.
{
std
::
vector
<
std
::
vector
<
int
>
>
inputs
;
// Connections of an every node to it's inputs.
virtual
void
finalize
(
const
Ptr
<
ImportGraphWrapper
>&
netWrapper
,
const
Ptr
<
ImportNodeWrapper
>&
fusedNodeWrapper
,
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>&
inputs
)
CV_OVERRIDE
{
std
::
vector
<
tensorflow
::
NodeDef
*>
inputNodes
(
inputs
.
size
());
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
inputNodes
[
i
]
=
inputs
[
i
].
dynamicCast
<
TFNodeWrapper
>
()
->
node
;
finalize
(
netWrapper
.
dynamicCast
<
TFGraphWrapper
>
()
->
net
,
fusedNodeWrapper
.
dynamicCast
<
TFNodeWrapper
>
()
->
node
,
inputNodes
);
}
std
::
string
fusedNodeOp
;
// Operation name of resulting fused node.
virtual
void
finalize
(
tensorflow
::
GraphDef
&
,
tensorflow
::
NodeDef
*
fusedNode
,
std
::
vector
<
int
>
fusedNodeInputs
;
// Inputs of fused node.
std
::
vector
<
tensorflow
::
NodeDef
*>&
inputNodes
)
{}
};
};
class
BatchNormSubgraph
:
public
Subgraph
class
BatchNormSubgraph
:
public
TF
Subgraph
{
{
public
:
public
:
BatchNormSubgraph
()
BatchNormSubgraph
()
...
@@ -250,7 +135,7 @@ public:
...
@@ -250,7 +135,7 @@ public:
}
}
};
};
class
BatchNormNoGammaSubgraph
:
public
Subgraph
class
BatchNormNoGammaSubgraph
:
public
TF
Subgraph
{
{
public
:
public
:
BatchNormNoGammaSubgraph
()
BatchNormNoGammaSubgraph
()
...
@@ -366,20 +251,21 @@ public:
...
@@ -366,20 +251,21 @@ public:
setFusedNode
(
"Relu6"
,
input
);
setFusedNode
(
"Relu6"
,
input
);
}
}
virtual
bool
match
(
const
tensorflow
::
GraphDef
&
net
,
int
nodeId
,
virtual
bool
match
(
const
Ptr
<
ImportGraphWrapper
>
&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
CV_OVERRIDE
std
::
vector
<
int
>&
targetNodesIds
)
CV_OVERRIDE
{
{
if
(
!
Subgraph
::
match
(
net
,
nodeId
,
matchedNodesIds
,
targetNodesIds
))
if
(
!
Subgraph
::
match
(
net
,
nodeId
,
matchedNodesIds
,
targetNodesIds
))
return
false
;
return
false
;
Mat
maxValue
=
getTensorContent
(
net
.
node
(
matchedNodesIds
.
front
()
+
1
).
attr
().
at
(
"value"
).
tensor
());
tensorflow
::
NodeDef
*
node
=
net
->
getNode
(
matchedNodesIds
.
front
()
+
1
).
dynamicCast
<
TFNodeWrapper
>
()
->
node
;
Mat
maxValue
=
getTensorContent
(
node
->
attr
().
at
(
"value"
).
tensor
());
return
maxValue
.
type
()
==
CV_32FC1
&&
maxValue
.
total
()
==
1
&&
maxValue
.
at
<
float
>
(
0
)
==
6
;
return
maxValue
.
type
()
==
CV_32FC1
&&
maxValue
.
total
()
==
1
&&
maxValue
.
at
<
float
>
(
0
)
==
6
;
}
}
};
};
// Keras' reshape stores output shape in separate Const nodes by one value.
// Keras' reshape stores output shape in separate Const nodes by one value.
// Need to merge them into a single Const node.
// Need to merge them into a single Const node.
class
ReshapeKerasSubgraph
:
public
Subgraph
class
ReshapeKerasSubgraph
:
public
TF
Subgraph
{
{
public
:
public
:
ReshapeKerasSubgraph
(
int
_numOutDims
)
:
numOutDims
(
_numOutDims
)
ReshapeKerasSubgraph
(
int
_numOutDims
)
:
numOutDims
(
_numOutDims
)
...
@@ -402,15 +288,15 @@ public:
...
@@ -402,15 +288,15 @@ public:
setFusedNode
(
"Reshape"
,
ids
);
setFusedNode
(
"Reshape"
,
ids
);
}
}
virtual
bool
match
(
const
tensorflow
::
GraphDef
&
net
,
int
nodeId
,
virtual
bool
match
(
const
Ptr
<
ImportGraphWrapper
>
&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
CV_OVERRIDE
std
::
vector
<
int
>&
targetNodesIds
)
CV_OVERRIDE
{
{
const
tensorflow
::
NodeDef
&
node
=
net
.
n
ode
(
nodeId
);
Ptr
<
ImportNodeWrapper
>
node
=
net
->
getN
ode
(
nodeId
);
if
(
node
.
input_size
()
==
0
)
if
(
node
->
getNumInputs
()
==
0
)
return
false
;
return
false
;
inpName
=
node
.
input
(
0
);
inpName
=
node
->
getInputName
(
0
);
return
Subgraph
::
match
(
net
,
nodeId
,
matchedNodesIds
,
targetNodesIds
);
return
Subgraph
::
match
(
net
,
nodeId
,
matchedNodesIds
,
targetNodesIds
);
}
}
...
@@ -457,7 +343,7 @@ public:
...
@@ -457,7 +343,7 @@ public:
}
}
};
};
class
DeconvolutionValidKerasSubgraph
:
public
Subgraph
class
DeconvolutionValidKerasSubgraph
:
public
TF
Subgraph
{
{
public
:
public
:
DeconvolutionValidKerasSubgraph
()
DeconvolutionValidKerasSubgraph
()
...
@@ -518,7 +404,7 @@ public:
...
@@ -518,7 +404,7 @@ public:
}
}
};
};
class
DeconvolutionSameKerasSubgraph
:
public
Subgraph
class
DeconvolutionSameKerasSubgraph
:
public
TF
Subgraph
{
{
public
:
public
:
DeconvolutionSameKerasSubgraph
()
DeconvolutionSameKerasSubgraph
()
...
@@ -608,7 +494,7 @@ public:
...
@@ -608,7 +494,7 @@ public:
};
};
// In case of resizing by factor.
// In case of resizing by factor.
class
UpsamplingKerasSubgraph
:
public
Subgraph
class
UpsamplingKerasSubgraph
:
public
TF
Subgraph
{
{
public
:
public
:
UpsamplingKerasSubgraph
(
const
std
::
string
&
type
)
UpsamplingKerasSubgraph
(
const
std
::
string
&
type
)
...
@@ -703,7 +589,7 @@ public:
...
@@ -703,7 +589,7 @@ public:
}
}
};
};
class
KerasMVNSubgraph
:
public
Subgraph
class
KerasMVNSubgraph
:
public
TF
Subgraph
{
{
public
:
public
:
KerasMVNSubgraph
()
KerasMVNSubgraph
()
...
@@ -758,20 +644,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
...
@@ -758,20 +644,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
subgraphs
.
push_back
(
Ptr
<
Subgraph
>
(
new
ReshapeAsShapeSubgraph
()));
subgraphs
.
push_back
(
Ptr
<
Subgraph
>
(
new
ReshapeAsShapeSubgraph
()));
subgraphs
.
push_back
(
Ptr
<
Subgraph
>
(
new
KerasMVNSubgraph
()));
subgraphs
.
push_back
(
Ptr
<
Subgraph
>
(
new
KerasMVNSubgraph
()));
int
numNodes
=
net
.
node_size
();
simplifySubgraphs
(
Ptr
<
ImportGraphWrapper
>
(
new
TFGraphWrapper
(
net
)),
subgraphs
);
std
::
vector
<
int
>
matchedNodesIds
,
targetNodesIds
;
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
for
(
int
j
=
0
;
j
<
subgraphs
.
size
();
++
j
)
{
if
(
subgraphs
[
j
]
->
match
(
net
,
i
,
matchedNodesIds
,
targetNodesIds
))
{
subgraphs
[
j
]
->
replace
(
net
,
matchedNodesIds
,
targetNodesIds
);
numNodes
-=
matchedNodesIds
.
size
()
-
1
;
// #matchedNodes removed and one added.
break
;
}
}
}
}
}
void
RemoveIdentityOps
(
tensorflow
::
GraphDef
&
net
)
void
RemoveIdentityOps
(
tensorflow
::
GraphDef
&
net
)
...
...
modules/dnn/test/test_onnx_importer.cpp
View file @
c1c84d2f
...
@@ -396,6 +396,7 @@ TEST_P(Test_ONNX_layers, Softmax)
...
@@ -396,6 +396,7 @@ TEST_P(Test_ONNX_layers, Softmax)
{
{
testONNXModels
(
"softmax"
);
testONNXModels
(
"softmax"
);
testONNXModels
(
"log_softmax"
,
npy
,
0
,
0
,
false
,
false
);
testONNXModels
(
"log_softmax"
,
npy
,
0
,
0
,
false
,
false
);
testONNXModels
(
"softmax_unfused"
);
}
}
TEST_P
(
Test_ONNX_layers
,
Split_EltwiseMax
)
TEST_P
(
Test_ONNX_layers
,
Split_EltwiseMax
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment