Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
O
opencv
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
opencv
Commits
f056c0f1
Commit
f056c0f1
authored
Aug 13, 2018
by
Dmitry Kurtaev
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
UINT8 face detection network using Intel's Inference Engine backend
parent
8b2122e1
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
80 additions
and
27 deletions
+80
-27
quantize_face_detector.py
modules/dnn/misc/quantize_face_detector.py
+15
-4
normalize_bbox_layer.cpp
modules/dnn/src/layers/normalize_bbox_layer.cpp
+59
-14
tf_importer.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
+4
-3
test_tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp
+2
-6
No files found.
modules/dnn/misc/quantize_face_detector.py
View file @
f056c0f1
...
...
@@ -104,8 +104,12 @@ def l2norm(x, name):
inp
=
tf
.
placeholder
(
dtype
,
[
1
,
300
,
300
,
3
],
'data'
)
data_bn
=
batch_norm
(
inp
,
'data_bn'
)
data_scale
=
scale
(
data_bn
,
'data_scale'
)
data_scale
=
tf
.
pad
(
data_scale
,
[[
0
,
0
],
[
3
,
3
],
[
3
,
3
],
[
0
,
0
]])
# Instead of tf.pad we use tf.space_to_batch_nd layers which override convolution's padding strategy to explicit numbers
# data_scale = tf.pad(data_scale, [[0, 0], [3, 3], [3, 3], [0, 0]])
data_scale
=
tf
.
space_to_batch_nd
(
data_scale
,
[
1
,
1
],
[[
3
,
3
],
[
3
,
3
]],
name
=
'Pad'
)
conv1_h
=
conv
(
data_scale
,
stride
=
2
,
pad
=
'VALID'
,
name
=
'conv1_h'
)
conv1_bn_h
=
batch_norm
(
conv1_h
,
'conv1_bn_h'
)
conv1_scale_h
=
scale
(
conv1_bn_h
,
'conv1_scale_h'
)
conv1_relu
=
tf
.
nn
.
relu
(
conv1_scale_h
)
...
...
@@ -133,8 +137,11 @@ layer_128_1_sum = layer_128_1_conv2 + layer_128_1_conv_expand_h
layer_256_1_bn1
=
batch_norm
(
layer_128_1_sum
,
'layer_256_1_bn1'
)
layer_256_1_scale1
=
scale
(
layer_256_1_bn1
,
'layer_256_1_scale1'
)
layer_256_1_relu1
=
tf
.
nn
.
relu
(
layer_256_1_scale1
)
layer_256_1_conv1
=
tf
.
pad
(
layer_256_1_relu1
,
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]])
# layer_256_1_conv1 = tf.pad(layer_256_1_relu1, [[0, 0], [1, 1], [1, 1], [0, 0]])
layer_256_1_conv1
=
tf
.
space_to_batch_nd
(
layer_256_1_relu1
,
[
1
,
1
],
[[
1
,
1
],
[
1
,
1
]],
name
=
'Pad_1'
)
layer_256_1_conv1
=
conv
(
layer_256_1_conv1
,
stride
=
2
,
pad
=
'VALID'
,
name
=
'layer_256_1_conv1'
)
layer_256_1_bn2
=
batch_norm
(
layer_256_1_conv1
,
'layer_256_1_bn2'
)
layer_256_1_scale2
=
scale
(
layer_256_1_bn2
,
'layer_256_1_scale2'
)
layer_256_1_relu2
=
tf
.
nn
.
relu
(
layer_256_1_scale2
)
...
...
@@ -160,8 +167,11 @@ fc7 = tf.nn.relu(last_scale_h, name='last_relu')
conv6_1_h
=
conv
(
fc7
,
'conv6_1_h'
,
activ
=
tf
.
nn
.
relu
)
conv6_2_h
=
conv
(
conv6_1_h
,
stride
=
2
,
name
=
'conv6_2_h'
,
activ
=
tf
.
nn
.
relu
)
conv7_1_h
=
conv
(
conv6_2_h
,
'conv7_1_h'
,
activ
=
tf
.
nn
.
relu
)
conv7_2_h
=
tf
.
pad
(
conv7_1_h
,
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]])
# conv7_2_h = tf.pad(conv7_1_h, [[0, 0], [1, 1], [1, 1], [0, 0]])
conv7_2_h
=
tf
.
space_to_batch_nd
(
conv7_1_h
,
[
1
,
1
],
[[
1
,
1
],
[
1
,
1
]],
name
=
'Pad_2'
)
conv7_2_h
=
conv
(
conv7_2_h
,
stride
=
2
,
pad
=
'VALID'
,
name
=
'conv7_2_h'
,
activ
=
tf
.
nn
.
relu
)
conv8_1_h
=
conv
(
conv7_2_h
,
pad
=
'SAME'
,
name
=
'conv8_1_h'
,
activ
=
tf
.
nn
.
relu
)
conv8_2_h
=
conv
(
conv8_1_h
,
pad
=
'SAME'
,
name
=
'conv8_2_h'
,
activ
=
tf
.
nn
.
relu
)
conv9_1_h
=
conv
(
conv8_2_h
,
'conv9_1_h'
,
activ
=
tf
.
nn
.
relu
)
...
...
@@ -201,6 +211,7 @@ with tf.Session() as sess:
inputData
=
np
.
random
.
standard_normal
([
1
,
3
,
300
,
300
])
.
astype
(
np
.
float32
)
cvNet
.
setInput
(
inputData
)
cvNet
.
setPreferableBackend
(
cv
.
dnn
.
DNN_BACKEND_OPENCV
)
outDNN
=
cvNet
.
forward
(
out_nodes
)
outTF
=
sess
.
run
([
mbox_loc
,
mbox_conf_flatten
],
feed_dict
=
{
inp
:
inputData
.
transpose
(
0
,
2
,
3
,
1
)})
...
...
@@ -254,7 +265,7 @@ for i in reversed(range(len(graph_def.node))):
del
graph_def
.
node
[
i
]
for
attr
in
[
'T'
,
'data_format'
,
'Tshape'
,
'N'
,
'Tidx'
,
'Tdim'
,
'use_cudnn_on_gpu'
,
'Index'
,
'Tperm'
,
'is_training'
,
'Tpaddings'
]:
'Tpaddings'
,
'Tblock_shape'
,
'Tcrops'
]:
if
attr
in
graph_def
.
node
[
i
]
.
attr
:
del
graph_def
.
node
[
i
]
.
attr
[
attr
]
...
...
modules/dnn/src/layers/normalize_bbox_layer.cpp
View file @
f056c0f1
...
...
@@ -63,9 +63,18 @@ public:
virtual
bool
supportBackend
(
int
backendId
)
CV_OVERRIDE
{
return
backendId
==
DNN_BACKEND_OPENCV
||
backendId
==
DNN_BACKEND_INFERENCE_ENGINE
&&
haveInfEngine
()
&&
pnorm
==
2
&&
!
blobs
.
empty
();
if
(
backendId
==
DNN_BACKEND_INFERENCE_ENGINE
)
{
if
(
pnorm
!=
2
)
return
false
;
if
(
!
blobs
.
empty
())
return
true
;
if
(
preferableTarget
==
DNN_TARGET_MYRIAD
)
return
!
acrossSpatial
;
return
startAxis
==
1
&&
(
!
acrossSpatial
||
endAxis
>
1
);
}
else
return
backendId
==
DNN_BACKEND_OPENCV
;
}
bool
getMemoryShapes
(
const
std
::
vector
<
MatShape
>
&
inputs
,
...
...
@@ -80,6 +89,14 @@ public:
return
true
;
}
void
finalize
(
const
std
::
vector
<
Mat
*>
&
inputs
,
std
::
vector
<
Mat
>
&
outputs
)
CV_OVERRIDE
{
CV_Assert
(
inputs
.
size
()
==
1
);
endAxis
=
endAxis
==
-
1
?
(
inputs
[
0
]
->
dims
-
1
)
:
endAxis
;
startAxis
=
startAxis
==
-
1
?
(
inputs
[
0
]
->
dims
-
1
)
:
startAxis
;
acrossSpatial
=
(
startAxis
==
1
&&
endAxis
==
inputs
[
0
]
->
dims
-
1
);
}
#ifdef HAVE_OPENCL
bool
forward_ocl
(
InputArrayOfArrays
inputs_
,
OutputArrayOfArrays
outputs_
,
OutputArrayOfArrays
internals_
)
{
...
...
@@ -240,24 +257,52 @@ public:
}
}
virtual
Ptr
<
BackendNode
>
initInfEngine
(
const
std
::
vector
<
Ptr
<
BackendWrapper
>
>&
)
CV_OVERRIDE
virtual
Ptr
<
BackendNode
>
initInfEngine
(
const
std
::
vector
<
Ptr
<
BackendWrapper
>
>&
inputs
)
CV_OVERRIDE
{
#ifdef HAVE_INF_ENGINE
InferenceEngine
::
DataPtr
input
=
infEngineDataNode
(
inputs
[
0
]);
InferenceEngine
::
LayerParams
lp
;
lp
.
name
=
name
;
lp
.
type
=
"Normalize"
;
lp
.
precision
=
InferenceEngine
::
Precision
::
FP32
;
std
::
shared_ptr
<
InferenceEngine
::
CNNLayer
>
ieLayer
(
new
InferenceEngine
::
CNNLayer
(
lp
));
CV_Assert
(
!
blobs
.
empty
());
i
eLayer
->
params
[
"eps"
]
=
format
(
"%f"
,
epsilon
);
ieLayer
->
params
[
"across_spatial"
]
=
acrossSpatial
?
"1"
:
"0"
;
ieLayer
->
params
[
"channel_shared"
]
=
blobs
[
0
].
total
()
==
1
?
"1"
:
"0"
;
i
f
(
input
->
dims
.
size
()
==
4
)
{
const
int
numChannels
=
input
->
dims
[
2
];
// NOTE: input->dims are reversed (whcn)
const
size_t
numChannels
=
blobs
[
0
].
total
();
ieLayer
->
blobs
[
"weights"
]
=
wrapToInfEngineBlob
(
blobs
[
0
],
{
numChannels
},
InferenceEngine
::
Layout
::
C
);
return
Ptr
<
BackendNode
>
(
new
InfEngineBackendNode
(
ieLayer
));
lp
.
type
=
"Normalize"
;
std
::
shared_ptr
<
InferenceEngine
::
CNNLayer
>
ieLayer
(
new
InferenceEngine
::
CNNLayer
(
lp
));
if
(
blobs
.
empty
())
{
auto
weights
=
InferenceEngine
::
make_shared_blob
<
float
>
(
InferenceEngine
::
Precision
::
FP32
,
InferenceEngine
::
Layout
::
C
,
{
numChannels
});
weights
->
allocate
();
std
::
vector
<
float
>
ones
(
numChannels
,
1
);
weights
->
set
(
ones
);
ieLayer
->
blobs
[
"weights"
]
=
weights
;
ieLayer
->
params
[
"channel_shared"
]
=
"0"
;
}
else
{
CV_Assert
(
numChannels
==
blobs
[
0
].
total
());
ieLayer
->
blobs
[
"weights"
]
=
wrapToInfEngineBlob
(
blobs
[
0
],
{
numChannels
},
InferenceEngine
::
Layout
::
C
);
ieLayer
->
params
[
"channel_shared"
]
=
blobs
[
0
].
total
()
==
1
?
"1"
:
"0"
;
}
ieLayer
->
params
[
"eps"
]
=
format
(
"%f"
,
epsilon
);
ieLayer
->
params
[
"across_spatial"
]
=
acrossSpatial
?
"1"
:
"0"
;
return
Ptr
<
BackendNode
>
(
new
InfEngineBackendNode
(
ieLayer
));
}
else
{
InferenceEngine
::
LayerParams
lp
;
lp
.
name
=
name
;
lp
.
type
=
"GRN"
;
lp
.
precision
=
InferenceEngine
::
Precision
::
FP32
;
std
::
shared_ptr
<
InferenceEngine
::
CNNLayer
>
ieLayer
(
new
InferenceEngine
::
CNNLayer
(
lp
));
ieLayer
->
params
[
"bias"
]
=
format
(
"%f"
,
epsilon
);
return
Ptr
<
BackendNode
>
(
new
InfEngineBackendNode
(
ieLayer
));
}
#endif // HAVE_INF_ENGINE
return
Ptr
<
BackendNode
>
();
}
...
...
modules/dnn/src/tensorflow/tf_importer.cpp
View file @
f056c0f1
...
...
@@ -529,7 +529,8 @@ const tensorflow::TensorProto& TFImporter::getConstBlob(const tensorflow::NodeDe
Pin
kernel_inp
=
parsePin
(
layer
.
input
(
input_blob_index
));
if
(
const_layers
.
find
(
kernel_inp
.
name
)
==
const_layers
.
end
())
CV_Error
(
Error
::
StsError
,
"Const kernel input not found"
);
CV_Error
(
Error
::
StsError
,
"Input ["
+
layer
.
input
(
input_blob_index
)
+
"] for node ["
+
layer
.
name
()
+
"] not found"
);
if
(
kernel_inp
.
blobIndex
!=
0
)
CV_Error
(
Error
::
StsError
,
"Unsupported kernel input"
);
...
...
@@ -867,13 +868,13 @@ void TFImporter::populateNet(Net dstNet)
layerParams
.
set
(
"num_output"
,
layerParams
.
blobs
[
0
].
size
[
0
]);
setStrides
(
layerParams
,
layer
);
setPadding
(
layerParams
,
layer
);
if
(
!
layerParams
.
has
(
"pad_w"
)
&&
!
layerParams
.
has
(
"pad_h"
))
setPadding
(
layerParams
,
layer
);
// The final node of dilated convolution subgraph.
next_layers
=
getNextLayers
(
net
,
name
,
"BatchToSpaceND"
);
if
(
!
next_layers
.
empty
())
{
layerParams
.
set
(
"pad_mode"
,
""
);
// We use padding values.
CV_Assert
(
next_layers
.
size
()
==
1
);
ExcludeLayer
(
net
,
next_layers
[
0
].
second
,
0
,
false
);
layers_to_ignore
.
insert
(
next_layers
[
0
].
first
);
...
...
modules/dnn/test/test_tf_importer.cpp
View file @
f056c0f1
...
...
@@ -239,7 +239,7 @@ TEST_P(Test_TensorFlow_layers, l2_normalize)
// TODO: fix it and add to l2_normalize
TEST_P
(
Test_TensorFlow_layers
,
l2_normalize_3d
)
{
if
(
backend
==
DNN_BACKEND_INFERENCE_ENGINE
&&
target
==
DNN_TARGET_MYRIAD
)
if
(
backend
==
DNN_BACKEND_INFERENCE_ENGINE
&&
target
!=
DNN_TARGET_CPU
)
throw
SkipTestException
(
""
);
runTensorFlowNet
(
"l2_normalize_3d"
);
}
...
...
@@ -360,10 +360,6 @@ TEST_P(Test_TensorFlow_nets, MobileNet_v1_SSD_PPN)
TEST_P
(
Test_TensorFlow_nets
,
opencv_face_detector_uint8
)
{
checkBackend
();
if
(
backend
==
DNN_BACKEND_INFERENCE_ENGINE
&&
(
target
==
DNN_TARGET_OPENCL
||
target
==
DNN_TARGET_OPENCL_FP16
||
target
==
DNN_TARGET_MYRIAD
))
throw
SkipTestException
(
""
);
std
::
string
proto
=
findDataFile
(
"dnn/opencv_face_detector.pbtxt"
,
false
);
std
::
string
model
=
findDataFile
(
"dnn/opencv_face_detector_uint8.pb"
,
false
);
...
...
@@ -386,7 +382,7 @@ TEST_P(Test_TensorFlow_nets, opencv_face_detector_uint8)
0
,
1
,
0.97203469
,
0.67965847
,
0.06876482
,
0.73999709
,
0.1513494
,
0
,
1
,
0.95097077
,
0.51901293
,
0.45863652
,
0.5777427
,
0.5347801
);
double
scoreDiff
=
(
target
==
DNN_TARGET_OPENCL_FP16
||
target
==
DNN_TARGET_MYRIAD
)
?
4e-3
:
3.4e-3
;
double
iouDiff
=
(
target
==
DNN_TARGET_OPENCL_FP16
||
target
==
DNN_TARGET_MYRIAD
)
?
0.0
17
:
1e-2
;
double
iouDiff
=
(
target
==
DNN_TARGET_OPENCL_FP16
||
target
==
DNN_TARGET_MYRIAD
)
?
0.0
24
:
1e-2
;
normAssertDetections
(
ref
,
out
,
""
,
0.9
,
scoreDiff
,
iouDiff
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment