Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
e18d5e94
Commit
e18d5e94
authored
Mar 01, 2020
by
ashishiva3@gmail.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Gather-Cast, Mul-Cast fusion
parent
4d0f1354
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
15 deletions
+43
-15
graph_simplifier.cpp
modules/dnn/src/graph_simplifier.cpp
+2
-3
onnx_graph_simplifier.cpp
modules/dnn/src/onnx/onnx_graph_simplifier.cpp
+40
-12
test_onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp
+1
-0
No files found.
modules/dnn/src/graph_simplifier.cpp
View file @
e18d5e94
...
@@ -194,15 +194,14 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
...
@@ -194,15 +194,14 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
{
{
int
numNodes
=
net
->
getNumNodes
();
int
numNodes
=
net
->
getNumNodes
();
std
::
vector
<
int
>
matchedNodesIds
,
targetNodesIds
;
std
::
vector
<
int
>
matchedNodesIds
,
targetNodesIds
;
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
for
(
int
j
=
0
;
j
<
patterns
.
size
();
++
j
)
{
{
for
(
int
j
=
0
;
j
<
patterns
.
size
();
++
j
)
for
(
int
i
=
0
;
i
<
numNodes
;
++
i
)
{
{
if
(
patterns
[
j
]
->
match
(
net
,
i
,
matchedNodesIds
,
targetNodesIds
))
if
(
patterns
[
j
]
->
match
(
net
,
i
,
matchedNodesIds
,
targetNodesIds
))
{
{
patterns
[
j
]
->
replace
(
net
,
matchedNodesIds
,
targetNodesIds
);
patterns
[
j
]
->
replace
(
net
,
matchedNodesIds
,
targetNodesIds
);
numNodes
-=
matchedNodesIds
.
size
()
-
1
;
// #matchedNodes removed and one added.
numNodes
-=
matchedNodesIds
.
size
()
-
1
;
// #matchedNodes removed and one added.
break
;
}
}
}
}
}
}
...
...
modules/dnn/src/onnx/onnx_graph_simplifier.cpp
View file @
e18d5e94
...
@@ -154,6 +154,32 @@ private:
...
@@ -154,6 +154,32 @@ private:
int
axis
;
int
axis
;
};
};
class
GatherCastSubgraph
:
public
Subgraph
{
public
:
GatherCastSubgraph
()
{
int
input
=
addNodeToMatch
(
""
);
int
index
=
addNodeToMatch
(
"Constant"
);
int
gather
=
addNodeToMatch
(
"Gather"
,
input
,
index
);
addNodeToMatch
(
"Cast"
,
gather
);
setFusedNode
(
"Gather"
,
input
,
index
);
}
};
class
MulCastSubgraph
:
public
Subgraph
{
public
:
MulCastSubgraph
()
{
int
input
=
addNodeToMatch
(
""
);
int
scaleNode
=
addNodeToMatch
(
"Constant"
);
int
mul
=
addNodeToMatch
(
"Mul"
,
input
,
scaleNode
);
addNodeToMatch
(
"Cast"
,
mul
);
setFusedNode
(
"Mul"
,
input
,
scaleNode
);
}
};
class
ExtractScalesSubgraph
:
public
Subgraph
class
ExtractScalesSubgraph
:
public
Subgraph
{
{
public
:
public
:
...
@@ -164,20 +190,16 @@ public:
...
@@ -164,20 +190,16 @@ public:
int
indexH
=
addNodeToMatch
(
"Constant"
);
int
indexH
=
addNodeToMatch
(
"Constant"
);
int
shape1
=
addNodeToMatch
(
"Shape"
,
input
);
int
shape1
=
addNodeToMatch
(
"Shape"
,
input
);
int
gather1
=
addNodeToMatch
(
"Gather"
,
shape1
,
indexH
);
int
gather1
=
addNodeToMatch
(
"Gather"
,
shape1
,
indexH
);
int
castG1
=
addNodeToMatch
(
"Cast"
,
gather1
);
scaleHNode
=
addNodeToMatch
(
"Constant"
);
scaleHNode
=
addNodeToMatch
(
"Constant"
);
int
mul1
=
addNodeToMatch
(
"Mul"
,
castG1
,
scaleHNode
);
int
mul1
=
addNodeToMatch
(
"Mul"
,
gather1
,
scaleHNode
);
int
castM1
=
addNodeToMatch
(
"Cast"
,
mul1
);
int
floor1
=
addNodeToMatch
(
"Floor"
,
mul1
);
int
floor1
=
addNodeToMatch
(
"Floor"
,
castM1
);
int
indexW
=
addNodeToMatch
(
"Constant"
);
int
indexW
=
addNodeToMatch
(
"Constant"
);
int
shape2
=
addNodeToMatch
(
"Shape"
,
input
);
int
shape2
=
addNodeToMatch
(
"Shape"
,
input
);
int
gather2
=
addNodeToMatch
(
"Gather"
,
shape2
,
indexW
);
int
gather2
=
addNodeToMatch
(
"Gather"
,
shape2
,
indexW
);
int
castG2
=
addNodeToMatch
(
"Cast"
,
gather2
);
scaleWNode
=
addNodeToMatch
(
"Constant"
);
scaleWNode
=
addNodeToMatch
(
"Constant"
);
int
mul2
=
addNodeToMatch
(
"Mul"
,
castG2
,
scaleWNode
);
int
mul2
=
addNodeToMatch
(
"Mul"
,
gather2
,
scaleWNode
);
int
castM2
=
addNodeToMatch
(
"Cast"
,
mul2
);
int
floor2
=
addNodeToMatch
(
"Floor"
,
mul2
);
int
floor2
=
addNodeToMatch
(
"Floor"
,
castM2
);
int
unsqueeze1
=
addNodeToMatch
(
"Unsqueeze"
,
floor1
);
int
unsqueeze1
=
addNodeToMatch
(
"Unsqueeze"
,
floor1
);
int
unsqueeze2
=
addNodeToMatch
(
"Unsqueeze"
,
floor2
);
int
unsqueeze2
=
addNodeToMatch
(
"Unsqueeze"
,
floor2
);
...
@@ -190,19 +212,23 @@ public:
...
@@ -190,19 +212,23 @@ public:
{
{
opencv_onnx
::
NodeProto
*
constant_node
=
inputs
[
1
].
dynamicCast
<
ONNXNodeWrapper
>
()
->
node
;
opencv_onnx
::
NodeProto
*
constant_node
=
inputs
[
1
].
dynamicCast
<
ONNXNodeWrapper
>
()
->
node
;
opencv_onnx
::
TensorProto
tensor_proto
=
constant_node
->
attribute
(
0
).
t
();
opencv_onnx
::
TensorProto
tensor_proto
=
constant_node
->
attribute
(
0
).
t
();
float
scaleW
=
getMatFromTensor
(
tensor_proto
).
at
<
float
>
(
0
);
Mat
scaleW
=
getMatFromTensor
(
tensor_proto
);
CV_Assert
(
scaleW
.
total
()
==
1
);
scaleW
.
convertTo
(
scaleW
,
CV_32F
);
constant_node
=
inputs
[
2
].
dynamicCast
<
ONNXNodeWrapper
>
()
->
node
;
constant_node
=
inputs
[
2
].
dynamicCast
<
ONNXNodeWrapper
>
()
->
node
;
tensor_proto
=
constant_node
->
attribute
(
0
).
t
();
tensor_proto
=
constant_node
->
attribute
(
0
).
t
();
float
scaleH
=
getMatFromTensor
(
tensor_proto
).
at
<
float
>
(
0
);
Mat
scaleH
=
getMatFromTensor
(
tensor_proto
);
CV_Assert
(
scaleH
.
total
()
==
1
);
scaleH
.
convertTo
(
scaleH
,
CV_32F
);
opencv_onnx
::
NodeProto
*
node
=
fusedNode
.
dynamicCast
<
ONNXNodeWrapper
>
()
->
node
;
opencv_onnx
::
NodeProto
*
node
=
fusedNode
.
dynamicCast
<
ONNXNodeWrapper
>
()
->
node
;
opencv_onnx
::
AttributeProto
*
attrH
=
node
->
add_attribute
();
opencv_onnx
::
AttributeProto
*
attrH
=
node
->
add_attribute
();
attrH
->
set_name
(
"height_scale"
);
attrH
->
set_name
(
"height_scale"
);
attrH
->
set_i
(
scaleH
);
attrH
->
set_i
(
scaleH
.
at
<
float
>
(
0
)
);
opencv_onnx
::
AttributeProto
*
attrW
=
node
->
add_attribute
();
opencv_onnx
::
AttributeProto
*
attrW
=
node
->
add_attribute
();
attrW
->
set_name
(
"width_scale"
);
attrW
->
set_name
(
"width_scale"
);
attrW
->
set_i
(
scaleW
);
attrW
->
set_i
(
scaleW
.
at
<
float
>
(
0
)
);
node
->
mutable_input
()
->
DeleteSubrange
(
1
,
2
);
// Remove two last inputs
node
->
mutable_input
()
->
DeleteSubrange
(
1
,
2
);
// Remove two last inputs
}
}
...
@@ -267,6 +293,8 @@ public:
...
@@ -267,6 +293,8 @@ public:
void
simplifySubgraphs
(
opencv_onnx
::
GraphProto
&
net
)
void
simplifySubgraphs
(
opencv_onnx
::
GraphProto
&
net
)
{
{
std
::
vector
<
Ptr
<
Subgraph
>
>
subgraphs
;
std
::
vector
<
Ptr
<
Subgraph
>
>
subgraphs
;
subgraphs
.
push_back
(
makePtr
<
GatherCastSubgraph
>
());
subgraphs
.
push_back
(
makePtr
<
MulCastSubgraph
>
());
subgraphs
.
push_back
(
makePtr
<
UpsampleSubgraph
>
());
subgraphs
.
push_back
(
makePtr
<
UpsampleSubgraph
>
());
subgraphs
.
push_back
(
makePtr
<
ResizeSubgraph1
>
());
subgraphs
.
push_back
(
makePtr
<
ResizeSubgraph1
>
());
subgraphs
.
push_back
(
makePtr
<
ResizeSubgraph2
>
());
subgraphs
.
push_back
(
makePtr
<
ResizeSubgraph2
>
());
...
...
modules/dnn/test/test_onnx_importer.cpp
View file @
e18d5e94
...
@@ -320,6 +320,7 @@ TEST_P(Test_ONNX_layers, ResizeUnfused)
...
@@ -320,6 +320,7 @@ TEST_P(Test_ONNX_layers, ResizeUnfused)
{
{
if
(
backend
==
DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019
)
if
(
backend
==
DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019
)
applyTestTag
(
CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER
);
applyTestTag
(
CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER
);
testONNXModels
(
"upsample_unfused_torch1.2"
);
testONNXModels
(
"upsample_unfused_opset9_torch1.4"
);
testONNXModels
(
"upsample_unfused_opset9_torch1.4"
);
testONNXModels
(
"resize_nearest_unfused_opset11_torch1.4"
);
testONNXModels
(
"resize_nearest_unfused_opset11_torch1.4"
);
testONNXModels
(
"resize_nearest_unfused_opset11_torch1.3"
);
testONNXModels
(
"resize_nearest_unfused_opset11_torch1.3"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment