Commit 3346a7ac authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Fix batch normalization fusion from TensorFlow's SSDs

parent 2f9b4439
...@@ -64,36 +64,51 @@ removedNodes = [] ...@@ -64,36 +64,51 @@ removedNodes = []
# Detect unfused batch normalization nodes and fuse them. # Detect unfused batch normalization nodes and fuse them.
def fuse_batch_normalization(): def fuse_batch_normalization():
pattern = ['Add', 'Rsqrt', 'Mul', 'Mul', 'Mul', 'Sub', 'Add'] # Add_0 <-- moving_variance, add_y
candidates = [] # Rsqrt <-- Add_0
# Mul_0 <-- Rsqrt, gamma
for node in graph_def.node: # Mul_1 <-- input, Mul_0
if node.op == pattern[len(candidates)]: # Mul_2 <-- moving_mean, Mul_0
candidates.append(node) # Sub_0 <-- beta, Mul_2
# Add_1 <-- Mul_1, Sub_0
nodesMap = {node.name: node for node in graph_def.node}
subgraph = ['Add',
['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
def checkSubgraph(node, targetNode, inputs, fusedNodes):
op = targetNode[0]
if node.op == op and (len(node.input) >= len(targetNode) - 1):
fusedNodes.append(node)
for i, inpOp in enumerate(targetNode[1:]):
if isinstance(inpOp, list):
if not node.input[i] in nodesMap or \
not checkSubgraph(nodesMap[node.input[i]], inpOp, inputs, fusedNodes):
return False
else: else:
candidates = [] inputs[inpOp] = node.input[i]
if len(candidates) == len(pattern): return True
inp = candidates[3].input[0] else:
gamma = candidates[2].input[1] return False
beta = candidates[5].input[0]
moving_mean = candidates[4].input[0]
moving_variance = candidates[0].input[0]
nodesToRemove = []
for node in graph_def.node:
inputs = {}
fusedNodes = []
if checkSubgraph(node, subgraph, inputs, fusedNodes):
name = node.name name = node.name
node.Clear() node.Clear()
node.name = name node.name = name
node.op = 'FusedBatchNorm' node.op = 'FusedBatchNorm'
node.input.append(inp) node.input.append(inputs['input'])
node.input.append(gamma) node.input.append(inputs['gamma'])
node.input.append(beta) node.input.append(inputs['beta'])
node.input.append(moving_mean) node.input.append(inputs['moving_mean'])
node.input.append(moving_variance) node.input.append(inputs['moving_variance'])
text_format.Merge('f: 0.001', node.attr["epsilon"]) text_format.Merge('f: 0.001', node.attr["epsilon"])
nodesToRemove += fusedNodes[1:]
for candidate in candidates[:-1]: for node in nodesToRemove:
graph_def.node.remove(candidate) graph_def.node.remove(node)
candidates = []
fuse_batch_normalization() fuse_batch_normalization()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment