Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
c1c84d2f
Commit
c1c84d2f
authored
Jan 06, 2020
by
Dmitry Kurtaev
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ONNX graphs simplifier
parent
74bc8d35
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
565 additions
and
192 deletions
+565
-192
graph_simplifier.cpp
modules/dnn/src/graph_simplifier.cpp
+207
-0
graph_simplifier.hpp
modules/dnn/src/graph_simplifier.hpp
+100
-0
onnx_graph_simplifier.cpp
modules/dnn/src/onnx/onnx_graph_simplifier.cpp
+157
-0
onnx_graph_simplifier.hpp
modules/dnn/src/onnx/onnx_graph_simplifier.hpp
+30
-0
onnx_importer.cpp
modules/dnn/src/onnx/onnx_importer.cpp
+5
-0
tf_graph_simplifier.cpp
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
+65
-192
test_onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp
+1
-0
No files found.
modules/dnn/src/graph_simplifier.cpp
0 → 100644
View file @
c1c84d2f
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2020, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#include "precomp.hpp"
#include "graph_simplifier.hpp"
#include <queue>
namespace
cv
{
namespace
dnn
{
Subgraph
::~
Subgraph
()
{}
int
Subgraph
::
addNodeToMatch
(
const
std
::
string
&
op
,
int
input_0
,
int
input_1
,
int
input_2
,
int
input_3
)
{
int
nodeInputs
[]
=
{
input_0
,
input_1
,
input_2
,
input_3
};
int
numInputs
=
0
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
numInputs
+=
(
int
)(
nodeInputs
[
i
]
!=
-
1
);
}
return
addNodeToMatch
(
op
,
std
::
vector
<
int
>
(
&
nodeInputs
[
0
],
&
nodeInputs
[
0
]
+
numInputs
));
}
int
Subgraph
::
addNodeToMatch
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
)
{
for
(
int
i
=
0
;
i
<
inputs_
.
size
();
++
i
)
{
CV_Assert
(
inputs_
[
i
]
<
(
int
)
nodes
.
size
());
}
nodes
.
push_back
(
op
);
inputs
.
push_back
(
inputs_
);
return
nodes
.
size
()
-
1
;
}
void
Subgraph
::
setFusedNode
(
const
std
::
string
&
op
,
int
input_0
,
int
input_1
,
int
input_2
,
int
input_3
,
int
input_4
,
int
input_5
)
{
int
nodeInputs
[]
=
{
input_0
,
input_1
,
input_2
,
input_3
,
input_4
,
input_5
};
int
numInputs
=
0
;
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
CV_Assert
(
nodeInputs
[
i
]
<
(
int
)
nodes
.
size
());
numInputs
+=
(
int
)(
nodeInputs
[
i
]
!=
-
1
);
}
setFusedNode
(
op
,
std
::
vector
<
int
>
(
&
nodeInputs
[
0
],
&
nodeInputs
[
0
]
+
numInputs
));
}
void
Subgraph
::
setFusedNode
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
)
{
fusedNodeInputs
=
inputs_
;
fusedNodeOp
=
op
;
}
int
Subgraph
::
getInputNodeId
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
Ptr
<
ImportNodeWrapper
>&
node
,
int
inpId
)
{
CV_Assert
(
inpId
<
node
->
getNumInputs
());
std
::
string
name
=
node
->
getInputName
(
inpId
);
// If operation produces several tensors, they are specified by index
// after ':' character. In example, "input:0".
name
=
name
.
substr
(
0
,
name
.
rfind
(
':'
));
const
int
numNodes
=
net
->
getNumNodes
();
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
if
(
net
->
getNodeName
(
i
)
==
name
)
return
i
;
}
CV_Error
(
Error
::
StsParseError
,
"Input node with name "
+
name
+
" not found"
);
}
bool
Subgraph
::
match
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
{
matchedNodesIds
.
clear
();
targetNodesIds
.
clear
();
std
::
queue
<
int
>
nodesToMatch
;
std
::
queue
<
int
>
targetNodes
;
nodesToMatch
.
push
(
nodeId
);
targetNodes
.
push
(
nodes
.
size
()
-
1
);
while
(
!
nodesToMatch
.
empty
())
{
int
nodeToMatch
=
nodesToMatch
.
front
();
int
targetNodeId
=
targetNodes
.
front
();
nodesToMatch
.
pop
();
targetNodes
.
pop
();
if
(
std
::
find
(
matchedNodesIds
.
begin
(),
matchedNodesIds
.
end
(),
nodeToMatch
)
!=
matchedNodesIds
.
end
())
continue
;
const
Ptr
<
ImportNodeWrapper
>
node
=
net
->
getNode
(
nodeToMatch
);
if
(
node
->
getType
()
!=
nodes
[
targetNodeId
])
return
false
;
std
::
vector
<
int
>&
inputNodes
=
inputs
[
targetNodeId
];
if
(
inputNodes
.
size
()
!=
node
->
getNumInputs
())
return
false
;
for
(
int
j
=
0
;
j
<
inputNodes
.
size
();
++
j
)
{
if
(
nodes
[
inputNodes
[
j
]].
empty
())
// Unknown input node type.
continue
;
nodeId
=
getInputNodeId
(
net
,
node
,
j
);
const
Ptr
<
ImportNodeWrapper
>
inpNode
=
net
->
getNode
(
nodeId
);
if
(
inpNode
->
getType
()
!=
"Const"
)
{
nodesToMatch
.
push
(
nodeId
);
targetNodes
.
push
(
inputNodes
[
j
]);
}
else
if
(
nodes
[
inputNodes
[
j
]]
!=
"Const"
)
return
false
;
}
matchedNodesIds
.
push_back
(
nodeToMatch
);
targetNodesIds
.
push_back
(
targetNodeId
);
}
const
int
n
=
matchedNodesIds
.
size
();
std
::
vector
<
std
::
pair
<
int
,
int
>
>
elements
(
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
elements
[
i
]
=
std
::
make_pair
(
matchedNodesIds
[
i
],
targetNodesIds
[
i
]);
std
::
sort
(
elements
.
begin
(),
elements
.
end
());
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
matchedNodesIds
[
i
]
=
elements
[
i
].
first
;
targetNodesIds
[
i
]
=
elements
[
i
].
second
;
}
return
true
;
}
void
Subgraph
::
replace
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
std
::
vector
<
int
>&
matchedNodesIds
,
const
std
::
vector
<
int
>&
targetNodesIds
)
{
// Extract names of input nodes.
std
::
vector
<
std
::
string
>
inputsNames
(
fusedNodeInputs
.
size
());
for
(
int
i
=
0
;
i
<
fusedNodeInputs
.
size
();
++
i
)
{
std
::
string
inpName
;
// Find input node name looking at inputs of fused nodes.
for
(
int
j
=
0
;
j
<
matchedNodesIds
.
size
()
&&
inpName
.
empty
();
++
j
)
{
Ptr
<
ImportNodeWrapper
>
node
=
net
->
getNode
(
matchedNodesIds
[
j
]);
std
::
vector
<
int
>&
inpIndices
=
inputs
[
targetNodesIds
[
j
]];
CV_Assert
(
node
->
getNumInputs
()
==
inpIndices
.
size
());
for
(
int
k
=
0
;
k
<
inpIndices
.
size
();
++
k
)
{
if
(
inpIndices
[
k
]
==
fusedNodeInputs
[
i
])
{
inpName
=
node
->
getInputName
(
k
);
break
;
}
}
}
CV_Assert
(
!
inpName
.
empty
());
inputsNames
[
i
]
=
inpName
;
}
// Remove matched nodes except the last one. Indices in ascending order are expected.
Ptr
<
ImportNodeWrapper
>
node
=
net
->
getNode
(
matchedNodesIds
.
back
());
for
(
int
i
=
matchedNodesIds
.
size
()
-
2
;
i
>=
0
;
--
i
)
net
->
removeNode
(
matchedNodesIds
[
i
]);
// Modify the last node to be a fused one.
node
->
setType
(
fusedNodeOp
);
node
->
setInputNames
(
inputsNames
);
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>
inputNodes
(
inputsNames
.
size
());
for
(
int
i
=
0
;
i
<
inputsNames
.
size
();
++
i
)
{
inputNodes
[
i
]
=
net
->
getNode
(
getInputNodeId
(
net
,
node
,
i
));
}
finalize
(
net
,
node
,
inputNodes
);
}
void
Subgraph
::
finalize
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
Ptr
<
ImportNodeWrapper
>&
fusedNode
,
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>&
inputs
)
{}
void
simplifySubgraphs
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
std
::
vector
<
Ptr
<
Subgraph
>
>&
patterns
)
{
int
numNodes
=
net
->
getNumNodes
();
std
::
vector
<
int
>
matchedNodesIds
,
targetNodesIds
;
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
for
(
int
j
=
0
;
j
<
patterns
.
size
();
++
j
)
{
if
(
patterns
[
j
]
->
match
(
net
,
i
,
matchedNodesIds
,
targetNodesIds
))
{
patterns
[
j
]
->
replace
(
net
,
matchedNodesIds
,
targetNodesIds
);
numNodes
-=
matchedNodesIds
.
size
()
-
1
;
// #matchedNodes removed and one added.
break
;
}
}
}
}
}}
// namespace cv::dnn
modules/dnn/src/graph_simplifier.hpp
0 → 100644
View file @
c1c84d2f
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2020, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#ifndef __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__
#define __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__
#include <string>
#include <opencv2/core.hpp>
namespace
cv
{
namespace
dnn
{
class
ImportNodeWrapper
{
public
:
virtual
~
ImportNodeWrapper
()
{};
virtual
int
getNumInputs
()
const
=
0
;
virtual
std
::
string
getInputName
(
int
idx
)
const
=
0
;
virtual
std
::
string
getType
()
const
=
0
;
virtual
void
setType
(
const
std
::
string
&
type
)
=
0
;
virtual
void
setInputNames
(
const
std
::
vector
<
std
::
string
>&
inputs
)
=
0
;
};
class
ImportGraphWrapper
{
public
:
virtual
~
ImportGraphWrapper
()
{};
virtual
Ptr
<
ImportNodeWrapper
>
getNode
(
int
idx
)
const
=
0
;
virtual
int
getNumNodes
()
const
=
0
;
virtual
std
::
string
getNodeName
(
int
idx
)
const
=
0
;
virtual
void
removeNode
(
int
idx
)
=
0
;
};
class
Subgraph
// Interface to match and replace subgraphs.
{
public
:
virtual
~
Subgraph
();
// Add a node to be matched in the origin graph. Specify ids of nodes that
// are expected to be inputs. Returns id of a newly added node.
// TODO: Replace inputs to std::vector<int> in C++11
int
addNodeToMatch
(
const
std
::
string
&
op
,
int
input_0
=
-
1
,
int
input_1
=
-
1
,
int
input_2
=
-
1
,
int
input_3
=
-
1
);
int
addNodeToMatch
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
);
// Specify resulting node. All the matched nodes in subgraph excluding
// input nodes will be fused into this single node.
// TODO: Replace inputs to std::vector<int> in C++11
void
setFusedNode
(
const
std
::
string
&
op
,
int
input_0
=
-
1
,
int
input_1
=
-
1
,
int
input_2
=
-
1
,
int
input_3
=
-
1
,
int
input_4
=
-
1
,
int
input_5
=
-
1
);
void
setFusedNode
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
);
static
int
getInputNodeId
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
Ptr
<
ImportNodeWrapper
>&
node
,
int
inpId
);
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
virtual
bool
match
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
);
// Fuse matched subgraph.
void
replace
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
std
::
vector
<
int
>&
matchedNodesIds
,
const
std
::
vector
<
int
>&
targetNodesIds
);
virtual
void
finalize
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
Ptr
<
ImportNodeWrapper
>&
fusedNode
,
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>&
inputs
);
private
:
std
::
vector
<
std
::
string
>
nodes
;
// Nodes to be matched in the origin graph.
std
::
vector
<
std
::
vector
<
int
>
>
inputs
;
// Connections of an every node to it's inputs.
std
::
string
fusedNodeOp
;
// Operation name of resulting fused node.
std
::
vector
<
int
>
fusedNodeInputs
;
// Inputs of fused node.
};
void
simplifySubgraphs
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
const
std
::
vector
<
Ptr
<
Subgraph
>
>&
patterns
);
}}
// namespace dnn, namespace cv
#endif // __OPENCV_DNN_GRAPH_SIMPLIFIER_HPP__
modules/dnn/src/onnx/onnx_graph_simplifier.cpp
0 → 100644
View file @
c1c84d2f
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2020, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#include "../precomp.hpp"
#include "../graph_simplifier.hpp"
#include "onnx_graph_simplifier.hpp"
#include <queue>
namespace
cv
{
namespace
dnn
{
CV__DNN_EXPERIMENTAL_NS_BEGIN
// This wrapper can behave differently for fake input nodes and real graph nodes.
class
ONNXNodeWrapper
:
public
ImportNodeWrapper
{
public
:
ONNXNodeWrapper
(
opencv_onnx
::
NodeProto
*
_node
=
0
)
:
node
(
_node
)
{}
virtual
int
getNumInputs
()
const
CV_OVERRIDE
{
return
node
?
node
->
input_size
()
:
0
;
}
virtual
std
::
string
getInputName
(
int
idx
)
const
CV_OVERRIDE
{
CV_Assert_N
(
node
,
idx
<
node
->
input_size
());
return
node
->
input
(
idx
);
}
virtual
std
::
string
getType
()
const
CV_OVERRIDE
{
return
node
?
node
->
op_type
()
:
""
;
}
virtual
void
setType
(
const
std
::
string
&
type
)
CV_OVERRIDE
{
CV_Assert
(
node
);
node
->
set_op_type
(
type
);
}
virtual
void
setInputNames
(
const
std
::
vector
<
std
::
string
>&
inputs
)
CV_OVERRIDE
{
CV_Assert
(
node
);
node
->
clear_input
();
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
node
->
add_input
(
inputs
[
i
]);
}
opencv_onnx
::
NodeProto
*
node
;
};
// ONNX graph's inputs are separate from nodes so we index them before the rest of nodes.
class
ONNXGraphWrapper
:
public
ImportGraphWrapper
{
public
:
ONNXGraphWrapper
(
opencv_onnx
::
GraphProto
&
_net
)
:
net
(
_net
)
{
numInputs
=
net
.
input_size
();
}
virtual
Ptr
<
ImportNodeWrapper
>
getNode
(
int
idx
)
const
CV_OVERRIDE
{
opencv_onnx
::
NodeProto
*
node
=
0
;
if
(
idx
>=
numInputs
)
node
=
net
.
mutable_node
(
idx
-
numInputs
);
return
makePtr
<
ONNXNodeWrapper
>
(
node
);
}
virtual
int
getNumNodes
()
const
CV_OVERRIDE
{
return
numInputs
+
net
.
node_size
();
}
virtual
std
::
string
getNodeName
(
int
idx
)
const
CV_OVERRIDE
{
if
(
idx
<
numInputs
)
return
net
.
input
(
idx
).
name
();
else
return
net
.
node
(
idx
-
numInputs
).
output
(
0
);
}
virtual
void
removeNode
(
int
idx
)
CV_OVERRIDE
{
CV_Assert
(
idx
>=
numInputs
);
net
.
mutable_node
()
->
DeleteSubrange
(
idx
-
numInputs
,
1
);
}
private
:
int
numInputs
;
opencv_onnx
::
GraphProto
&
net
;
};
class
SoftMaxSubgraph
:
public
Subgraph
{
public
:
SoftMaxSubgraph
()
{
int
input
=
addNodeToMatch
(
""
);
int
inpExp
=
addNodeToMatch
(
"Exp"
,
input
);
int
sum
=
addNodeToMatch
(
"ReduceSum"
,
inpExp
);
addNodeToMatch
(
"Div"
,
inpExp
,
sum
);
setFusedNode
(
"Softmax"
,
input
);
}
virtual
bool
match
(
const
Ptr
<
ImportGraphWrapper
>&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
CV_OVERRIDE
{
if
(
Subgraph
::
match
(
net
,
nodeId
,
matchedNodesIds
,
targetNodesIds
))
{
Ptr
<
ImportNodeWrapper
>
sum
=
net
->
getNode
(
matchedNodesIds
[
1
]);
opencv_onnx
::
NodeProto
*
node
=
sum
.
dynamicCast
<
ONNXNodeWrapper
>
()
->
node
;
for
(
int
i
=
0
;
i
<
node
->
attribute_size
();
i
++
)
{
opencv_onnx
::
AttributeProto
attr
=
node
->
attribute
(
i
);
if
(
attr
.
name
()
!=
"axes"
)
continue
;
if
(
attr
.
ints_size
()
!=
1
)
CV_Error
(
Error
::
StsNotImplemented
,
format
(
"Unexpected number of axes: %d"
,
attr
.
ints_size
()));
axis
=
attr
.
ints
(
0
);
return
true
;
}
CV_Error
(
Error
::
StsNotImplemented
,
"Missed axes attribute"
);
}
return
false
;
}
virtual
void
finalize
(
const
Ptr
<
ImportGraphWrapper
>&
,
const
Ptr
<
ImportNodeWrapper
>&
fusedNode
,
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>&
)
CV_OVERRIDE
{
opencv_onnx
::
NodeProto
*
node
=
fusedNode
.
dynamicCast
<
ONNXNodeWrapper
>
()
->
node
;
opencv_onnx
::
AttributeProto
*
attr
=
node
->
add_attribute
();
attr
->
set_name
(
"axis"
);
attr
->
set_i
(
axis
);
}
private
:
int
axis
;
};
void
simplifySubgraphs
(
opencv_onnx
::
GraphProto
&
net
)
{
std
::
vector
<
Ptr
<
Subgraph
>
>
subgraphs
;
subgraphs
.
push_back
(
makePtr
<
SoftMaxSubgraph
>
());
simplifySubgraphs
(
Ptr
<
ImportGraphWrapper
>
(
new
ONNXGraphWrapper
(
net
)),
subgraphs
);
}
CV__DNN_EXPERIMENTAL_NS_END
}}
// namespace cv::dnn
modules/dnn/src/onnx/onnx_graph_simplifier.hpp
0 → 100644
View file @
c1c84d2f
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2020, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#ifndef __OPENCV_DNN_ONNX_SIMPLIFIER_HPP__
#define __OPENCV_DNN_ONNX_SIMPLIFIER_HPP__
#include "../precomp.hpp"
#if defined(__GNUC__) && __GNUC__ >= 5
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wsuggest-override"
#endif
#include "opencv-onnx.pb.h"
#if defined(__GNUC__) && __GNUC__ >= 5
#pragma GCC diagnostic pop
#endif
namespace
cv
{
namespace
dnn
{
CV__DNN_EXPERIMENTAL_NS_BEGIN
void
simplifySubgraphs
(
opencv_onnx
::
GraphProto
&
net
);
CV__DNN_EXPERIMENTAL_NS_END
}}
// namespace dnn, namespace cv
#endif // __OPENCV_DNN_ONNX_SIMPLIFIER_HPP__
modules/dnn/src/onnx/onnx_importer.cpp
View file @
c1c84d2f
...
...
@@ -26,6 +26,8 @@
#pragma GCC diagnostic pop
#endif
#include "onnx_graph_simplifier.hpp"
namespace
cv
{
namespace
dnn
{
CV__DNN_EXPERIMENTAL_NS_BEGIN
...
...
@@ -326,6 +328,9 @@ void ONNXImporter::populateNet(Net dstNet)
{
CV_Assert
(
model_proto
.
has_graph
());
opencv_onnx
::
GraphProto
graph_proto
=
model_proto
.
graph
();
simplifySubgraphs
(
graph_proto
);
std
::
map
<
std
::
string
,
Mat
>
constBlobs
=
getGraphTensors
(
graph_proto
);
// List of internal blobs shapes.
std
::
map
<
std
::
string
,
MatShape
>
outShapes
;
...
...
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
View file @
c1c84d2f
...
...
@@ -9,6 +9,7 @@
#ifdef HAVE_PROTOBUF
#include "../graph_simplifier.hpp"
#include "tf_graph_simplifier.hpp"
#include <queue>
...
...
@@ -18,203 +19,87 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
using
::
google
::
protobuf
::
RepeatedField
;
using
::
google
::
protobuf
::
MapPair
;
class
Subgraph
// Interface to match and replace TensorFlow subgraphs.
class
TFNodeWrapper
:
public
ImportNodeWrapper
{
public
:
virtual
~
Subgraph
(
)
{}
TFNodeWrapper
(
tensorflow
::
NodeDef
*
_node
)
:
node
(
_node
)
{}
// Add a node to be matched in the origin graph. Specify ids of nodes that
// are expected to be inputs. Returns id of a newly added node.
// TODO: Replace inputs to std::vector<int> in C++11
int
addNodeToMatch
(
const
std
::
string
&
op
,
int
input_0
=
-
1
,
int
input_1
=
-
1
,
int
input_2
=
-
1
,
int
input_3
=
-
1
)
virtual
int
getNumInputs
()
const
CV_OVERRIDE
{
int
nodeInputs
[]
=
{
input_0
,
input_1
,
input_2
,
input_3
};
int
numInputs
=
0
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
numInputs
+=
(
int
)(
nodeInputs
[
i
]
!=
-
1
);
}
return
addNodeToMatch
(
op
,
std
::
vector
<
int
>
(
&
nodeInputs
[
0
],
&
nodeInputs
[
0
]
+
numInputs
));
return
node
->
input_size
();
}
int
addNodeToMatch
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
)
virtual
std
::
string
getInputName
(
int
idx
)
const
CV_OVERRIDE
{
for
(
int
i
=
0
;
i
<
inputs_
.
size
();
++
i
)
{
CV_Assert
(
inputs_
[
i
]
<
(
int
)
nodes
.
size
());
}
nodes
.
push_back
(
op
);
inputs
.
push_back
(
inputs_
);
return
nodes
.
size
()
-
1
;
return
node
->
input
(
idx
);
}
// Specify resulting node. All the matched nodes in subgraph excluding
// input nodes will be fused into this single node.
// TODO: Replace inputs to std::vector<int> in C++11
void
setFusedNode
(
const
std
::
string
&
op
,
int
input_0
=
-
1
,
int
input_1
=
-
1
,
int
input_2
=
-
1
,
int
input_3
=
-
1
,
int
input_4
=
-
1
,
int
input_5
=
-
1
)
virtual
std
::
string
getType
()
const
CV_OVERRIDE
{
int
nodeInputs
[]
=
{
input_0
,
input_1
,
input_2
,
input_3
,
input_4
,
input_5
};
int
numInputs
=
0
;
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
CV_Assert
(
nodeInputs
[
i
]
<
(
int
)
nodes
.
size
());
numInputs
+=
(
int
)(
nodeInputs
[
i
]
!=
-
1
);
}
setFusedNode
(
op
,
std
::
vector
<
int
>
(
&
nodeInputs
[
0
],
&
nodeInputs
[
0
]
+
numInputs
));
return
node
->
op
();
}
v
oid
setFusedNode
(
const
std
::
string
&
op
,
const
std
::
vector
<
int
>&
inputs_
)
v
irtual
void
setType
(
const
std
::
string
&
type
)
CV_OVERRIDE
{
fusedNodeInputs
=
inputs_
;
fusedNodeOp
=
op
;
node
->
set_op
(
type
);
}
static
int
getInputNodeId
(
const
tensorflow
::
GraphDef
&
net
,
const
tensorflow
::
NodeDef
&
node
,
int
inpId
)
virtual
void
setInputNames
(
const
std
::
vector
<
std
::
string
>&
inputs
)
CV_OVERRIDE
{
CV_Assert
(
inpId
<
node
.
input_size
());
std
::
string
name
=
node
.
input
(
inpId
);
// If operation produces several tensors, they are specified by index
// after ':' character. In example, "input:0".
name
=
name
.
substr
(
0
,
name
.
rfind
(
':'
));
const
int
numNodes
=
net
.
node_size
();
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
if
(
net
.
node
(
i
).
name
()
==
name
)
return
i
;
}
CV_Error
(
Error
::
StsParseError
,
"Input node with name "
+
name
+
" not found"
);
node
->
clear_input
();
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
node
->
add_input
(
inputs
[
i
]);
}
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
virtual
bool
match
(
const
tensorflow
::
GraphDef
&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
{
matchedNodesIds
.
clear
();
targetNodesIds
.
clear
();
std
::
queue
<
int
>
nodesToMatch
;
std
::
queue
<
int
>
targetNodes
;
nodesToMatch
.
push
(
nodeId
);
targetNodes
.
push
(
nodes
.
size
()
-
1
);
while
(
!
nodesToMatch
.
empty
())
{
int
nodeToMatch
=
nodesToMatch
.
front
();
int
targetNodeId
=
targetNodes
.
front
();
nodesToMatch
.
pop
();
targetNodes
.
pop
();
if
(
std
::
find
(
matchedNodesIds
.
begin
(),
matchedNodesIds
.
end
(),
nodeToMatch
)
!=
matchedNodesIds
.
end
())
continue
;
const
tensorflow
::
NodeDef
&
node
=
net
.
node
(
nodeToMatch
);
if
(
node
.
op
()
!=
nodes
[
targetNodeId
])
return
false
;
std
::
vector
<
int
>&
inputNodes
=
inputs
[
targetNodeId
];
if
(
inputNodes
.
size
()
!=
node
.
input_size
())
return
false
;
tensorflow
::
NodeDef
*
node
;
};
for
(
int
j
=
0
;
j
<
inputNodes
.
size
();
++
j
)
{
if
(
nodes
[
inputNodes
[
j
]].
empty
())
// Unknown input node type.
continue
;
nodeId
=
getInputNodeId
(
net
,
node
,
j
);
const
tensorflow
::
NodeDef
&
inpNode
=
net
.
node
(
nodeId
);
if
(
inpNode
.
op
()
!=
"Const"
)
{
nodesToMatch
.
push
(
nodeId
);
targetNodes
.
push
(
inputNodes
[
j
]);
}
else
if
(
nodes
[
inputNodes
[
j
]]
!=
"Const"
)
return
false
;
}
matchedNodesIds
.
push_back
(
nodeToMatch
);
targetNodesIds
.
push_back
(
targetNodeId
);
}
class
TFGraphWrapper
:
public
ImportGraphWrapper
{
public
:
TFGraphWrapper
(
tensorflow
::
GraphDef
&
_net
)
:
net
(
_net
)
{}
const
int
n
=
matchedNodesIds
.
size
();
std
::
vector
<
std
::
pair
<
int
,
int
>
>
elements
(
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
elements
[
i
]
=
std
::
make_pair
(
matchedNodesIds
[
i
],
targetNodesIds
[
i
]);
std
::
sort
(
elements
.
begin
(),
elements
.
end
());
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
matchedNodesIds
[
i
]
=
elements
[
i
].
first
;
targetNodesIds
[
i
]
=
elements
[
i
].
second
;
}
return
true
;
virtual
Ptr
<
ImportNodeWrapper
>
getNode
(
int
idx
)
const
CV_OVERRIDE
{
return
makePtr
<
TFNodeWrapper
>
(
net
.
mutable_node
(
idx
));
}
// Fuse matched subgraph.
void
replace
(
tensorflow
::
GraphDef
&
net
,
const
std
::
vector
<
int
>&
matchedNodesIds
,
const
std
::
vector
<
int
>&
targetNodesIds
)
virtual
int
getNumNodes
()
const
CV_OVERRIDE
{
// Extract names of input nodes.
std
::
vector
<
std
::
string
>
inputsNames
(
fusedNodeInputs
.
size
());
for
(
int
i
=
0
;
i
<
fusedNodeInputs
.
size
();
++
i
)
{
std
::
string
inpName
;
// Find input node name looking at inputs of fused nodes.
for
(
int
j
=
0
;
j
<
matchedNodesIds
.
size
()
&&
inpName
.
empty
();
++
j
)
{
const
tensorflow
::
NodeDef
&
node
=
net
.
node
(
matchedNodesIds
[
j
]);
std
::
vector
<
int
>&
inpIndices
=
inputs
[
targetNodesIds
[
j
]];
CV_Assert
(
node
.
input_size
()
==
inpIndices
.
size
());
for
(
int
k
=
0
;
k
<
inpIndices
.
size
();
++
k
)
{
if
(
inpIndices
[
k
]
==
fusedNodeInputs
[
i
])
{
inpName
=
node
.
input
(
k
);
break
;
}
}
}
CV_Assert
(
!
inpName
.
empty
());
inputsNames
[
i
]
=
inpName
;
}
// Remove matched nodes except the last one. Indices in ascending order are expected.
tensorflow
::
NodeDef
*
node
=
net
.
mutable_node
(
matchedNodesIds
.
back
());
for
(
int
i
=
matchedNodesIds
.
size
()
-
2
;
i
>=
0
;
--
i
)
net
.
mutable_node
()
->
DeleteSubrange
(
matchedNodesIds
[
i
],
1
);
return
net
.
node_size
();
}
// Modify the last node to be a fused one.
node
->
set_op
(
fusedNodeOp
);
node
->
clear_input
();
for
(
int
i
=
0
;
i
<
inputsNames
.
size
();
++
i
)
{
node
->
add_input
(
inputsNames
[
i
]);
}
virtual
std
::
string
getNodeName
(
int
idx
)
const
CV_OVERRIDE
{
return
net
.
node
(
idx
).
name
();
}
std
::
vector
<
tensorflow
::
NodeDef
*>
inputNodes
(
inputsNames
.
size
());
for
(
int
i
=
0
;
i
<
inputsNames
.
size
();
++
i
)
{
inputNodes
[
i
]
=
net
.
mutable_node
(
getInputNodeId
(
net
,
*
node
,
i
));
}
finalize
(
net
,
node
,
inputNodes
);
virtual
void
removeNode
(
int
idx
)
CV_OVERRIDE
{
net
.
mutable_node
()
->
DeleteSubrange
(
idx
,
1
);
}
virtual
void
finalize
(
tensorflow
::
GraphDef
&
,
tensorflow
::
NodeDef
*
,
std
::
vector
<
tensorflow
::
NodeDef
*>&
)
{}
tensorflow
::
GraphDef
&
net
;
};
private
:
std
::
vector
<
std
::
string
>
nodes
;
// Nodes to be matched in the origin graph.
std
::
vector
<
std
::
vector
<
int
>
>
inputs
;
// Connections of an every node to it's inputs.
class
TFSubgraph
:
public
Subgraph
{
virtual
void
finalize
(
const
Ptr
<
ImportGraphWrapper
>&
netWrapper
,
const
Ptr
<
ImportNodeWrapper
>&
fusedNodeWrapper
,
std
::
vector
<
Ptr
<
ImportNodeWrapper
>
>&
inputs
)
CV_OVERRIDE
{
std
::
vector
<
tensorflow
::
NodeDef
*>
inputNodes
(
inputs
.
size
());
for
(
int
i
=
0
;
i
<
inputs
.
size
();
++
i
)
inputNodes
[
i
]
=
inputs
[
i
].
dynamicCast
<
TFNodeWrapper
>
()
->
node
;
finalize
(
netWrapper
.
dynamicCast
<
TFGraphWrapper
>
()
->
net
,
fusedNodeWrapper
.
dynamicCast
<
TFNodeWrapper
>
()
->
node
,
inputNodes
);
}
std
::
string
fusedNodeOp
;
// Operation name of resulting fused node.
std
::
vector
<
int
>
fusedNodeInputs
;
// Inputs of fused node.
virtual
void
finalize
(
tensorflow
::
GraphDef
&
,
tensorflow
::
NodeDef
*
fusedNode
,
std
::
vector
<
tensorflow
::
NodeDef
*>&
inputNodes
)
{}
};
class
BatchNormSubgraph
:
public
Subgraph
class
BatchNormSubgraph
:
public
TF
Subgraph
{
public
:
BatchNormSubgraph
()
...
...
@@ -250,7 +135,7 @@ public:
}
};
class
BatchNormNoGammaSubgraph
:
public
Subgraph
class
BatchNormNoGammaSubgraph
:
public
TF
Subgraph
{
public
:
BatchNormNoGammaSubgraph
()
...
...
@@ -366,20 +251,21 @@ public:
setFusedNode
(
"Relu6"
,
input
);
}
virtual
bool
match
(
const
tensorflow
::
GraphDef
&
net
,
int
nodeId
,
virtual
bool
match
(
const
Ptr
<
ImportGraphWrapper
>
&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
CV_OVERRIDE
{
if
(
!
Subgraph
::
match
(
net
,
nodeId
,
matchedNodesIds
,
targetNodesIds
))
return
false
;
Mat
maxValue
=
getTensorContent
(
net
.
node
(
matchedNodesIds
.
front
()
+
1
).
attr
().
at
(
"value"
).
tensor
());
tensorflow
::
NodeDef
*
node
=
net
->
getNode
(
matchedNodesIds
.
front
()
+
1
).
dynamicCast
<
TFNodeWrapper
>
()
->
node
;
Mat
maxValue
=
getTensorContent
(
node
->
attr
().
at
(
"value"
).
tensor
());
return
maxValue
.
type
()
==
CV_32FC1
&&
maxValue
.
total
()
==
1
&&
maxValue
.
at
<
float
>
(
0
)
==
6
;
}
};
// Keras' reshape stores output shape in separate Const nodes by one value.
// Need to merge them into a single Const node.
class
ReshapeKerasSubgraph
:
public
Subgraph
class
ReshapeKerasSubgraph
:
public
TF
Subgraph
{
public
:
ReshapeKerasSubgraph
(
int
_numOutDims
)
:
numOutDims
(
_numOutDims
)
...
...
@@ -402,15 +288,15 @@ public:
setFusedNode
(
"Reshape"
,
ids
);
}
virtual
bool
match
(
const
tensorflow
::
GraphDef
&
net
,
int
nodeId
,
virtual
bool
match
(
const
Ptr
<
ImportGraphWrapper
>
&
net
,
int
nodeId
,
std
::
vector
<
int
>&
matchedNodesIds
,
std
::
vector
<
int
>&
targetNodesIds
)
CV_OVERRIDE
{
const
tensorflow
::
NodeDef
&
node
=
net
.
n
ode
(
nodeId
);
if
(
node
.
input_size
()
==
0
)
Ptr
<
ImportNodeWrapper
>
node
=
net
->
getN
ode
(
nodeId
);
if
(
node
->
getNumInputs
()
==
0
)
return
false
;
inpName
=
node
.
input
(
0
);
inpName
=
node
->
getInputName
(
0
);
return
Subgraph
::
match
(
net
,
nodeId
,
matchedNodesIds
,
targetNodesIds
);
}
...
...
@@ -457,7 +343,7 @@ public:
}
};
class
DeconvolutionValidKerasSubgraph
:
public
Subgraph
class
DeconvolutionValidKerasSubgraph
:
public
TF
Subgraph
{
public
:
DeconvolutionValidKerasSubgraph
()
...
...
@@ -518,7 +404,7 @@ public:
}
};
class
DeconvolutionSameKerasSubgraph
:
public
Subgraph
class
DeconvolutionSameKerasSubgraph
:
public
TF
Subgraph
{
public
:
DeconvolutionSameKerasSubgraph
()
...
...
@@ -608,7 +494,7 @@ public:
};
// In case of resizing by factor.
class
UpsamplingKerasSubgraph
:
public
Subgraph
class
UpsamplingKerasSubgraph
:
public
TF
Subgraph
{
public
:
UpsamplingKerasSubgraph
(
const
std
::
string
&
type
)
...
...
@@ -703,7 +589,7 @@ public:
}
};
class
KerasMVNSubgraph
:
public
Subgraph
class
KerasMVNSubgraph
:
public
TF
Subgraph
{
public
:
KerasMVNSubgraph
()
...
...
@@ -758,20 +644,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
subgraphs
.
push_back
(
Ptr
<
Subgraph
>
(
new
ReshapeAsShapeSubgraph
()));
subgraphs
.
push_back
(
Ptr
<
Subgraph
>
(
new
KerasMVNSubgraph
()));
int
numNodes
=
net
.
node_size
();
std
::
vector
<
int
>
matchedNodesIds
,
targetNodesIds
;
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
for
(
int
j
=
0
;
j
<
subgraphs
.
size
();
++
j
)
{
if
(
subgraphs
[
j
]
->
match
(
net
,
i
,
matchedNodesIds
,
targetNodesIds
))
{
subgraphs
[
j
]
->
replace
(
net
,
matchedNodesIds
,
targetNodesIds
);
numNodes
-=
matchedNodesIds
.
size
()
-
1
;
// #matchedNodes removed and one added.
break
;
}
}
}
simplifySubgraphs
(
Ptr
<
ImportGraphWrapper
>
(
new
TFGraphWrapper
(
net
)),
subgraphs
);
}
void
RemoveIdentityOps
(
tensorflow
::
GraphDef
&
net
)
...
...
modules/dnn/test/test_onnx_importer.cpp
View file @
c1c84d2f
...
...
@@ -396,6 +396,7 @@ TEST_P(Test_ONNX_layers, Softmax)
{
testONNXModels
(
"softmax"
);
testONNXModels
(
"log_softmax"
,
npy
,
0
,
0
,
false
,
false
);
testONNXModels
(
"softmax_unfused"
);
}
TEST_P
(
Test_ONNX_layers
,
Split_EltwiseMax
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment