Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
7ed5d85f
Commit
7ed5d85f
authored
Jul 03, 2018
by
Dmitry Kurtaev
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add Reshape layer tests
parent
f73eff75
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
17 deletions
+44
-17
reshape_layer.cpp
modules/dnn/src/layers/reshape_layer.cpp
+10
-1
tf_importer.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
+24
-14
torch_importer.cpp
modules/dnn/src/torch/torch_importer.cpp
+2
-2
test_layers.cpp
modules/dnn/test/test_layers.cpp
+7
-0
test_tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp
+1
-0
No files found.
modules/dnn/src/layers/reshape_layer.cpp
View file @
7ed5d85f
...
@@ -82,17 +82,26 @@ static void computeShapeByReshapeMask(const MatShape &srcShape,
...
@@ -82,17 +82,26 @@ static void computeShapeByReshapeMask(const MatShape &srcShape,
{
{
if
(
matched
)
if
(
matched
)
{
{
if
(
i
==
0
||
total
(
srcShape
,
i
,
srcRange
.
end
)
!=
maskTotal
)
if
(
total
(
srcShape
,
i
,
srcRange
.
end
)
!=
maskTotal
)
{
{
srcRange
.
start
=
i
+
1
;
srcRange
.
start
=
i
+
1
;
break
;
break
;
}
}
else
if
(
i
==
0
)
{
srcRange
.
start
=
0
;
break
;
}
}
}
else
else
{
{
matched
=
total
(
srcShape
,
i
,
srcRange
.
end
)
==
maskTotal
;
matched
=
total
(
srcShape
,
i
,
srcRange
.
end
)
==
maskTotal
;
}
}
}
}
while
(
total
(
srcShape
,
srcRange
.
start
,
srcRange
.
end
)
!=
maskTotal
&&
srcRange
.
start
>
0
)
{
srcRange
.
start
-=
1
;
}
CV_Assert
(
total
(
srcShape
,
srcRange
.
start
,
srcRange
.
end
)
==
maskTotal
);
CV_Assert
(
total
(
srcShape
,
srcRange
.
start
,
srcRange
.
end
)
==
maskTotal
);
}
}
...
...
modules/dnn/src/tensorflow/tf_importer.cpp
View file @
7ed5d85f
...
@@ -262,6 +262,18 @@ static int getDataLayout(const tensorflow::NodeDef& layer)
...
@@ -262,6 +262,18 @@ static int getDataLayout(const tensorflow::NodeDef& layer)
return
DATA_LAYOUT_UNKNOWN
;
return
DATA_LAYOUT_UNKNOWN
;
}
}
static
inline
std
::
string
getNodeName
(
const
std
::
string
&
tensorName
)
{
return
tensorName
.
substr
(
0
,
tensorName
.
rfind
(
':'
));
}
static
inline
int
getDataLayout
(
const
std
::
string
&
layerName
,
const
std
::
map
<
String
,
int
>&
data_layouts
)
{
std
::
map
<
String
,
int
>::
const_iterator
it
=
data_layouts
.
find
(
getNodeName
(
layerName
));
return
it
!=
data_layouts
.
end
()
?
it
->
second
:
DATA_LAYOUT_UNKNOWN
;
}
void
setStrides
(
LayerParams
&
layerParams
,
const
tensorflow
::
NodeDef
&
layer
)
void
setStrides
(
LayerParams
&
layerParams
,
const
tensorflow
::
NodeDef
&
layer
)
{
{
if
(
hasLayerAttr
(
layer
,
"strides"
))
if
(
hasLayerAttr
(
layer
,
"strides"
))
...
@@ -604,11 +616,6 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons
...
@@ -604,11 +616,6 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons
}
}
}
}
static
inline
std
::
string
getNodeName
(
const
std
::
string
&
tensorName
)
{
return
tensorName
.
substr
(
0
,
tensorName
.
rfind
(
':'
));
}
// If all inputs of specific layer have the same data layout we can say that
// If all inputs of specific layer have the same data layout we can say that
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
static
int
predictOutputDataLayout
(
const
tensorflow
::
GraphDef
&
net
,
static
int
predictOutputDataLayout
(
const
tensorflow
::
GraphDef
&
net
,
...
@@ -830,7 +837,8 @@ void TFImporter::populateNet(Net dstNet)
...
@@ -830,7 +837,8 @@ void TFImporter::populateNet(Net dstNet)
// one input only
// one input only
connect
(
layer_id
,
dstNet
,
parsePin
(
input
),
id
,
0
);
connect
(
layer_id
,
dstNet
,
parsePin
(
input
),
id
,
0
);
if
(
data_layouts
[
name
]
==
DATA_LAYOUT_UNKNOWN
)
if
(
getDataLayout
(
name
,
data_layouts
)
==
DATA_LAYOUT_UNKNOWN
)
data_layouts
[
name
]
=
DATA_LAYOUT_NHWC
;
data_layouts
[
name
]
=
DATA_LAYOUT_NHWC
;
}
}
else
if
(
type
==
"BiasAdd"
||
type
==
"Add"
)
else
if
(
type
==
"BiasAdd"
||
type
==
"Add"
)
...
@@ -956,7 +964,8 @@ void TFImporter::populateNet(Net dstNet)
...
@@ -956,7 +964,8 @@ void TFImporter::populateNet(Net dstNet)
Pin
inpId
=
parsePin
(
layer
.
input
(
0
));
Pin
inpId
=
parsePin
(
layer
.
input
(
0
));
Mat
newShape
=
getTensorContent
(
getConstBlob
(
layer
,
value_id
,
1
));
Mat
newShape
=
getTensorContent
(
getConstBlob
(
layer
,
value_id
,
1
));
if
(
newShape
.
total
()
!=
4
&&
data_layouts
[
layer
.
input
(
0
)]
==
DATA_LAYOUT_NHWC
)
int
inpLayout
=
getDataLayout
(
layer
.
input
(
0
),
data_layouts
);
if
(
newShape
.
total
()
!=
4
&&
inpLayout
==
DATA_LAYOUT_NHWC
)
{
{
LayerParams
permLP
;
LayerParams
permLP
;
int
order
[]
=
{
0
,
2
,
3
,
1
};
// From OpenCV's NCHW to NHWC.
int
order
[]
=
{
0
,
2
,
3
,
1
};
// From OpenCV's NCHW to NHWC.
...
@@ -969,7 +978,7 @@ void TFImporter::populateNet(Net dstNet)
...
@@ -969,7 +978,7 @@ void TFImporter::populateNet(Net dstNet)
connect
(
layer_id
,
dstNet
,
inpId
,
permId
,
0
);
connect
(
layer_id
,
dstNet
,
inpId
,
permId
,
0
);
inpId
=
Pin
(
permName
);
inpId
=
Pin
(
permName
);
}
}
else
if
(
newShape
.
total
()
==
4
&&
data_layouts
[
layer
.
input
(
0
)]
==
DATA_LAYOUT_NHWC
)
else
if
(
newShape
.
total
()
==
4
&&
inpLayout
==
DATA_LAYOUT_NHWC
)
{
{
// NHWC->NCHW
// NHWC->NCHW
std
::
swap
(
*
newShape
.
ptr
<
int32_t
>
(
0
,
2
),
*
newShape
.
ptr
<
int32_t
>
(
0
,
3
));
std
::
swap
(
*
newShape
.
ptr
<
int32_t
>
(
0
,
2
),
*
newShape
.
ptr
<
int32_t
>
(
0
,
3
));
...
@@ -987,7 +996,7 @@ void TFImporter::populateNet(Net dstNet)
...
@@ -987,7 +996,7 @@ void TFImporter::populateNet(Net dstNet)
else
if
(
type
==
"Flatten"
||
type
==
"Squeeze"
)
else
if
(
type
==
"Flatten"
||
type
==
"Squeeze"
)
{
{
Pin
inpId
=
parsePin
(
layer
.
input
(
0
));
Pin
inpId
=
parsePin
(
layer
.
input
(
0
));
int
inpLayout
=
data_layouts
[
layer
.
input
(
0
)]
;
int
inpLayout
=
getDataLayout
(
layer
.
input
(
0
),
data_layouts
)
;
if
(
type
==
"Squeeze"
)
if
(
type
==
"Squeeze"
)
{
{
CV_Assert
(
hasLayerAttr
(
layer
,
"squeeze_dims"
));
CV_Assert
(
hasLayerAttr
(
layer
,
"squeeze_dims"
));
...
@@ -1032,7 +1041,8 @@ void TFImporter::populateNet(Net dstNet)
...
@@ -1032,7 +1041,8 @@ void TFImporter::populateNet(Net dstNet)
{
{
// Only NHWC <-> NCHW permutations are allowed. OpenCV is always
// Only NHWC <-> NCHW permutations are allowed. OpenCV is always
// keep NCHW layout this way.
// keep NCHW layout this way.
if
(
data_layouts
[
layer
.
input
(
0
)]
==
DATA_LAYOUT_NHWC
)
int
inpLayout
=
getDataLayout
(
layer
.
input
(
0
),
data_layouts
);
if
(
inpLayout
==
DATA_LAYOUT_NHWC
)
{
{
if
(
permData
[
0
]
==
0
&&
permData
[
1
]
==
3
&&
permData
[
2
]
==
1
&&
permData
[
3
]
==
2
)
if
(
permData
[
0
]
==
0
&&
permData
[
1
]
==
3
&&
permData
[
2
]
==
1
&&
permData
[
3
]
==
2
)
{
{
...
@@ -1049,7 +1059,7 @@ void TFImporter::populateNet(Net dstNet)
...
@@ -1049,7 +1059,7 @@ void TFImporter::populateNet(Net dstNet)
else
else
CV_Error
(
Error
::
StsParseError
,
"Only NHWC <-> NCHW permutations are allowed."
);
CV_Error
(
Error
::
StsParseError
,
"Only NHWC <-> NCHW permutations are allowed."
);
}
}
else
if
(
data_layouts
[
layer
.
input
(
0
)]
==
DATA_LAYOUT_NCHW
)
else
if
(
inpLayout
==
DATA_LAYOUT_NCHW
)
{
{
if
(
permData
[
0
]
==
0
&&
permData
[
1
]
==
2
&&
permData
[
2
]
==
3
&&
permData
[
3
]
==
1
)
if
(
permData
[
0
]
==
0
&&
permData
[
1
]
==
2
&&
permData
[
2
]
==
3
&&
permData
[
3
]
==
1
)
{
{
...
@@ -1112,7 +1122,7 @@ void TFImporter::populateNet(Net dstNet)
...
@@ -1112,7 +1122,7 @@ void TFImporter::populateNet(Net dstNet)
int
axisId
=
(
type
==
"Concat"
?
0
:
layer
.
input_size
()
-
1
);
int
axisId
=
(
type
==
"Concat"
?
0
:
layer
.
input_size
()
-
1
);
int
axis
=
getConstBlob
(
layer
,
value_id
,
axisId
).
int_val
().
Get
(
0
);
int
axis
=
getConstBlob
(
layer
,
value_id
,
axisId
).
int_val
().
Get
(
0
);
if
(
data_layouts
[
name
]
==
DATA_LAYOUT_NHWC
)
if
(
getDataLayout
(
name
,
data_layouts
)
==
DATA_LAYOUT_NHWC
)
axis
=
toNCHW
(
axis
);
axis
=
toNCHW
(
axis
);
layerParams
.
set
(
"axis"
,
axis
);
layerParams
.
set
(
"axis"
,
axis
);
...
@@ -1197,7 +1207,7 @@ void TFImporter::populateNet(Net dstNet)
...
@@ -1197,7 +1207,7 @@ void TFImporter::populateNet(Net dstNet)
CV_Assert
(
!
begins
.
empty
(),
!
sizes
.
empty
(),
begins
.
type
()
==
CV_32SC1
,
CV_Assert
(
!
begins
.
empty
(),
!
sizes
.
empty
(),
begins
.
type
()
==
CV_32SC1
,
sizes
.
type
()
==
CV_32SC1
);
sizes
.
type
()
==
CV_32SC1
);
if
(
begins
.
total
()
==
4
&&
data_layouts
[
name
]
==
DATA_LAYOUT_NHWC
)
if
(
begins
.
total
()
==
4
&&
getDataLayout
(
name
,
data_layouts
)
==
DATA_LAYOUT_NHWC
)
{
{
// Swap NHWC parameters' order to NCHW.
// Swap NHWC parameters' order to NCHW.
std
::
swap
(
*
begins
.
ptr
<
int32_t
>
(
0
,
2
),
*
begins
.
ptr
<
int32_t
>
(
0
,
3
));
std
::
swap
(
*
begins
.
ptr
<
int32_t
>
(
0
,
2
),
*
begins
.
ptr
<
int32_t
>
(
0
,
3
));
...
@@ -1597,7 +1607,7 @@ void TFImporter::populateNet(Net dstNet)
...
@@ -1597,7 +1607,7 @@ void TFImporter::populateNet(Net dstNet)
CV_Assert
(
reductionIndices
.
type
()
==
CV_32SC1
);
CV_Assert
(
reductionIndices
.
type
()
==
CV_32SC1
);
const
int
numAxes
=
reductionIndices
.
total
();
const
int
numAxes
=
reductionIndices
.
total
();
if
(
data_layouts
[
name
]
==
DATA_LAYOUT_NHWC
)
if
(
getDataLayout
(
name
,
data_layouts
)
==
DATA_LAYOUT_NHWC
)
for
(
int
i
=
0
;
i
<
numAxes
;
++
i
)
for
(
int
i
=
0
;
i
<
numAxes
;
++
i
)
reductionIndices
.
at
<
int
>
(
i
)
=
toNCHW
(
reductionIndices
.
at
<
int
>
(
i
));
reductionIndices
.
at
<
int
>
(
i
)
=
toNCHW
(
reductionIndices
.
at
<
int
>
(
i
));
...
...
modules/dnn/src/torch/torch_importer.cpp
View file @
7ed5d85f
...
@@ -592,8 +592,8 @@ struct TorchImporter
...
@@ -592,8 +592,8 @@ struct TorchImporter
DictValue
dimParam
=
scalarParams
.
get
(
"size"
);
DictValue
dimParam
=
scalarParams
.
get
(
"size"
);
layerParams
.
set
(
"dim"
,
dimParam
);
layerParams
.
set
(
"dim"
,
dimParam
);
i
f
(
scalarParams
.
has
(
"batchMode"
)
&&
scalarParams
.
get
<
bool
>
(
"batchMode"
))
i
nt
axis
=
(
int
)
scalarParams
.
get
<
bool
>
(
"batchMode"
,
true
);
layerParams
.
set
(
"axis"
,
1
);
layerParams
.
set
(
"axis"
,
axis
);
curModule
->
modules
.
push_back
(
newModule
);
curModule
->
modules
.
push_back
(
newModule
);
}
}
...
...
modules/dnn/test/test_layers.cpp
View file @
7ed5d85f
...
@@ -201,6 +201,13 @@ TEST(Layer_Test_Reshape, Accuracy)
...
@@ -201,6 +201,13 @@ TEST(Layer_Test_Reshape, Accuracy)
testReshape
(
MatShape
(
inp
,
inp
+
4
),
MatShape
(
out
,
out
+
2
),
0
,
-
1
,
testReshape
(
MatShape
(
inp
,
inp
+
4
),
MatShape
(
out
,
out
+
2
),
0
,
-
1
,
MatShape
(
mask
,
mask
+
2
));
MatShape
(
mask
,
mask
+
2
));
}
}
{
int
inp
[]
=
{
1
,
2
,
3
};
int
out
[]
=
{
3
,
1
,
2
};
int
mask
[]
=
{
3
,
1
,
2
};
testReshape
(
MatShape
(
inp
,
inp
+
3
),
MatShape
(
out
,
out
+
3
),
0
,
-
1
,
MatShape
(
mask
,
mask
+
3
));
}
}
}
TEST
(
Layer_Test_BatchNorm
,
Accuracy
)
TEST
(
Layer_Test_BatchNorm
,
Accuracy
)
...
...
modules/dnn/test/test_tf_importer.cpp
View file @
7ed5d85f
...
@@ -198,6 +198,7 @@ TEST_P(Test_TensorFlow_layers, reshape)
...
@@ -198,6 +198,7 @@ TEST_P(Test_TensorFlow_layers, reshape)
{
{
int
targetId
=
GetParam
();
int
targetId
=
GetParam
();
runTensorFlowNet
(
"shift_reshape_no_reorder"
,
targetId
);
runTensorFlowNet
(
"shift_reshape_no_reorder"
,
targetId
);
runTensorFlowNet
(
"reshape_no_reorder"
,
targetId
);
runTensorFlowNet
(
"reshape_reduce"
,
targetId
);
runTensorFlowNet
(
"reshape_reduce"
,
targetId
);
runTensorFlowNet
(
"flatten"
,
targetId
,
true
);
runTensorFlowNet
(
"flatten"
,
targetId
,
true
);
runTensorFlowNet
(
"unfused_flatten"
,
targetId
);
runTensorFlowNet
(
"unfused_flatten"
,
targetId
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment