Commit 76cfa65d authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

AddV2 from TensorFlow

parent b4759d72
...@@ -996,7 +996,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -996,7 +996,7 @@ void TFImporter::populateNet(Net dstNet)
if (getDataLayout(name, data_layouts) == DATA_LAYOUT_UNKNOWN) if (getDataLayout(name, data_layouts) == DATA_LAYOUT_UNKNOWN)
data_layouts[name] = DATA_LAYOUT_NHWC; data_layouts[name] = DATA_LAYOUT_NHWC;
} }
else if (type == "BiasAdd" || type == "Add" || type == "Sub" || type=="AddN") else if (type == "BiasAdd" || type == "Add" || type == "AddV2" || type == "Sub" || type=="AddN")
{ {
bool haveConst = false; bool haveConst = false;
for(int ii = 0; !haveConst && ii < layer.input_size(); ++ii) for(int ii = 0; !haveConst && ii < layer.input_size(); ++ii)
......
...@@ -62,7 +62,7 @@ class MultiscaleAnchorGenerator: ...@@ -62,7 +62,7 @@ class MultiscaleAnchorGenerator:
def createSSDGraph(modelPath, configPath, outputPath): def createSSDGraph(modelPath, configPath, outputPath):
# Nodes that should be kept. # Nodes that should be kept.
keepOps = ['Conv2D', 'BiasAdd', 'Add', 'Relu', 'Relu6', 'Placeholder', 'FusedBatchNorm', keepOps = ['Conv2D', 'BiasAdd', 'Add', 'AddV2', 'Relu', 'Relu6', 'Placeholder', 'FusedBatchNorm',
'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity', 'DepthwiseConv2dNative', 'ConcatV2', 'Mul', 'MaxPool', 'AvgPool', 'Identity',
'Sub', 'ResizeNearestNeighbor', 'Pad', 'FusedBatchNormV3'] 'Sub', 'ResizeNearestNeighbor', 'Pad', 'FusedBatchNormV3']
...@@ -151,6 +151,9 @@ def createSSDGraph(modelPath, configPath, outputPath): ...@@ -151,6 +151,9 @@ def createSSDGraph(modelPath, configPath, outputPath):
subgraphBatchNorm = ['Add', subgraphBatchNorm = ['Add',
['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']], ['Mul', 'input', ['Mul', ['Rsqrt', ['Add', 'moving_variance', 'add_y']], 'gamma']],
['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]] ['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
subgraphBatchNormV2 = ['AddV2',
['Mul', 'input', ['Mul', ['Rsqrt', ['AddV2', 'moving_variance', 'add_y']], 'gamma']],
['Sub', 'beta', ['Mul', 'moving_mean', 'Mul_0']]]
# Detect unfused nearest neighbor resize. # Detect unfused nearest neighbor resize.
subgraphResizeNN = ['Reshape', subgraphResizeNN = ['Reshape',
['Mul', ['Reshape', 'input', ['Pack', 'shape_1', 'shape_2', 'shape_3', 'shape_4', 'shape_5']], ['Mul', ['Reshape', 'input', ['Pack', 'shape_1', 'shape_2', 'shape_3', 'shape_4', 'shape_5']],
...@@ -177,7 +180,8 @@ def createSSDGraph(modelPath, configPath, outputPath): ...@@ -177,7 +180,8 @@ def createSSDGraph(modelPath, configPath, outputPath):
for node in graph_def.node: for node in graph_def.node:
inputs = {} inputs = {}
fusedNodes = [] fusedNodes = []
if checkSubgraph(node, subgraphBatchNorm, inputs, fusedNodes): if checkSubgraph(node, subgraphBatchNorm, inputs, fusedNodes) or \
checkSubgraph(node, subgraphBatchNormV2, inputs, fusedNodes):
name = node.name name = node.name
node.Clear() node.Clear()
node.name = name node.name = name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment