Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
69a8f110
Commit
69a8f110
authored
Mar 01, 2018
by
Dmitry Kurtaev
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fuse subgraphs from Keras
parent
9457bf10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
61 additions
and
7 deletions
+61
-7
tf_graph_simplifier.cpp
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
+0
-0
tf_graph_simplifier.hpp
modules/dnn/src/tensorflow/tf_graph_simplifier.hpp
+0
-0
tf_importer.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
+45
-7
test_tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp
+16
-0
No files found.
modules/dnn/src/tensorflow/tf_graph_
edito
r.cpp
→
modules/dnn/src/tensorflow/tf_graph_
simplifie
r.cpp
View file @
69a8f110
This diff is collapsed.
Click to expand it.
modules/dnn/src/tensorflow/tf_graph_
edito
r.hpp
→
modules/dnn/src/tensorflow/tf_graph_
simplifie
r.hpp
View file @
69a8f110
File moved
modules/dnn/src/tensorflow/tf_importer.cpp
View file @
69a8f110
...
...
@@ -22,7 +22,7 @@ Implementation of Tensorflow models parser
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include "tf_io.hpp"
#include "tf_graph_
edito
r.hpp"
#include "tf_graph_
simplifie
r.hpp"
#endif
namespace
cv
{
...
...
@@ -715,9 +715,9 @@ void TFImporter::populateNet(Net dstNet)
if
(
hasLayerAttr
(
layer
,
"data_format"
))
{
std
::
string
format
=
getLayerAttr
(
layer
,
"data_format"
).
s
();
if
(
format
==
"NHWC"
)
if
(
format
==
"NHWC"
||
format
==
"channels_last"
)
data_layouts
[
name
]
=
DATA_LAYOUT_NHWC
;
else
if
(
format
==
"NCHW"
)
else
if
(
format
==
"NCHW"
||
format
==
"channels_first"
)
data_layouts
[
name
]
=
DATA_LAYOUT_NCHW
;
else
CV_Error
(
Error
::
StsParseError
,
"Unknown data_format value: "
+
format
);
...
...
@@ -804,9 +804,9 @@ void TFImporter::populateNet(Net dstNet)
else
if
(
type
==
"Reshape"
)
{
Pin
inpId
=
parsePin
(
layer
.
input
(
0
));
DictValue
newShape
=
parseDims
(
getConstBlob
(
layer
,
value_id
,
1
));
Mat
newShape
=
getTensorContent
(
getConstBlob
(
layer
,
value_id
,
1
));
if
(
newShape
.
size
()
!=
4
&&
data_layouts
[
layer
.
input
(
0
)]
==
DATA_LAYOUT_NHWC
)
if
(
newShape
.
total
()
!=
4
&&
data_layouts
[
layer
.
input
(
0
)]
==
DATA_LAYOUT_NHWC
)
{
LayerParams
permLP
;
int
order
[]
=
{
0
,
2
,
3
,
1
};
// From OpenCV's NCHW to NHWC.
...
...
@@ -819,14 +819,19 @@ void TFImporter::populateNet(Net dstNet)
connect
(
layer_id
,
dstNet
,
inpId
,
permId
,
0
);
inpId
=
Pin
(
permName
);
}
layerParams
.
set
(
"dim"
,
newShape
);
else
if
(
newShape
.
total
()
==
4
&&
data_layouts
[
layer
.
input
(
0
)]
==
DATA_LAYOUT_NHWC
)
{
// NHWC->NCHW
std
::
swap
(
*
newShape
.
ptr
<
int32_t
>
(
0
,
2
),
*
newShape
.
ptr
<
int32_t
>
(
0
,
3
));
std
::
swap
(
*
newShape
.
ptr
<
int32_t
>
(
0
,
1
),
*
newShape
.
ptr
<
int32_t
>
(
0
,
2
));
}
layerParams
.
set
(
"dim"
,
DictValue
::
arrayInt
<
int
*>
(
newShape
.
ptr
<
int
>
(),
newShape
.
total
()));
int
id
=
dstNet
.
addLayer
(
name
,
"Reshape"
,
layerParams
);
layer_id
[
name
]
=
id
;
// one input only
connect
(
layer_id
,
dstNet
,
inpId
,
id
,
0
);
data_layouts
[
name
]
=
DATA_LAYOUT_UNKNOWN
;
}
else
if
(
type
==
"Flatten"
||
type
==
"Squeeze"
)
{
...
...
@@ -1488,6 +1493,39 @@ void TFImporter::populateNet(Net dstNet)
layer_id
[
name
]
=
id
;
connectToAllBlobs
(
layer_id
,
dstNet
,
parsePin
(
layer
.
input
(
0
)),
id
,
layer
.
input_size
());
}
else
if
(
type
==
"Mean"
)
{
Mat
indices
=
getTensorContent
(
getConstBlob
(
layer
,
value_id
,
1
));
CV_Assert
(
indices
.
type
()
==
CV_32SC1
);
if
(
indices
.
total
()
!=
2
||
indices
.
at
<
int
>
(
0
)
!=
1
||
indices
.
at
<
int
>
(
1
)
!=
2
)
CV_Error
(
Error
::
StsNotImplemented
,
"Unsupported mode of reduce_mean operation."
);
layerParams
.
set
(
"pool"
,
"ave"
);
layerParams
.
set
(
"global_pooling"
,
true
);
int
id
=
dstNet
.
addLayer
(
name
,
"Pooling"
,
layerParams
);
layer_id
[
name
]
=
id
;
connect
(
layer_id
,
dstNet
,
parsePin
(
layer
.
input
(
0
)),
id
,
0
);
// There are two attributes, "keepdims" and a deprecated "keep_dims".
bool
keepDims
=
false
;
if
(
hasLayerAttr
(
layer
,
"keepdims"
))
keepDims
=
getLayerAttr
(
layer
,
"keepdims"
).
b
();
else
if
(
hasLayerAttr
(
layer
,
"keep_dims"
))
keepDims
=
getLayerAttr
(
layer
,
"keep_dims"
).
b
();
if
(
!
keepDims
)
{
LayerParams
flattenLp
;
std
::
string
flattenName
=
name
+
"/flatten"
;
CV_Assert
(
layer_id
.
find
(
flattenName
)
==
layer_id
.
end
());
int
flattenId
=
dstNet
.
addLayer
(
flattenName
,
"Flatten"
,
flattenLp
);
layer_id
[
flattenName
]
=
flattenId
;
connect
(
layer_id
,
dstNet
,
Pin
(
name
),
flattenId
,
0
);
}
}
else
if
(
type
==
"Abs"
||
type
==
"Tanh"
||
type
==
"Sigmoid"
||
type
==
"Relu"
||
type
==
"Elu"
||
type
==
"Identity"
||
type
==
"Relu6"
)
...
...
modules/dnn/test/test_tf_importer.cpp
View file @
69a8f110
...
...
@@ -162,6 +162,7 @@ TEST_P(Test_TensorFlow_layers, pooling)
runTensorFlowNet
(
"max_pool_odd_valid"
,
targetId
);
runTensorFlowNet
(
"ave_pool_same"
,
targetId
);
runTensorFlowNet
(
"max_pool_odd_same"
,
targetId
);
runTensorFlowNet
(
"reduce_mean"
,
targetId
);
// an average pooling over all spatial dimensions.
}
TEST_P
(
Test_TensorFlow_layers
,
deconvolution
)
...
...
@@ -337,6 +338,21 @@ TEST(Test_TensorFlow, slice)
runTensorFlowNet
(
"slice_4d"
);
}
TEST
(
Test_TensorFlow
,
softmax
)
{
runTensorFlowNet
(
"keras_softmax"
);
}
TEST
(
Test_TensorFlow
,
relu6
)
{
runTensorFlowNet
(
"keras_relu6"
);
}
TEST
(
Test_TensorFlow
,
keras_mobilenet_head
)
{
runTensorFlowNet
(
"keras_mobilenet_head"
);
}
TEST
(
Test_TensorFlow
,
memory_read
)
{
double
l1
=
1e-5
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment