Commit c40fbad1 authored by Lorenzo Lucignano's avatar Lorenzo Lucignano

Samples DNN: tf_text_graph_sd.py loads box coder variance and box NMS params from config file

parent 1f57eb93
...@@ -283,6 +283,9 @@ def createSSDGraph(modelPath, configPath, outputPath): ...@@ -283,6 +283,9 @@ def createSSDGraph(modelPath, configPath, outputPath):
# Add layers that generate anchors (bounding boxes proposals). # Add layers that generate anchors (bounding boxes proposals).
priorBoxes = [] priorBoxes = []
boxCoder = config['box_coder'][0]
fasterRcnnBoxCoder = boxCoder['faster_rcnn_box_coder'][0]
boxCoderVariance = [1.0/float(fasterRcnnBoxCoder['x_scale'][0]), 1.0/float(fasterRcnnBoxCoder['y_scale'][0]), 1.0/float(fasterRcnnBoxCoder['width_scale'][0]), 1.0/float(fasterRcnnBoxCoder['height_scale'][0])]
for i in range(num_layers): for i in range(num_layers):
priorBox = NodeDef() priorBox = NodeDef()
priorBox.name = 'PriorBox_%d' % i priorBox.name = 'PriorBox_%d' % i
...@@ -303,7 +306,7 @@ def createSSDGraph(modelPath, configPath, outputPath): ...@@ -303,7 +306,7 @@ def createSSDGraph(modelPath, configPath, outputPath):
priorBox.addAttr('width', widths) priorBox.addAttr('width', widths)
priorBox.addAttr('height', heights) priorBox.addAttr('height', heights)
priorBox.addAttr('variance', [0.1, 0.1, 0.2, 0.2]) priorBox.addAttr('variance', boxCoderVariance)
graph_def.node.extend([priorBox]) graph_def.node.extend([priorBox])
priorBoxes.append(priorBox.name) priorBoxes.append(priorBox.name)
...@@ -336,11 +339,31 @@ def createSSDGraph(modelPath, configPath, outputPath): ...@@ -336,11 +339,31 @@ def createSSDGraph(modelPath, configPath, outputPath):
detectionOut.addAttr('num_classes', num_classes + 1) detectionOut.addAttr('num_classes', num_classes + 1)
detectionOut.addAttr('share_location', True) detectionOut.addAttr('share_location', True)
detectionOut.addAttr('background_label_id', 0) detectionOut.addAttr('background_label_id', 0)
detectionOut.addAttr('nms_threshold', 0.6)
detectionOut.addAttr('top_k', 100) postProcessing = config['post_processing'][0]
batchNMS = postProcessing['batch_non_max_suppression'][0]
if 'iou_threshold' in batchNMS:
detectionOut.addAttr('nms_threshold', float(batchNMS['iou_threshold'][0]))
else:
detectionOut.addAttr('nms_threshold', 0.6)
if 'score_threshold' in batchNMS:
detectionOut.addAttr('confidence_threshold', float(batchNMS['score_threshold'][0]))
else:
detectionOut.addAttr('confidence_threshold', 0.01)
if 'max_detections_per_class' in batchNMS:
detectionOut.addAttr('top_k', int(batchNMS['max_detections_per_class'][0]))
else:
detectionOut.addAttr('top_k', 100)
if 'max_total_detections' in batchNMS:
detectionOut.addAttr('keep_top_k', int(batchNMS['max_total_detections'][0]))
else:
detectionOut.addAttr('keep_top_k', 100)
detectionOut.addAttr('code_type', "CENTER_SIZE") detectionOut.addAttr('code_type', "CENTER_SIZE")
detectionOut.addAttr('keep_top_k', 100)
detectionOut.addAttr('confidence_threshold', 0.01)
graph_def.node.extend([detectionOut]) graph_def.node.extend([detectionOut])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment