Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
ba1a6ad4
Commit
ba1a6ad4
authored
7 years ago
by
Vadim Pisarevsky
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #11840 from dkurt:dnn_tf_nchw
parents
34ad9b8a
dbeb4a11
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
21 deletions
+42
-21
tf_importer.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
+41
-21
test_tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp
+1
-0
No files found.
modules/dnn/src/tensorflow/tf_importer.cpp
View file @
ba1a6ad4
...
...
@@ -246,16 +246,41 @@ const tensorflow::AttrValue& getLayerAttr(const tensorflow::NodeDef &layer, cons
return
layer
.
attr
().
at
(
name
);
}
static
int
getDataLayout
(
const
tensorflow
::
NodeDef
&
layer
)
{
if
(
hasLayerAttr
(
layer
,
"data_format"
))
{
std
::
string
format
=
getLayerAttr
(
layer
,
"data_format"
).
s
();
if
(
format
==
"NHWC"
||
format
==
"channels_last"
)
return
DATA_LAYOUT_NHWC
;
else
if
(
format
==
"NCHW"
||
format
==
"channels_first"
)
return
DATA_LAYOUT_NCHW
;
else
CV_Error
(
Error
::
StsParseError
,
"Unknown data_format value: "
+
format
);
}
return
DATA_LAYOUT_UNKNOWN
;
}
void
setStrides
(
LayerParams
&
layerParams
,
const
tensorflow
::
NodeDef
&
layer
)
{
if
(
hasLayerAttr
(
layer
,
"strides"
))
{
const
tensorflow
::
AttrValue
&
val
=
getLayerAttr
(
layer
,
"strides"
);
int
dimX
,
dimY
,
dimC
;
int
layout
=
getDataLayout
(
layer
);
if
(
layout
==
DATA_LAYOUT_NCHW
)
{
dimC
=
1
;
dimY
=
2
;
dimX
=
3
;
}
else
{
dimY
=
1
;
dimX
=
2
;
dimC
=
3
;
}
if
(
val
.
list
().
i_size
()
!=
4
||
val
.
list
().
i
(
0
)
!=
1
||
val
.
list
().
i
(
3
)
!=
1
)
val
.
list
().
i
(
0
)
!=
1
||
val
.
list
().
i
(
dimC
)
!=
1
)
CV_Error
(
Error
::
StsError
,
"Unsupported strides"
);
layerParams
.
set
(
"stride_h"
,
static_cast
<
int
>
(
val
.
list
().
i
(
1
)));
layerParams
.
set
(
"stride_w"
,
static_cast
<
int
>
(
val
.
list
().
i
(
2
)));
layerParams
.
set
(
"stride_h"
,
static_cast
<
int
>
(
val
.
list
().
i
(
dimY
)));
layerParams
.
set
(
"stride_w"
,
static_cast
<
int
>
(
val
.
list
().
i
(
dimX
)));
}
}
...
...
@@ -278,11 +303,21 @@ void setKSize(LayerParams &layerParams, const tensorflow::NodeDef &layer)
if
(
hasLayerAttr
(
layer
,
"ksize"
))
{
const
tensorflow
::
AttrValue
&
val
=
getLayerAttr
(
layer
,
"ksize"
);
int
dimX
,
dimY
,
dimC
;
int
layout
=
getDataLayout
(
layer
);
if
(
layout
==
DATA_LAYOUT_NCHW
)
{
dimC
=
1
;
dimY
=
2
;
dimX
=
3
;
}
else
{
dimY
=
1
;
dimX
=
2
;
dimC
=
3
;
}
if
(
val
.
list
().
i_size
()
!=
4
||
val
.
list
().
i
(
0
)
!=
1
||
val
.
list
().
i
(
3
)
!=
1
)
val
.
list
().
i
(
0
)
!=
1
||
val
.
list
().
i
(
dimC
)
!=
1
)
CV_Error
(
Error
::
StsError
,
"Unsupported ksize"
);
layerParams
.
set
(
"kernel_h"
,
static_cast
<
int
>
(
val
.
list
().
i
(
1
)));
layerParams
.
set
(
"kernel_w"
,
static_cast
<
int
>
(
val
.
list
().
i
(
2
)));
layerParams
.
set
(
"kernel_h"
,
static_cast
<
int
>
(
val
.
list
().
i
(
dimY
)));
layerParams
.
set
(
"kernel_w"
,
static_cast
<
int
>
(
val
.
list
().
i
(
dimX
)));
}
else
{
...
...
@@ -568,21 +603,6 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons
}
}
static
int
getDataLayout
(
const
tensorflow
::
NodeDef
&
layer
)
{
if
(
hasLayerAttr
(
layer
,
"data_format"
))
{
std
::
string
format
=
getLayerAttr
(
layer
,
"data_format"
).
s
();
if
(
format
==
"NHWC"
||
format
==
"channels_last"
)
return
DATA_LAYOUT_NHWC
;
else
if
(
format
==
"NCHW"
||
format
==
"channels_first"
)
return
DATA_LAYOUT_NCHW
;
else
CV_Error
(
Error
::
StsParseError
,
"Unknown data_format value: "
+
format
);
}
return
DATA_LAYOUT_UNKNOWN
;
}
static
inline
std
::
string
getNodeName
(
const
std
::
string
&
tensorName
)
{
return
tensorName
.
substr
(
0
,
tensorName
.
rfind
(
':'
));
...
...
This diff is collapsed.
Click to expand it.
modules/dnn/test/test_tf_importer.cpp
View file @
ba1a6ad4
...
...
@@ -127,6 +127,7 @@ TEST_P(Test_TensorFlow_layers, conv)
runTensorFlowNet
(
"atrous_conv2d_same"
,
targetId
);
runTensorFlowNet
(
"depthwise_conv2d"
,
targetId
);
runTensorFlowNet
(
"keras_atrous_conv2d_same"
,
targetId
);
runTensorFlowNet
(
"conv_pool_nchw"
,
targetId
);
}
TEST_P
(
Test_TensorFlow_layers
,
padding
)
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment