Commit 3346a7ac authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Fix batch normalization fusion from TensorFlow's SSDs

parent 2f9b4439
......@@ -64,36 +64,51 @@ removedNodes = []
# Detect unfused batch normalization nodes and fuse them.
def fuse_batch_normalization():
pattern = ['Add', 'Rsqrt', 'Mul', 'Mul', 'Mul', 'Sub', 'Add']
candidates = []
for node in graph_def.node:
if node.op == pattern[len(candidates)]:
candidates.append(node)
# Add_0 <-- moving_variance, add_y
# Rsqrt <-- Add_0
# Mul_0 <-- Rsqrt, gamma
# Mul_1 <-- input, Mul_0
# Mul_2 <-- moving_mean, Mul_0
# Sub_0 <-- beta, Mul_2
# Add_1 <-- Mul_1, Sub_0
nodesMap = {node.name: node for node in graph_def.node}
subgraph = ['Add',
['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
def checkSubgraph(node, targetNode, inputs, fusedNodes):
op = targetNode[0]
if node.op == op and (len(node.input) >= len(targetNode) - 1):
fusedNodes.append(node)
for i, inpOp in enumerate(targetNode[1:]):
if isinstance(inpOp, list):
if not node.input[i] in nodesMap or \
not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes):
return False
else:
inputs[inpOp] = node.input[i]
return True
else:
candidates = []
if len(candidates) == len(pattern):
inp = candidates[3].input[0]
gamma = candidates[2].input[1]
beta = candidates[5].input[0]
moving_mean = candidates[4].input[0]
moving_variance = candidates[0].input[0]
return False
nodesToRemove = []
for node in graph_def.node:
inputs = {}
fusedNodes = []
if checkSubgraph(node, subgraph, inputs, fusedNodes):
name = node.name
node.Clear()
node.name = name
node.op = 'FusedBatchNorm'
node.input.append(inp)
node.input.append(gamma)
node.input.append(beta)
node.input.append(moving_mean)
node.input.append(moving_variance)
node.input.append(inputs['input'])
node.input.append(inputs['gamma'])
node.input.append(inputs['beta'])
node.input.append(inputs['moving_mean'])
node.input.append(inputs['moving_variance'])
text_format.Merge('f: 0.001', node.attr["epsilon"])
for candidate in candidates[:-1]:
graph_def.node.remove(candidate)
candidates = []
nodesToRemove += fusedNodes[1:]
for node in nodesToRemove:
graph_def.node.remove(node)
fuse_batch_normalization()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment