Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
5429b1f5
Commit
5429b1f5
authored
Jan 27, 2020
by
Alexander Alekhin
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #16223 from l-bat:lip_jppnet
parents
02f8a947
24166ac4
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
364 additions
and
40 deletions
+364
-40
all_layers.hpp
modules/dnn/include/opencv2/dnn/all_layers.hpp
+2
-1
dnn.hpp
modules/dnn/include/opencv2/dnn/dnn.hpp
+2
-2
layers_common.cpp
modules/dnn/src/layers/layers_common.cpp
+21
-10
layers_common.hpp
modules/dnn/src/layers/layers_common.hpp
+1
-1
pooling_layer.cpp
modules/dnn/src/layers/pooling_layer.cpp
+24
-12
slice_layer.cpp
modules/dnn/src/layers/slice_layer.cpp
+2
-1
tf_importer.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
+127
-13
test_tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp
+7
-0
human_parsing.py
samples/dnn/human_parsing.py
+178
-0
No files found.
modules/dnn/include/opencv2/dnn/all_layers.hpp
View file @
5429b1f5
...
@@ -250,7 +250,8 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
...
@@ -250,7 +250,8 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
std
::
vector
<
size_t
>
pads_begin
,
pads_end
;
std
::
vector
<
size_t
>
pads_begin
,
pads_end
;
CV_DEPRECATED_EXTERNAL
Size
kernel
,
stride
,
pad
;
CV_DEPRECATED_EXTERNAL
Size
kernel
,
stride
,
pad
;
CV_DEPRECATED_EXTERNAL
int
pad_l
,
pad_t
,
pad_r
,
pad_b
;
CV_DEPRECATED_EXTERNAL
int
pad_l
,
pad_t
,
pad_r
,
pad_b
;
bool
globalPooling
;
bool
globalPooling
;
//!< Flag is true if at least one of the axes is global pooled.
std
::
vector
<
bool
>
isGlobalPooling
;
bool
computeMaxIdx
;
bool
computeMaxIdx
;
String
padMode
;
String
padMode
;
bool
ceilMode
;
bool
ceilMode
;
...
...
modules/dnn/include/opencv2/dnn/dnn.hpp
View file @
5429b1f5
...
@@ -47,9 +47,9 @@
...
@@ -47,9 +47,9 @@
#include "opencv2/core/async.hpp"
#include "opencv2/core/async.hpp"
#if !defined CV_DOXYGEN && !defined CV_STATIC_ANALYSIS && !defined CV_DNN_DONT_ADD_EXPERIMENTAL_NS
#if !defined CV_DOXYGEN && !defined CV_STATIC_ANALYSIS && !defined CV_DNN_DONT_ADD_EXPERIMENTAL_NS
#define CV__DNN_EXPERIMENTAL_NS_BEGIN namespace experimental_dnn_34_v1
5
{
#define CV__DNN_EXPERIMENTAL_NS_BEGIN namespace experimental_dnn_34_v1
6
{
#define CV__DNN_EXPERIMENTAL_NS_END }
#define CV__DNN_EXPERIMENTAL_NS_END }
namespace
cv
{
namespace
dnn
{
namespace
experimental_dnn_34_v1
5
{
}
using
namespace
experimental_dnn_34_v15
;
}}
namespace
cv
{
namespace
dnn
{
namespace
experimental_dnn_34_v1
6
{
}
using
namespace
experimental_dnn_34_v16
;
}}
#else
#else
#define CV__DNN_EXPERIMENTAL_NS_BEGIN
#define CV__DNN_EXPERIMENTAL_NS_BEGIN
#define CV__DNN_EXPERIMENTAL_NS_END
#define CV__DNN_EXPERIMENTAL_NS_END
...
...
modules/dnn/src/layers/layers_common.cpp
View file @
5429b1f5
...
@@ -144,26 +144,37 @@ void getStrideAndPadding(const LayerParams ¶ms, std::vector<size_t>& pads_be
...
@@ -144,26 +144,37 @@ void getStrideAndPadding(const LayerParams ¶ms, std::vector<size_t>& pads_be
}
}
}
}
void
getPoolingKernelParams
(
const
LayerParams
&
params
,
std
::
vector
<
size_t
>&
kernel
,
bool
&
globalPooling
,
void
getPoolingKernelParams
(
const
LayerParams
&
params
,
std
::
vector
<
size_t
>&
kernel
,
std
::
vector
<
bool
>&
globalPooling
,
std
::
vector
<
size_t
>&
pads_begin
,
std
::
vector
<
size_t
>&
pads_end
,
std
::
vector
<
size_t
>&
pads_begin
,
std
::
vector
<
size_t
>&
pads_end
,
std
::
vector
<
size_t
>&
strides
,
cv
::
String
&
padMode
)
std
::
vector
<
size_t
>&
strides
,
cv
::
String
&
padMode
)
{
{
globalPooling
=
params
.
has
(
"global_pooling"
)
&&
bool
is_global
=
params
.
get
<
bool
>
(
"global_pooling"
,
false
);
params
.
get
<
bool
>
(
"global_pooling"
);
globalPooling
.
resize
(
3
);
globalPooling
[
0
]
=
params
.
get
<
bool
>
(
"global_pooling_d"
,
is_global
);
globalPooling
[
1
]
=
params
.
get
<
bool
>
(
"global_pooling_h"
,
is_global
);
globalPooling
[
2
]
=
params
.
get
<
bool
>
(
"global_pooling_w"
,
is_global
);
if
(
globalPooling
)
if
(
globalPooling
[
0
]
||
globalPooling
[
1
]
||
globalPooling
[
2
]
)
{
{
util
::
getStrideAndPadding
(
params
,
pads_begin
,
pads_end
,
strides
,
padMode
);
util
::
getStrideAndPadding
(
params
,
pads_begin
,
pads_end
,
strides
,
padMode
);
if
(
params
.
has
(
"kernel_h"
)
||
params
.
has
(
"kernel_w"
)
||
params
.
has
(
"kernel_size"
))
if
((
globalPooling
[
0
]
&&
params
.
has
(
"kernel_d"
))
||
{
(
globalPooling
[
1
]
&&
params
.
has
(
"kernel_h"
))
||
(
globalPooling
[
2
]
&&
params
.
has
(
"kernel_w"
))
||
params
.
has
(
"kernel_size"
))
{
CV_Error
(
cv
::
Error
::
StsBadArg
,
"In global_pooling mode, kernel_size (or kernel_h and kernel_w) cannot be specified"
);
CV_Error
(
cv
::
Error
::
StsBadArg
,
"In global_pooling mode, kernel_size (or kernel_h and kernel_w) cannot be specified"
);
}
}
for
(
int
i
=
0
;
i
<
pads_begin
.
size
();
i
++
)
{
if
(
pads_begin
[
i
]
!=
0
||
pads_end
[
i
]
!=
0
)
kernel
.
resize
(
3
);
kernel
[
0
]
=
params
.
get
<
int
>
(
"kernel_d"
,
1
);
kernel
[
1
]
=
params
.
get
<
int
>
(
"kernel_h"
,
1
);
kernel
[
2
]
=
params
.
get
<
int
>
(
"kernel_w"
,
1
);
for
(
int
i
=
0
,
j
=
globalPooling
.
size
()
-
pads_begin
.
size
();
i
<
pads_begin
.
size
();
i
++
,
j
++
)
{
if
((
pads_begin
[
i
]
!=
0
||
pads_end
[
i
]
!=
0
)
&&
globalPooling
[
j
])
CV_Error
(
cv
::
Error
::
StsBadArg
,
"In global_pooling mode, pads must be = 0"
);
CV_Error
(
cv
::
Error
::
StsBadArg
,
"In global_pooling mode, pads must be = 0"
);
}
}
for
(
int
i
=
0
;
i
<
strides
.
size
();
i
++
)
{
for
(
int
i
=
0
,
j
=
globalPooling
.
size
()
-
strides
.
size
();
i
<
strides
.
size
();
i
++
,
j
++
)
{
if
(
strides
[
i
]
!=
1
)
if
(
strides
[
i
]
!=
1
&&
globalPooling
[
j
]
)
CV_Error
(
cv
::
Error
::
StsBadArg
,
"In global_pooling mode, strides must be = 1"
);
CV_Error
(
cv
::
Error
::
StsBadArg
,
"In global_pooling mode, strides must be = 1"
);
}
}
}
}
...
...
modules/dnn/src/layers/layers_common.hpp
View file @
5429b1f5
...
@@ -63,7 +63,7 @@ void getConvolutionKernelParams(const LayerParams ¶ms, std::vector<size_t>&
...
@@ -63,7 +63,7 @@ void getConvolutionKernelParams(const LayerParams ¶ms, std::vector<size_t>&
std
::
vector
<
size_t
>&
pads_end
,
std
::
vector
<
size_t
>&
strides
,
std
::
vector
<
size_t
>&
dilations
,
std
::
vector
<
size_t
>&
pads_end
,
std
::
vector
<
size_t
>&
strides
,
std
::
vector
<
size_t
>&
dilations
,
cv
::
String
&
padMode
,
std
::
vector
<
size_t
>&
adjust_pads
);
cv
::
String
&
padMode
,
std
::
vector
<
size_t
>&
adjust_pads
);
void
getPoolingKernelParams
(
const
LayerParams
&
params
,
std
::
vector
<
size_t
>&
kernel
,
bool
&
globalPooling
,
void
getPoolingKernelParams
(
const
LayerParams
&
params
,
std
::
vector
<
size_t
>&
kernel
,
std
::
vector
<
bool
>&
globalPooling
,
std
::
vector
<
size_t
>&
pads_begin
,
std
::
vector
<
size_t
>&
pads_end
,
std
::
vector
<
size_t
>&
strides
,
cv
::
String
&
padMode
);
std
::
vector
<
size_t
>&
pads_begin
,
std
::
vector
<
size_t
>&
pads_end
,
std
::
vector
<
size_t
>&
strides
,
cv
::
String
&
padMode
);
void
getConvPoolOutParams
(
const
std
::
vector
<
int
>&
inp
,
const
std
::
vector
<
size_t
>&
kernel
,
void
getConvPoolOutParams
(
const
std
::
vector
<
int
>&
inp
,
const
std
::
vector
<
size_t
>&
kernel
,
...
...
modules/dnn/src/layers/pooling_layer.cpp
View file @
5429b1f5
...
@@ -79,6 +79,7 @@ public:
...
@@ -79,6 +79,7 @@ public:
{
{
computeMaxIdx
=
true
;
computeMaxIdx
=
true
;
globalPooling
=
false
;
globalPooling
=
false
;
isGlobalPooling
=
std
::
vector
<
bool
>
(
3
,
false
);
stride
=
Size
(
1
,
1
);
stride
=
Size
(
1
,
1
);
pad_t
=
pad_l
=
pad_b
=
pad_r
=
0
;
pad_t
=
pad_l
=
pad_b
=
pad_r
=
0
;
...
@@ -95,7 +96,8 @@ public:
...
@@ -95,7 +96,8 @@ public:
else
else
CV_Error
(
Error
::
StsBadArg
,
"Unknown pooling type
\"
"
+
pool
+
"
\"
"
);
CV_Error
(
Error
::
StsBadArg
,
"Unknown pooling type
\"
"
+
pool
+
"
\"
"
);
getPoolingKernelParams
(
params
,
kernel_size
,
globalPooling
,
pads_begin
,
pads_end
,
strides
,
padMode
);
getPoolingKernelParams
(
params
,
kernel_size
,
isGlobalPooling
,
pads_begin
,
pads_end
,
strides
,
padMode
);
globalPooling
=
isGlobalPooling
[
0
]
||
isGlobalPooling
[
1
]
||
isGlobalPooling
[
2
];
if
(
kernel_size
.
size
()
==
2
)
{
if
(
kernel_size
.
size
()
==
2
)
{
kernel
=
Size
(
kernel_size
[
1
],
kernel_size
[
0
]);
kernel
=
Size
(
kernel_size
[
1
],
kernel_size
[
0
]);
stride
=
Size
(
strides
[
1
],
strides
[
0
]);
stride
=
Size
(
strides
[
1
],
strides
[
0
]);
...
@@ -147,9 +149,14 @@ public:
...
@@ -147,9 +149,14 @@ public:
out
.
push_back
(
outputs
[
0
].
size
[
i
]);
out
.
push_back
(
outputs
[
0
].
size
[
i
]);
}
}
if
(
globalPooling
)
{
if
(
globalPooling
)
{
kernel
=
Size
(
inp
[
1
],
inp
[
0
]);
std
::
vector
<
size_t
>
finalKernel
;
kernel_size
=
std
::
vector
<
size_t
>
(
inp
.
begin
(),
inp
.
end
());
for
(
int
i
=
0
;
i
<
inp
.
size
();
i
++
)
{
}
int
idx
=
isGlobalPooling
.
size
()
-
inp
.
size
()
+
i
;
finalKernel
.
push_back
(
isGlobalPooling
[
idx
]
?
inp
[
i
]
:
kernel_size
[
idx
]);
}
kernel_size
=
finalKernel
;
kernel
=
Size
(
kernel_size
[
1
],
kernel_size
[
0
]);
}
getConvPoolPaddings
(
inp
,
kernel_size
,
strides
,
padMode
,
pads_begin
,
pads_end
);
getConvPoolPaddings
(
inp
,
kernel_size
,
strides
,
padMode
,
pads_begin
,
pads_end
);
if
(
pads_begin
.
size
()
==
2
)
{
if
(
pads_begin
.
size
()
==
2
)
{
...
@@ -995,20 +1002,25 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
...
@@ -995,20 +1002,25 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
std
::
vector
<
int
>
inpShape
(
inputs
[
0
].
begin
()
+
2
,
inputs
[
0
].
end
());
std
::
vector
<
int
>
inpShape
(
inputs
[
0
].
begin
()
+
2
,
inputs
[
0
].
end
());
std
::
vector
<
int
>
outShape
(
inputs
[
0
].
begin
(),
inputs
[
0
].
begin
()
+
2
);
std
::
vector
<
int
>
outShape
(
inputs
[
0
].
begin
(),
inputs
[
0
].
begin
()
+
2
);
if
(
globalPooling
)
std
::
vector
<
size_t
>
local_kernel
;
{
if
(
globalPooling
)
{
outShape
.
push_back
(
1
);
for
(
int
i
=
0
;
i
<
inpShape
.
size
();
i
++
)
{
outShape
.
push_back
(
1
);
int
idx
=
isGlobalPooling
.
size
()
-
inpShape
.
size
()
+
i
;
local_kernel
.
push_back
(
isGlobalPooling
[
idx
]
?
inpShape
[
i
]
:
kernel_size
[
idx
]);
}
}
else
{
local_kernel
=
kernel_size
;
}
}
else
if
(
type
==
ROI
||
type
==
PSROI
)
if
(
type
==
ROI
||
type
==
PSROI
)
{
{
outShape
.
push_back
(
pooledSize
.
height
);
outShape
.
push_back
(
pooledSize
.
height
);
outShape
.
push_back
(
pooledSize
.
width
);
outShape
.
push_back
(
pooledSize
.
width
);
}
}
else
if
(
padMode
.
empty
())
else
if
(
padMode
.
empty
())
{
{
for
(
int
i
=
0
;
i
<
kernel_size
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
local_kernel
.
size
();
i
++
)
{
float
dst
=
(
float
)(
inpShape
[
i
]
+
pads_begin
[
i
]
+
pads_end
[
i
]
-
kernel_size
[
i
])
/
strides
[
i
];
float
dst
=
(
float
)(
inpShape
[
i
]
+
pads_begin
[
i
]
+
pads_end
[
i
]
-
local_kernel
[
i
])
/
strides
[
i
];
outShape
.
push_back
(
1
+
(
ceilMode
?
ceil
(
dst
)
:
floor
(
dst
)));
outShape
.
push_back
(
1
+
(
ceilMode
?
ceil
(
dst
)
:
floor
(
dst
)));
}
}
...
@@ -1023,7 +1035,7 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
...
@@ -1023,7 +1035,7 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
}
}
else
else
{
{
getConvPoolOutParams
(
inpShape
,
kernel_size
,
strides
,
padMode
,
std
::
vector
<
size_t
>
(
kernel_size
.
size
(),
1
),
outShape
);
getConvPoolOutParams
(
inpShape
,
local_kernel
,
strides
,
padMode
,
std
::
vector
<
size_t
>
(
local_kernel
.
size
(),
1
),
outShape
);
}
}
if
(
type
==
ROI
)
if
(
type
==
ROI
)
{
{
...
...
modules/dnn/src/layers/slice_layer.cpp
View file @
5429b1f5
...
@@ -114,7 +114,8 @@ public:
...
@@ -114,7 +114,8 @@ public:
virtual
bool
supportBackend
(
int
backendId
)
CV_OVERRIDE
virtual
bool
supportBackend
(
int
backendId
)
CV_OVERRIDE
{
{
return
backendId
==
DNN_BACKEND_OPENCV
||
return
backendId
==
DNN_BACKEND_OPENCV
||
((
backendId
==
DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019
||
backendId
==
DNN_BACKEND_INFERENCE_ENGINE_NGRAPH
)
&&
(
backendId
==
DNN_BACKEND_INFERENCE_ENGINE_NGRAPH
&&
sliceRanges
.
size
()
==
1
)
||
(
backendId
==
DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019
&&
#ifdef HAVE_INF_ENGINE
#ifdef HAVE_INF_ENGINE
INF_ENGINE_VER_MAJOR_GE
(
INF_ENGINE_RELEASE_2019R1
)
&&
INF_ENGINE_VER_MAJOR_GE
(
INF_ENGINE_RELEASE_2019R1
)
&&
#endif
#endif
...
...
modules/dnn/src/tensorflow/tf_importer.cpp
View file @
5429b1f5
...
@@ -1936,20 +1936,22 @@ void TFImporter::populateNet(Net dstNet)
...
@@ -1936,20 +1936,22 @@ void TFImporter::populateNet(Net dstNet)
}
}
else
if
(
type
==
"Mean"
)
else
if
(
type
==
"Mean"
)
{
{
// Computes the mean of elements across dimensions of a tensor.
// If keepdims is false (default) reduces input_tensor along the dimensions given in axis,
// else the reduced dimensions are retained with length 1.
// if indices = [1, 2] in NHWC layout we use global pooling: NxCxHxW --Pooling--> NxCx1x1
// if keepdims is false we use Flatten after Pooling: out_shape = NxC
// if indices = [0] we use a global pooling by indices.
// To return correct shape, we use Reshape after Pooling. To determine input shape use Slice for input,
// if keepdims is false we use Flatten after Slice.
// Example: input_shape = NxCxHxW
// determine out shape: NxCxHxW --Slice--> 1xCxHxW
// out_shape = 1xCxHxW if keepDims else (1xCxHxW --Flatten--> CxHxW)
// global pool: NxCxHxW --Flatten--> Nx(C*H*W) --Reshape--> 1x1xNx(C*H*W) --Pooling--> 1x1x1x(C*H*W) --Reshape--> out_shape
Mat
indices
=
getTensorContent
(
getConstBlob
(
layer
,
value_id
,
1
));
Mat
indices
=
getTensorContent
(
getConstBlob
(
layer
,
value_id
,
1
));
CV_Assert
(
indices
.
type
()
==
CV_32SC1
);
CV_Assert
(
indices
.
type
()
==
CV_32SC1
);
if
(
indices
.
total
()
!=
2
||
indices
.
at
<
int
>
(
0
)
!=
1
||
indices
.
at
<
int
>
(
1
)
!=
2
)
CV_Error
(
Error
::
StsNotImplemented
,
"Unsupported mode of reduce_mean operation."
);
layerParams
.
set
(
"pool"
,
"ave"
);
layerParams
.
set
(
"global_pooling"
,
true
);
int
id
=
dstNet
.
addLayer
(
name
,
"Pooling"
,
layerParams
);
layer_id
[
name
]
=
id
;
connect
(
layer_id
,
dstNet
,
parsePin
(
layer
.
input
(
0
)),
id
,
0
);
// There are two attributes, "keepdims" and a deprecated "keep_dims".
// There are two attributes, "keepdims" and a deprecated "keep_dims".
bool
keepDims
=
false
;
bool
keepDims
=
false
;
if
(
hasLayerAttr
(
layer
,
"keepdims"
))
if
(
hasLayerAttr
(
layer
,
"keepdims"
))
...
@@ -1957,16 +1959,128 @@ void TFImporter::populateNet(Net dstNet)
...
@@ -1957,16 +1959,128 @@ void TFImporter::populateNet(Net dstNet)
else
if
(
hasLayerAttr
(
layer
,
"keep_dims"
))
else
if
(
hasLayerAttr
(
layer
,
"keep_dims"
))
keepDims
=
getLayerAttr
(
layer
,
"keep_dims"
).
b
();
keepDims
=
getLayerAttr
(
layer
,
"keep_dims"
).
b
();
if
(
!
keepDims
)
if
(
indices
.
total
()
==
1
&&
indices
.
at
<
int
>
(
0
)
==
0
)
{
{
LayerParams
flattenLp
;
LayerParams
flattenLp
;
std
::
string
flattenName
=
name
+
"/flatten"
;
std
::
string
flattenName
=
name
+
"/flatten"
;
CV_Assert
(
layer_id
.
find
(
flattenName
)
==
layer_id
.
end
());
CV_Assert
(
layer_id
.
find
(
flattenName
)
==
layer_id
.
end
());
int
flattenId
=
dstNet
.
addLayer
(
flattenName
,
"Flatten"
,
flattenLp
);
int
flattenId
=
dstNet
.
addLayer
(
flattenName
,
"Flatten"
,
flattenLp
);
layer_id
[
flattenName
]
=
flattenId
;
layer_id
[
flattenName
]
=
flattenId
;
connect
(
layer_id
,
dstNet
,
Pin
(
name
),
flattenId
,
0
);
connect
(
layer_id
,
dstNet
,
parsePin
(
layer
.
input
(
0
)),
flattenId
,
0
);
LayerParams
reshapeLp
;
std
::
string
reshapeName
=
name
+
"/reshape"
;
CV_Assert
(
layer_id
.
find
(
reshapeName
)
==
layer_id
.
end
());
reshapeLp
.
set
(
"axis"
,
0
);
reshapeLp
.
set
(
"num_axes"
,
1
);
int
newShape
[]
=
{
1
,
1
,
-
1
};
reshapeLp
.
set
(
"dim"
,
DictValue
::
arrayInt
(
&
newShape
[
0
],
3
));
int
reshapeId
=
dstNet
.
addLayer
(
reshapeName
,
"Reshape"
,
reshapeLp
);
layer_id
[
reshapeName
]
=
reshapeId
;
connect
(
layer_id
,
dstNet
,
Pin
(
flattenName
),
reshapeId
,
0
);
LayerParams
avgLp
;
std
::
string
avgName
=
name
+
"/avg"
;
CV_Assert
(
layer_id
.
find
(
avgName
)
==
layer_id
.
end
());
avgLp
.
set
(
"pool"
,
"ave"
);
// pooling kernel H x 1
avgLp
.
set
(
"global_pooling_h"
,
true
);
avgLp
.
set
(
"kernel_w"
,
1
);
int
avgId
=
dstNet
.
addLayer
(
avgName
,
"Pooling"
,
avgLp
);
layer_id
[
avgName
]
=
avgId
;
connect
(
layer_id
,
dstNet
,
Pin
(
reshapeName
),
avgId
,
0
);
LayerParams
sliceLp
;
std
::
string
layerShapeName
=
name
+
"/slice"
;
CV_Assert
(
layer_id
.
find
(
layerShapeName
)
==
layer_id
.
end
());
sliceLp
.
set
(
"axis"
,
0
);
int
begin
[]
=
{
0
};
int
size
[]
=
{
1
};
sliceLp
.
set
(
"begin"
,
DictValue
::
arrayInt
(
&
begin
[
0
],
1
));
sliceLp
.
set
(
"size"
,
DictValue
::
arrayInt
(
&
size
[
0
],
1
));
int
sliceId
=
dstNet
.
addLayer
(
layerShapeName
,
"Slice"
,
sliceLp
);
layer_id
[
layerShapeName
]
=
sliceId
;
connect
(
layer_id
,
dstNet
,
Pin
(
layer
.
input
(
0
)),
sliceId
,
0
);
if
(
!
keepDims
)
{
LayerParams
squeezeLp
;
std
::
string
squeezeName
=
name
+
"/squeeze"
;
CV_Assert
(
layer_id
.
find
(
squeezeName
)
==
layer_id
.
end
());
squeezeLp
.
set
(
"axis"
,
0
);
squeezeLp
.
set
(
"end_axis"
,
1
);
int
squeezeId
=
dstNet
.
addLayer
(
squeezeName
,
"Flatten"
,
squeezeLp
);
layer_id
[
squeezeName
]
=
squeezeId
;
connect
(
layer_id
,
dstNet
,
Pin
(
layerShapeName
),
squeezeId
,
0
);
layerShapeName
=
squeezeName
;
}
int
id
=
dstNet
.
addLayer
(
name
,
"Reshape"
,
layerParams
);
layer_id
[
name
]
=
id
;
connect
(
layer_id
,
dstNet
,
Pin
(
avgName
),
id
,
0
);
connect
(
layer_id
,
dstNet
,
Pin
(
layerShapeName
),
id
,
1
);
}
else
{
if
(
indices
.
total
()
!=
2
||
indices
.
at
<
int
>
(
0
)
!=
1
||
indices
.
at
<
int
>
(
1
)
!=
2
)
CV_Error
(
Error
::
StsNotImplemented
,
"Unsupported mode of reduce_mean operation."
);
layerParams
.
set
(
"pool"
,
"ave"
);
layerParams
.
set
(
"global_pooling"
,
true
);
int
id
=
dstNet
.
addLayer
(
name
,
"Pooling"
,
layerParams
);
layer_id
[
name
]
=
id
;
connect
(
layer_id
,
dstNet
,
parsePin
(
layer
.
input
(
0
)),
id
,
0
);
if
(
!
keepDims
)
{
LayerParams
flattenLp
;
std
::
string
flattenName
=
name
+
"/flatten"
;
CV_Assert
(
layer_id
.
find
(
flattenName
)
==
layer_id
.
end
());
int
flattenId
=
dstNet
.
addLayer
(
flattenName
,
"Flatten"
,
flattenLp
);
layer_id
[
flattenName
]
=
flattenId
;
connect
(
layer_id
,
dstNet
,
Pin
(
name
),
flattenId
,
0
);
}
}
}
}
}
else
if
(
type
==
"Pack"
)
{
// op: tf.stack(list of tensors, axis=0)
// Join a list of inputs along a new axis.
// The "axis" specifies the index of the new axis in the dimensions of the output.
// Example: given a list with "N" tensors of shape (C, H, W):
// if axis == 0 then the output tensor will have the shape (N, C, H, W),
// if axis == 1 then the output tensor will have the shape (C, N, H, W).
CV_Assert
(
hasLayerAttr
(
layer
,
"axis"
));
int
dim
=
(
int
)
getLayerAttr
(
layer
,
"axis"
).
i
();
if
(
dim
!=
0
)
CV_Error
(
Error
::
StsNotImplemented
,
"Unsupported mode of pack operation."
);
CV_Assert
(
hasLayerAttr
(
layer
,
"N"
));
int
num
=
(
int
)
getLayerAttr
(
layer
,
"N"
).
i
();
CV_Assert
(
layer
.
input_size
()
==
num
);
std
::
string
base_name
=
name
+
"/reshape_"
;
std
::
vector
<
int
>
reshape_ids
;
for
(
int
i
=
0
;
i
<
num
;
i
++
)
{
std
::
ostringstream
ss
;
ss
<<
i
;
std
::
string
reshape_name
=
base_name
+
ss
.
str
();
LayerParams
reshapeLP
;
reshapeLP
.
set
(
"axis"
,
dim
);
reshapeLP
.
set
(
"num_axes"
,
1
);
int
outShape
[]
=
{
1
,
-
1
};
reshapeLP
.
set
(
"dim"
,
DictValue
::
arrayInt
(
&
outShape
[
0
],
2
));
int
id
=
dstNet
.
addLayer
(
reshape_name
,
"Reshape"
,
reshapeLP
);
layer_id
[
reshape_name
]
=
id
;
reshape_ids
.
push_back
(
id
);
connect
(
layer_id
,
dstNet
,
parsePin
(
layer
.
input
(
i
)),
id
,
0
);
}
layerParams
.
set
(
"axis"
,
dim
);
int
id
=
dstNet
.
addLayer
(
name
,
"Concat"
,
layerParams
);
layer_id
[
name
]
=
id
;
for
(
int
li
=
0
;
li
<
num
;
li
++
)
dstNet
.
connect
(
reshape_ids
[
li
],
0
,
id
,
li
);
}
else
if
(
type
==
"ClipByValue"
)
else
if
(
type
==
"ClipByValue"
)
{
{
// op: "ClipByValue"
// op: "ClipByValue"
...
...
modules/dnn/test/test_tf_importer.cpp
View file @
5429b1f5
...
@@ -121,6 +121,13 @@ public:
...
@@ -121,6 +121,13 @@ public:
}
}
};
};
TEST_P
(
Test_TensorFlow_layers
,
reduce_mean
)
{
if
(
backend
==
DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019
)
applyTestTag
(
CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER
);
runTensorFlowNet
(
"global_pool_by_axis"
);
}
TEST_P
(
Test_TensorFlow_layers
,
conv
)
TEST_P
(
Test_TensorFlow_layers
,
conv
)
{
{
runTensorFlowNet
(
"single_conv"
);
runTensorFlowNet
(
"single_conv"
);
...
...
samples/dnn/human_parsing.py
0 → 100644
View file @
5429b1f5
#!/usr/bin/env python
'''
You can download the converted pb model from https://www.dropbox.com/s/qag9vzambhhkvxr/lip_jppnet_384.pb?dl=0
or convert the model yourself.
Follow these steps if you want to convert the original model yourself:
To get original .meta pre-trained model download https://drive.google.com/file/d/1BFVXgeln-bek8TCbRjN6utPAgRE0LJZg/view
For correct convert .meta to .pb model download original repository https://github.com/Engineering-Course/LIP_JPPNet
Change script evaluate_parsing_JPPNet-s2.py for human parsing
1. Remove preprocessing to create image_batch_origin:
with tf.name_scope("create_inputs"):
...
Add
image_batch_origin = tf.placeholder(tf.float32, shape=(2, None, None, 3), name='input')
2. Create input
image = cv2.imread(path/to/image)
image_rev = np.flip(image, axis=1)
input = np.stack([image, image_rev], axis=0)
3. Hardcode image_h and image_w shapes to determine output shapes.
We use default INPUT_SIZE = (384, 384) from evaluate_parsing_JPPNet-s2.py.
parsing_out1 = tf.reduce_mean(tf.stack([tf.image.resize_images(parsing_out1_100, INPUT_SIZE),
tf.image.resize_images(parsing_out1_075, INPUT_SIZE),
tf.image.resize_images(parsing_out1_125, INPUT_SIZE)]), axis=0)
Do similarly with parsing_out2, parsing_out3
4. Remove postprocessing. Last net operation:
raw_output = tf.reduce_mean(tf.stack([parsing_out1, parsing_out2, parsing_out3]), axis=0)
Change:
parsing_ = sess.run(raw_output, feed_dict={'input:0': input})
5. To save model after sess.run(...) add:
input_graph_def = tf.get_default_graph().as_graph_def()
output_node = "Mean_3"
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_node)
output_graph = "LIP_JPPNet.pb"
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())'
'''
import
argparse
import
numpy
as
np
import
cv2
as
cv
backends
=
(
cv
.
dnn
.
DNN_BACKEND_DEFAULT
,
cv
.
dnn
.
DNN_BACKEND_INFERENCE_ENGINE
,
cv
.
dnn
.
DNN_BACKEND_OPENCV
)
targets
=
(
cv
.
dnn
.
DNN_TARGET_CPU
,
cv
.
dnn
.
DNN_TARGET_OPENCL
,
cv
.
dnn
.
DNN_TARGET_OPENCL_FP16
,
cv
.
dnn
.
DNN_TARGET_MYRIAD
)
def
preprocess
(
image_path
):
"""
Create 4-dimensional blob from image and flip image
:param image_path: path to input image
"""
image
=
cv
.
imread
(
image_path
)
image_rev
=
np
.
flip
(
image
,
axis
=
1
)
input
=
cv
.
dnn
.
blobFromImages
([
image
,
image_rev
],
mean
=
(
104.00698793
,
116.66876762
,
122.67891434
))
return
input
def
run_net
(
input
,
model_path
,
backend
,
target
):
"""
Read network and infer model
:param model_path: path to JPPNet model
:param backend: computation backend
:param target: computation device
"""
net
=
cv
.
dnn
.
readNet
(
model_path
)
net
.
setPreferableBackend
(
backend
)
net
.
setPreferableTarget
(
target
)
net
.
setInput
(
input
)
out
=
net
.
forward
()
return
out
def
postprocess
(
out
,
input_shape
):
"""
Create a grayscale human segmentation
:param out: network output
:param input_shape: input image width and height
"""
# LIP classes
# 0 Background
# 1 Hat
# 2 Hair
# 3 Glove
# 4 Sunglasses
# 5 UpperClothes
# 6 Dress
# 7 Coat
# 8 Socks
# 9 Pants
# 10 Jumpsuits
# 11 Scarf
# 12 Skirt
# 13 Face
# 14 LeftArm
# 15 RightArm
# 16 LeftLeg
# 17 RightLeg
# 18 LeftShoe
# 19 RightShoe
head_output
,
tail_output
=
np
.
split
(
out
,
indices_or_sections
=
[
1
],
axis
=
0
)
head_output
=
head_output
.
squeeze
(
0
)
tail_output
=
tail_output
.
squeeze
(
0
)
head_output
=
np
.
stack
([
cv
.
resize
(
img
,
dsize
=
input_shape
)
for
img
in
head_output
[:,
...
]])
tail_output
=
np
.
stack
([
cv
.
resize
(
img
,
dsize
=
input_shape
)
for
img
in
tail_output
[:,
...
]])
tail_list
=
np
.
split
(
tail_output
,
indices_or_sections
=
list
(
range
(
1
,
20
)),
axis
=
0
)
tail_list
=
[
arr
.
squeeze
(
0
)
for
arr
in
tail_list
]
tail_list_rev
=
[
tail_list
[
i
]
for
i
in
range
(
14
)]
tail_list_rev
.
extend
([
tail_list
[
15
],
tail_list
[
14
],
tail_list
[
17
],
tail_list
[
16
],
tail_list
[
19
],
tail_list
[
18
]])
tail_output_rev
=
np
.
stack
(
tail_list_rev
,
axis
=
0
)
tail_output_rev
=
np
.
flip
(
tail_output_rev
,
axis
=
2
)
raw_output_all
=
np
.
mean
(
np
.
stack
([
head_output
,
tail_output_rev
],
axis
=
0
),
axis
=
0
,
keepdims
=
True
)
raw_output_all
=
np
.
argmax
(
raw_output_all
,
axis
=
1
)
raw_output_all
=
raw_output_all
.
transpose
(
1
,
2
,
0
)
return
raw_output_all
def
decode_labels
(
gray_image
):
"""
Colorize image according to labels
:param gray_image: grayscale human segmentation result
"""
height
,
width
,
_
=
gray_image
.
shape
colors
=
[(
0
,
0
,
0
),
(
128
,
0
,
0
),
(
255
,
0
,
0
),
(
0
,
85
,
0
),
(
170
,
0
,
51
),
(
255
,
85
,
0
),
(
0
,
0
,
85
),
(
0
,
119
,
221
),
(
85
,
85
,
0
),
(
0
,
85
,
85
),
(
85
,
51
,
0
),
(
52
,
86
,
128
),
(
0
,
128
,
0
),
(
0
,
0
,
255
),
(
51
,
170
,
221
),
(
0
,
255
,
255
),(
85
,
255
,
170
),
(
170
,
255
,
85
),
(
255
,
255
,
0
),
(
255
,
170
,
0
)]
segm
=
np
.
stack
([
colors
[
idx
]
for
idx
in
gray_image
.
flatten
()])
segm
=
segm
.
reshape
(
height
,
width
,
3
)
.
astype
(
np
.
uint8
)
segm
=
cv
.
cvtColor
(
segm
,
cv
.
COLOR_BGR2RGB
)
return
segm
def
parse_human
(
image_path
,
model_path
,
backend
=
cv
.
dnn
.
DNN_BACKEND_OPENCV
,
target
=
cv
.
dnn
.
DNN_TARGET_CPU
):
"""
Prepare input for execution, run net and postprocess output to parse human.
:param image_path: path to input image
:param model_path: path to JPPNet model
:param backend: name of computation backend
:param target: name of computation target
"""
input
=
preprocess
(
image_path
)
input_h
,
input_w
=
input
.
shape
[
2
:]
output
=
run_net
(
input
,
model_path
,
backend
,
target
)
grayscale_out
=
postprocess
(
output
,
(
input_w
,
input_h
))
segmentation
=
decode_labels
(
grayscale_out
)
return
segmentation
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Use this script to run human parsing using JPPNet'
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
parser
.
add_argument
(
'--input'
,
'-i'
,
required
=
True
,
help
=
'Path to input image.'
)
parser
.
add_argument
(
'--model'
,
'-m'
,
required
=
True
,
help
=
'Path to pb model.'
)
parser
.
add_argument
(
'--backend'
,
choices
=
backends
,
default
=
cv
.
dnn
.
DNN_BACKEND_DEFAULT
,
type
=
int
,
help
=
"Choose one of computation backends: "
"
%
d: automatically (by default), "
"
%
d: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
"
%
d: OpenCV implementation"
%
backends
)
parser
.
add_argument
(
'--target'
,
choices
=
targets
,
default
=
cv
.
dnn
.
DNN_TARGET_CPU
,
type
=
int
,
help
=
'Choose one of target computation devices: '
'
%
d: CPU target (by default), '
'
%
d: OpenCL, '
'
%
d: OpenCL fp16 (half-float precision), '
'
%
d: VPU'
%
targets
)
args
,
_
=
parser
.
parse_known_args
()
output
=
parse_human
(
args
.
input
,
args
.
model
,
args
.
backend
,
args
.
target
)
winName
=
'Deep learning human parsing in OpenCV'
cv
.
namedWindow
(
winName
,
cv
.
WINDOW_AUTOSIZE
)
cv
.
imshow
(
winName
,
output
)
cv
.
waitKey
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment