Commit 081d9bc7 authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Fix Identity Switch from Keras

parent e07ffe90
...@@ -601,7 +601,7 @@ public: ...@@ -601,7 +601,7 @@ public:
class UpsamplingKerasSubgraph : public Subgraph class UpsamplingKerasSubgraph : public Subgraph
{ {
public: public:
UpsamplingKerasSubgraph() UpsamplingKerasSubgraph(const std::string& type)
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int shape = addNodeToMatch("Shape", input); int shape = addNodeToMatch("Shape", input);
...@@ -611,8 +611,8 @@ public: ...@@ -611,8 +611,8 @@ public:
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
int factors = addNodeToMatch("Const"); int factors = addNodeToMatch("Const");
int mul = addNodeToMatch("Mul", strided_slice, factors); int mul = addNodeToMatch("Mul", strided_slice, factors);
addNodeToMatch("ResizeNearestNeighbor", input, mul); addNodeToMatch(type, input, mul);
setFusedNode("ResizeNearestNeighbor", input, factors); setFusedNode(type, input, factors);
} }
virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode, virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
...@@ -707,7 +707,8 @@ void simplifySubgraphs(tensorflow::GraphDef& net) ...@@ -707,7 +707,8 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph())); subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph())); subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph())); subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph())); subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph("ResizeNearestNeighbor")));
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph("ResizeBilinear")));
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph())); subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph())); subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph()));
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph())); subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
...@@ -752,6 +753,8 @@ void RemoveIdentityOps(tensorflow::GraphDef& net) ...@@ -752,6 +753,8 @@ void RemoveIdentityOps(tensorflow::GraphDef& net)
tensorflow::NodeDef* layer = net.mutable_node(li); tensorflow::NodeDef* layer = net.mutable_node(li);
for (int input_id = 0; input_id < layer->input_size(); input_id++) { for (int input_id = 0; input_id < layer->input_size(); input_id++) {
String input_op_name = layer->input(input_id); String input_op_name = layer->input(input_id);
input_op_name = input_op_name.substr(input_op_name.find('^') + 1,
input_op_name.rfind(':'));
IdentityOpsMap::iterator it = identity_ops.find(input_op_name); IdentityOpsMap::iterator it = identity_ops.find(input_op_name);
if (it != identity_ops.end()) { if (it != identity_ops.end()) {
......
...@@ -186,6 +186,7 @@ TEST_P(Test_TensorFlow_layers, batch_norm) ...@@ -186,6 +186,7 @@ TEST_P(Test_TensorFlow_layers, batch_norm)
runTensorFlowNet("unfused_batch_norm_no_gamma"); runTensorFlowNet("unfused_batch_norm_no_gamma");
runTensorFlowNet("mvn_batch_norm"); runTensorFlowNet("mvn_batch_norm");
runTensorFlowNet("mvn_batch_norm_1x1"); runTensorFlowNet("mvn_batch_norm_1x1");
runTensorFlowNet("switch_identity");
} }
TEST_P(Test_TensorFlow_layers, batch_norm3D) TEST_P(Test_TensorFlow_layers, batch_norm3D)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment