Commit c7cf8fb3 authored by Dmitry Kurtaev's avatar Dmitry Kurtaev Committed by Vadim Pisarevsky

Import SSDs from TensorFlow by training config (#12188)

* Remove TensorFlow and protobuf dependencies from object detection scripts

* Create text graphs for TensorFlow object detection networks from sample
parent e3af72bb
......@@ -885,6 +885,14 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
CV_EXPORTS_W void shrinkCaffeModel(const String& src, const String& dst,
const std::vector<String>& layersTypes = std::vector<String>());
/** @brief Create a text representation for a binary network stored in protocol buffer format.
* @param[in] model A path to binary network.
* @param[in] output A path to output text file to be created.
*
* @note To reduce output file size, trained weights are not included.
*/
CV_EXPORTS_W void writeTextGraph(const String& model, const String& output);
/** @brief Performs non maximum suppression given boxes and corresponding scores.
* @param bboxes a set of bounding boxes to apply NMS.
......
......@@ -782,6 +782,108 @@ void releaseTensor(tensorflow::TensorProto* tensor)
}
}
static void permute(google::protobuf::RepeatedPtrField<tensorflow::NodeDef>* data,
const std::vector<int>& indices)
{
const int num = data->size();
CV_Assert(num == indices.size());
std::vector<int> elemIdToPos(num);
std::vector<int> posToElemId(num);
for (int i = 0; i < num; ++i)
{
elemIdToPos[i] = i;
posToElemId[i] = i;
}
for (int i = 0; i < num; ++i)
{
int elemId = indices[i];
int pos = elemIdToPos[elemId];
if (pos != i)
{
data->SwapElements(i, pos);
const int swappedElemId = posToElemId[i];
elemIdToPos[elemId] = i;
elemIdToPos[swappedElemId] = pos;
posToElemId[i] = elemId;
posToElemId[pos] = swappedElemId;
}
}
}
// Is based on tensorflow::graph_transforms::SortByExecutionOrder
void sortByExecutionOrder(tensorflow::GraphDef& net)
{
// Maps node's name to index at net.node() list.
std::map<std::string, int> nodesMap;
std::map<std::string, int>::iterator nodesMapIt;
for (int i = 0; i < net.node_size(); ++i)
{
const tensorflow::NodeDef& node = net.node(i);
nodesMap.insert(std::make_pair(node.name(), i));
}
// Indices of nodes which use specific node as input.
std::vector<std::vector<int> > edges(nodesMap.size());
std::vector<int> numRefsToAdd(nodesMap.size(), 0);
std::vector<int> nodesToAdd;
for (int i = 0; i < net.node_size(); ++i)
{
const tensorflow::NodeDef& node = net.node(i);
for (int j = 0; j < node.input_size(); ++j)
{
std::string inpName = node.input(j);
inpName = inpName.substr(0, inpName.rfind(':'));
inpName = inpName.substr(inpName.find('^') + 1);
nodesMapIt = nodesMap.find(inpName);
CV_Assert(nodesMapIt != nodesMap.end());
edges[nodesMapIt->second].push_back(i);
}
if (node.input_size() == 0)
nodesToAdd.push_back(i);
else
{
if (node.op() == "Merge" || node.op() == "RefMerge")
{
int numControlEdges = 0;
for (int j = 0; j < node.input_size(); ++j)
numControlEdges += node.input(j)[0] == '^';
numRefsToAdd[i] = numControlEdges + 1;
}
else
numRefsToAdd[i] = node.input_size();
}
}
std::vector<int> permIds;
permIds.reserve(net.node_size());
while (!nodesToAdd.empty())
{
int nodeToAdd = nodesToAdd.back();
nodesToAdd.pop_back();
permIds.push_back(nodeToAdd);
// std::cout << net.node(nodeToAdd).name() << '\n';
for (int i = 0; i < edges[nodeToAdd].size(); ++i)
{
int consumerId = edges[nodeToAdd][i];
if (numRefsToAdd[consumerId] > 0)
{
if (numRefsToAdd[consumerId] == 1)
nodesToAdd.push_back(consumerId);
else
CV_Assert(numRefsToAdd[consumerId] >= 0);
numRefsToAdd[consumerId] -= 1;
}
}
}
CV_Assert(permIds.size() == net.node_size());
permute(net.mutable_node(), permIds);
}
CV__DNN_EXPERIMENTAL_NS_END
}} // namespace dnn, namespace cv
......
......@@ -25,6 +25,8 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor);
void releaseTensor(tensorflow::TensorProto* tensor);
void sortByExecutionOrder(tensorflow::GraphDef& net);
CV__DNN_EXPERIMENTAL_NS_END
}} // namespace dnn, namespace cv
......
......@@ -1950,5 +1950,34 @@ Net readNetFromTensorflow(const std::vector<uchar>& bufferModel, const std::vect
bufferConfigPtr, bufferConfig.size());
}
void writeTextGraph(const String& _model, const String& output)
{
String model = _model;
const std::string modelExt = model.substr(model.rfind('.') + 1);
if (modelExt != "pb")
CV_Error(Error::StsNotImplemented, "Only TensorFlow models support export to text file");
tensorflow::GraphDef net;
ReadTFNetParamsFromBinaryFileOrDie(model.c_str(), &net);
sortByExecutionOrder(net);
RepeatedPtrField<tensorflow::NodeDef>::iterator it;
for (it = net.mutable_node()->begin(); it != net.mutable_node()->end(); ++it)
{
if (it->op() == "Const")
{
it->mutable_attr()->at("value").mutable_tensor()->clear_tensor_content();
}
}
std::string content;
google::protobuf::TextFormat::PrintToString(net, &content);
std::ofstream ofs(output.c_str());
ofs << content;
ofs.close();
}
CV__DNN_EXPERIMENTAL_NS_END
}} // namespace
......@@ -315,6 +315,29 @@ TEST_P(Test_TensorFlow_nets, Inception_v2_SSD)
normAssertDetections(ref, out, "", 0.5, scoreDiff, iouDiff);
}
TEST_P(Test_TensorFlow_nets, MobileNet_v1_SSD)
{
checkBackend();
std::string model = findDataFile("dnn/ssd_mobilenet_v1_coco_2017_11_17.pb", false);
std::string proto = findDataFile("dnn/ssd_mobilenet_v1_coco_2017_11_17.pbtxt", false);
Net net = readNetFromTensorflow(model, proto);
Mat img = imread(findDataFile("dnn/dog416.png", false));
Mat blob = blobFromImage(img, 1.0f, Size(300, 300), Scalar(), true, false);
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
net.setInput(blob);
Mat out = net.forward();
Mat ref = blobFromNPY(findDataFile("dnn/tensorflow/ssd_mobilenet_v1_coco_2017_11_17.detection_out.npy"));
float scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 7e-3 : 1e-5;
float iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.0098 : 1e-3;
normAssertDetections(ref, out, "", 0.3, scoreDiff, iouDiff);
}
TEST_P(Test_TensorFlow_nets, Faster_RCNN)
{
static std::string names[] = {"faster_rcnn_inception_v2_coco_2018_01_28",
......@@ -360,7 +383,8 @@ TEST_P(Test_TensorFlow_nets, MobileNet_v1_SSD_PPN)
net.setInput(blob);
Mat out = net.forward();
double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.008 : default_l1;
double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.011 : default_l1;
double iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.021 : default_lInf;
normAssertDetections(ref, out, "", 0.4, scoreDiff, iouDiff);
}
......
......@@ -3,6 +3,10 @@ import argparse
import sys
import numpy as np
from tf_text_graph_common import readTextMessage
from tf_text_graph_ssd import createSSDGraph
from tf_text_graph_faster_rcnn import createFasterRCNNGraph
backends = (cv.dnn.DNN_BACKEND_DEFAULT, cv.dnn.DNN_BACKEND_HALIDE, cv.dnn.DNN_BACKEND_INFERENCE_ENGINE, cv.dnn.DNN_BACKEND_OPENCV)
targets = (cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_OPENCL, cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD)
......@@ -11,11 +15,15 @@ parser.add_argument('--input', help='Path to input image or video file. Skip thi
parser.add_argument('--model', required=True,
help='Path to a binary file of model contains trained weights. '
'It could be a file with extensions .caffemodel (Caffe), '
'.pb (TensorFlow), .t7 or .net (Torch), .weights (Darknet)')
'.pb (TensorFlow), .t7 or .net (Torch), .weights (Darknet), .bin (OpenVINO)')
parser.add_argument('--config',
help='Path to a text file of model contains network configuration. '
'It could be a file with extensions .prototxt (Caffe), .pbtxt (TensorFlow), .cfg (Darknet)')
parser.add_argument('--framework', choices=['caffe', 'tensorflow', 'torch', 'darknet'],
'It could be a file with extensions .prototxt (Caffe), .pbtxt or .config (TensorFlow), .cfg (Darknet), .xml (OpenVINO)')
parser.add_argument('--out_tf_graph', default='graph.pbtxt',
help='For models from TensorFlow Object Detection API, you may '
'pass a .config file which was used for training through --config '
'argument. This way an additional .pbtxt file with TensorFlow graph will be created.')
parser.add_argument('--framework', choices=['caffe', 'tensorflow', 'torch', 'darknet', 'dldt'],
help='Optional name of an origin framework of the model. '
'Detect it automatically if it does not set.')
parser.add_argument('--classes', help='Optional path to a text file with names of classes to label detected objects.')
......@@ -46,6 +54,20 @@ parser.add_argument('--target', choices=targets, default=cv.dnn.DNN_TARGET_CPU,
'%d: VPU' % targets)
args = parser.parse_args()
# If config specified, try to load it as TensorFlow Object Detection API's pipeline.
config = readTextMessage(args.config)
if 'model' in config:
print('TensorFlow Object Detection API config detected')
if 'ssd' in config['model'][0]:
print('Preparing text graph representation for SSD model: ' + args.out_tf_graph)
createSSDGraph(args.model, args.config, args.out_tf_graph)
args.config = args.out_tf_graph
elif 'faster_rcnn' in config['model'][0]:
print('Preparing text graph representation for Faster-RCNN model: ' + args.out_tf_graph)
createFasterRCNNGraph(args.model, args.config, args.out_tf_graph)
args.config = args.out_tf_graph
# Load names of classes
classes = None
if args.classes:
......
import tensorflow as tf
from tensorflow.core.framework.node_def_pb2 import NodeDef
from google.protobuf import text_format
def tokenize(s):
tokens = []
token = ""
isString = False
isComment = False
for symbol in s:
isComment = (isComment and symbol != '\n') or (not isString and symbol == '#')
if isComment:
continue
def tensorMsg(values):
if symbol == ' ' or symbol == '\t' or symbol == '\r' or symbol == '\'' or \
symbol == '\n' or symbol == ':' or symbol == '\"' or symbol == ';' or \
symbol == ',':
if (symbol == '\"' or symbol == '\'') and isString:
tokens.append(token)
token = ""
else:
if isString:
token += symbol
elif token:
tokens.append(token)
token = ""
isString = (symbol == '\"' or symbol == '\'') ^ isString;
elif symbol == '{' or symbol == '}' or symbol == '[' or symbol == ']':
if token:
tokens.append(token)
token = ""
tokens.append(symbol)
else:
token += symbol
if token:
tokens.append(token)
return tokens
def parseMessage(tokens, idx):
msg = {}
assert(tokens[idx] == '{')
isArray = False
while True:
if not isArray:
idx += 1
if idx < len(tokens):
fieldName = tokens[idx]
else:
return None
if fieldName == '}':
break
idx += 1
fieldValue = tokens[idx]
if fieldValue == '{':
embeddedMsg, idx = parseMessage(tokens, idx)
if fieldName in msg:
msg[fieldName].append(embeddedMsg)
else:
msg[fieldName] = [embeddedMsg]
elif fieldValue == '[':
isArray = True
elif fieldValue == ']':
isArray = False
else:
if fieldName in msg:
msg[fieldName].append(fieldValue)
else:
msg[fieldName] = [fieldValue]
return msg, idx
def readTextMessage(filePath):
with open(filePath, 'rt') as f:
content = f.read()
tokens = tokenize('{' + content + '}')
msg = parseMessage(tokens, 0)
return msg[0] if msg else {}
def listToTensor(values):
if all([isinstance(v, float) for v in values]):
dtype = 'DT_FLOAT'
field = 'float_val'
......@@ -12,16 +90,25 @@ def tensorMsg(values):
else:
raise Exception('Wrong values types')
msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values)
for value in values:
msg += '%s: %s ' % (field, str(value))
return msg + '}'
msg = {
'tensor': {
'dtype': dtype,
'tensor_shape': {
'dim': {
'size': len(values)
}
}
}
}
msg['tensor'][field] = values
return msg
def addConstNode(name, values, graph_def):
node = NodeDef()
node.name = name
node.op = 'Const'
text_format.Merge(tensorMsg(values), node.attr["value"])
node.addAttr('value', values)
graph_def.node.extend([node])
......@@ -29,13 +116,13 @@ def addSlice(inp, out, begins, sizes, graph_def):
beginsNode = NodeDef()
beginsNode.name = out + '/begins'
beginsNode.op = 'Const'
text_format.Merge(tensorMsg(begins), beginsNode.attr["value"])
beginsNode.addAttr('value', begins)
graph_def.node.extend([beginsNode])
sizesNode = NodeDef()
sizesNode.name = out + '/sizes'
sizesNode.op = 'Const'
text_format.Merge(tensorMsg(sizes), sizesNode.attr["value"])
sizesNode.addAttr('value', sizes)
graph_def.node.extend([sizesNode])
sliced = NodeDef()
......@@ -51,7 +138,7 @@ def addReshape(inp, out, shape, graph_def):
shapeNode = NodeDef()
shapeNode.name = out + '/shape'
shapeNode.op = 'Const'
text_format.Merge(tensorMsg(shape), shapeNode.attr["value"])
shapeNode.addAttr('value', shape)
graph_def.node.extend([shapeNode])
reshape = NodeDef()
......@@ -66,7 +153,7 @@ def addSoftMax(inp, out, graph_def):
softmax = NodeDef()
softmax.name = out
softmax.op = 'Softmax'
text_format.Merge('i: -1', softmax.attr['axis'])
softmax.addAttr('axis', -1)
softmax.input.append(inp)
graph_def.node.extend([softmax])
......@@ -79,6 +166,103 @@ def addFlatten(inp, out, graph_def):
graph_def.node.extend([flatten])
class NodeDef:
def __init__(self):
self.input = []
self.name = ""
self.op = ""
self.attr = {}
def addAttr(self, key, value):
assert(not key in self.attr)
if isinstance(value, bool):
self.attr[key] = {'b': value}
elif isinstance(value, int):
self.attr[key] = {'i': value}
elif isinstance(value, float):
self.attr[key] = {'f': value}
elif isinstance(value, str):
self.attr[key] = {'s': value}
elif isinstance(value, list):
self.attr[key] = listToTensor(value)
else:
raise Exception('Unknown type of attribute ' + key)
def Clear(self):
self.input = []
self.name = ""
self.op = ""
self.attr = {}
class GraphDef:
def __init__(self):
self.node = []
def save(self, filePath):
with open(filePath, 'wt') as f:
def printAttr(d, indent):
indent = ' ' * indent
for key, value in sorted(d.items(), key=lambda x:x[0].lower()):
value = value if isinstance(value, list) else [value]
for v in value:
if isinstance(v, dict):
f.write(indent + key + ' {\n')
printAttr(v, len(indent) + 2)
f.write(indent + '}\n')
else:
isString = False
if isinstance(v, str) and not v.startswith('DT_'):
try:
float(v)
except:
isString = True
if isinstance(v, bool):
printed = 'true' if v else 'false'
elif v == 'true' or v == 'false':
printed = 'true' if v == 'true' else 'false'
elif isString:
printed = '\"%s\"' % v
else:
printed = str(v)
f.write(indent + key + ': ' + printed + '\n')
for node in self.node:
f.write('node {\n')
f.write(' name: \"%s\"\n' % node.name)
f.write(' op: \"%s\"\n' % node.op)
for inp in node.input:
f.write(' input: \"%s\"\n' % inp)
for key, value in sorted(node.attr.items(), key=lambda x:x[0].lower()):
f.write(' attr {\n')
f.write(' key: \"%s\"\n' % key)
f.write(' value {\n')
printAttr(value, 6)
f.write(' }\n')
f.write(' }\n')
f.write('}\n')
def parseTextGraph(filePath):
msg = readTextMessage(filePath)
graph = GraphDef()
for node in msg['node']:
graphNode = NodeDef()
graphNode.name = node['name'][0]
graphNode.op = node['op'][0]
graphNode.input = node['input'] if 'input' in node else []
if 'attr' in node:
for attr in node['attr']:
graphNode.attr[attr['key'][0]] = attr['value'][0]
graph.node.append(graphNode)
return graph
# Removes Identity nodes
def removeIdentity(graph_def):
identities = {}
......
This diff is collapsed.
import argparse
import numpy as np
import tensorflow as tf
from tensorflow.core.framework.node_def_pb2 import NodeDef
from tensorflow.tools.graph_transforms import TransformGraph
from google.protobuf import text_format
import cv2 as cv
from tf_text_graph_common import *
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
......@@ -13,13 +8,7 @@ parser = argparse.ArgumentParser(description='Run this script to get a text grap
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
parser.add_argument('--output', required=True, help='Path to output text graph.')
parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.')
parser.add_argument('--scales', default=[0.25, 0.5, 1.0, 2.0], type=float, nargs='+',
help='Hyper-parameter of grid_anchor_generator from a config file.')
parser.add_argument('--aspect_ratios', default=[0.5, 1.0, 2.0], type=float, nargs='+',
help='Hyper-parameter of grid_anchor_generator from a config file.')
parser.add_argument('--features_stride', default=16, type=float, nargs='+',
help='Hyper-parameter from a config file.')
parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
args = parser.parse_args()
scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
......@@ -39,11 +28,28 @@ scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
'FirstStageFeatureExtractor/GreaterEqual',
'FirstStageFeatureExtractor/LogicalAnd')
# Load a config file.
config = readTextMessage(args.config)
config = config['model'][0]['faster_rcnn'][0]
num_classes = int(config['num_classes'][0])
grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]
scales = [float(s) for s in grid_anchor_generator['scales']]
aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
width_stride = float(grid_anchor_generator['width_stride'][0])
height_stride = float(grid_anchor_generator['height_stride'][0])
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
print('Number of classes: %d' % num_classes)
print('Scales: %s' % str(scales))
print('Aspect ratios: %s' % str(aspect_ratios))
print('Width stride: %f' % width_stride)
print('Height stride: %f' % height_stride)
print('Features stride: %f' % features_stride)
# Read the graph.
with tf.gfile.FastGFile(args.input, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
cv.dnn.writeTextGraph(args.input, args.output)
graph_def = parseTextGraph(args.output)
removeIdentity(graph_def)
......@@ -87,22 +93,22 @@ proposals.op = 'PriorBox'
proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
proposals.input.append(graph_def.node[0].name) # image_tensor
text_format.Merge('b: false', proposals.attr["flip"])
text_format.Merge('b: true', proposals.attr["clip"])
text_format.Merge('f: %f' % args.features_stride, proposals.attr["step"])
text_format.Merge('f: 0.0', proposals.attr["offset"])
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), proposals.attr["variance"])
proposals.addAttr('flip', False)
proposals.addAttr('clip', True)
proposals.addAttr('step', features_stride)
proposals.addAttr('offset', 0.0)
proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
widths = []
heights = []
for a in args.aspect_ratios:
for s in args.scales:
for a in aspect_ratios:
for s in scales:
ar = np.sqrt(a)
heights.append((args.features_stride**2) * s / ar)
widths.append((args.features_stride**2) * s * ar)
heights.append((features_stride**2) * s / ar)
widths.append((features_stride**2) * s * ar)
text_format.Merge(tensorMsg(widths), proposals.attr["width"])
text_format.Merge(tensorMsg(heights), proposals.attr["height"])
proposals.addAttr('width', widths)
proposals.addAttr('height', heights)
graph_def.node.extend([proposals])
......@@ -115,14 +121,14 @@ detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
detectionOut.input.append('proposals')
text_format.Merge('i: 2', detectionOut.attr['num_classes'])
text_format.Merge('b: true', detectionOut.attr['share_location'])
text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold'])
text_format.Merge('i: 6000', detectionOut.attr['top_k'])
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
text_format.Merge('b: true', detectionOut.attr['clip'])
detectionOut.addAttr('num_classes', 2)
detectionOut.addAttr('share_location', True)
detectionOut.addAttr('background_label_id', 0)
detectionOut.addAttr('nms_threshold', 0.7)
detectionOut.addAttr('top_k', 6000)
detectionOut.addAttr('code_type', "CENTER_SIZE")
detectionOut.addAttr('keep_top_k', 100)
detectionOut.addAttr('clip', True)
graph_def.node.extend([detectionOut])
......@@ -171,7 +177,7 @@ for node in graph_def.node:
if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
text_format.Merge('b: true', node.attr["loc_pred_transposed"])
node.addAttr('loc_pred_transposed', True)
################################################################################
### Postprocessing
......@@ -181,7 +187,7 @@ addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4],
variance = NodeDef()
variance.name = 'proposals/variance'
variance.op = 'Const'
text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), variance.attr["value"])
variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])
graph_def.node.extend([variance])
varianceEncoder = NodeDef()
......@@ -189,7 +195,7 @@ varianceEncoder.name = 'variance_encoded'
varianceEncoder.op = 'Mul'
varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
varianceEncoder.input.append(variance.name)
text_format.Merge('i: 2', varianceEncoder.attr["axis"])
varianceEncoder.addAttr('axis', 2)
graph_def.node.extend([varianceEncoder])
addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
......@@ -203,16 +209,16 @@ detectionOut.input.append('variance_encoded/flatten')
detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
detectionOut.input.append('detection_out/slice/reshape')
text_format.Merge('i: %d' % args.num_classes, detectionOut.attr['num_classes'])
text_format.Merge('b: false', detectionOut.attr['share_location'])
text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['background_label_id'])
text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
text_format.Merge('b: true', detectionOut.attr['clip'])
text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target'])
text_format.Merge('f: 0.3', detectionOut.attr['confidence_threshold'])
text_format.Merge('b: false', detectionOut.attr['group_by_classes'])
detectionOut.addAttr('num_classes', num_classes)
detectionOut.addAttr('share_location', False)
detectionOut.addAttr('background_label_id', num_classes + 1)
detectionOut.addAttr('nms_threshold', 0.6)
detectionOut.addAttr('code_type', "CENTER_SIZE")
detectionOut.addAttr('keep_top_k',100)
detectionOut.addAttr('clip', True)
detectionOut.addAttr('variance_encoded_in_target', True)
detectionOut.addAttr('confidence_threshold', 0.3)
detectionOut.addAttr('group_by_classes', False)
graph_def.node.extend([detectionOut])
for node in reversed(topNodes):
......@@ -227,4 +233,5 @@ graph_def.node[-1].name = 'detection_masks'
graph_def.node[-1].op = 'Sigmoid'
graph_def.node[-1].input.pop()
tf.train.write_graph(graph_def, "", args.output, as_text=True)
# Save as text.
graph_def.save(args.output)
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment