Commit c40fbad1 authored by Lorenzo Lucignano's avatar Lorenzo Lucignano

Samples DNN: tf_text_graph_sd.py loads box coder variance and box NMS params from config file

parent 1f57eb93
......@@ -283,6 +283,9 @@ def createSSDGraph(modelPath, configPath, outputPath):
# Add layers that generate anchors (bounding boxes proposals).
priorBoxes = []
boxCoder = config['box_coder'][0]
fasterRcnnBoxCoder = boxCoder['faster_rcnn_box_coder'][0]
boxCoderVariance = [1.0/float(fasterRcnnBoxCoder['x_scale'][0]), 1.0/float(fasterRcnnBoxCoder['y_scale'][0]), 1.0/float(fasterRcnnBoxCoder['width_scale'][0]), 1.0/float(fasterRcnnBoxCoder['height_scale'][0])]
for i in range(num_layers):
priorBox = NodeDef()
priorBox.name = 'PriorBox_%d' % i
......@@ -303,7 +306,7 @@ def createSSDGraph(modelPath, configPath, outputPath):
priorBox.addAttr('width', widths)
priorBox.addAttr('height', heights)
priorBox.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
priorBox.addAttr('variance', boxCoderVariance)
graph_def.node.extend([priorBox])
priorBoxes.append(priorBox.name)
......@@ -336,11 +339,31 @@ def createSSDGraph(modelPath, configPath, outputPath):
detectionOut.addAttr('num_classes', num_classes + 1)
detectionOut.addAttr('share_location', True)
detectionOut.addAttr('background_label_id', 0)
postProcessing = config['post_processing'][0]
batchNMS = postProcessing['batch_non_max_suppression'][0]
if 'iou_threshold' in batchNMS:
detectionOut.addAttr('nms_threshold', float(batchNMS['iou_threshold'][0]))
else:
detectionOut.addAttr('nms_threshold', 0.6)
if 'score_threshold' in batchNMS:
detectionOut.addAttr('confidence_threshold', float(batchNMS['score_threshold'][0]))
else:
detectionOut.addAttr('confidence_threshold', 0.01)
if 'max_detections_per_class' in batchNMS:
detectionOut.addAttr('top_k', int(batchNMS['max_detections_per_class'][0]))
else:
detectionOut.addAttr('top_k', 100)
detectionOut.addAttr('code_type', "CENTER_SIZE")
if 'max_total_detections' in batchNMS:
detectionOut.addAttr('keep_top_k', int(batchNMS['max_total_detections'][0]))
else:
detectionOut.addAttr('keep_top_k', 100)
detectionOut.addAttr('confidence_threshold', 0.01)
detectionOut.addAttr('code_type', "CENTER_SIZE")
graph_def.node.extend([detectionOut])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment