Commit 9510551c authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Multiple inputs for TensorFlow models

parent ab8022f7
...@@ -375,6 +375,8 @@ private: ...@@ -375,6 +375,8 @@ private:
// and may be used to build the network using binary format only as a weights storage. // and may be used to build the network using binary format only as a weights storage.
// This approach is similar to Caffe's `.prorotxt` and `.caffemodel`. // This approach is similar to Caffe's `.prorotxt` and `.caffemodel`.
tensorflow::GraphDef netTxt; tensorflow::GraphDef netTxt;
std::vector<String> netInputsNames;
}; };
TFImporter::TFImporter(const char *model, const char *config) TFImporter::TFImporter(const char *model, const char *config)
...@@ -442,7 +444,14 @@ void TFImporter::connect(const std::map<String, int>& layers_name_id_map, Net& n ...@@ -442,7 +444,14 @@ void TFImporter::connect(const std::map<String, int>& layers_name_id_map, Net& n
std::map<String, int>::const_iterator it = layers_name_id_map.find(outPin.name); std::map<String, int>::const_iterator it = layers_name_id_map.find(outPin.name);
if (it == layers_name_id_map.end()) if (it == layers_name_id_map.end())
CV_Error(Error::StsError, "Input layer not found: " + outPin.name); CV_Error(Error::StsError, "Input layer not found: " + outPin.name);
network.connect(it->second, outPin.blobIndex, input_layer_id, input_blob_id);
std::vector<String>::iterator inpNameIt = std::find(netInputsNames.begin(), netInputsNames.end(), outPin.name);
int blobIndex;
if (inpNameIt == netInputsNames.end())
blobIndex = outPin.blobIndex;
else
blobIndex = inpNameIt - netInputsNames.begin();
network.connect(it->second, blobIndex, input_layer_id, input_blob_id);
} }
void TFImporter::connectToAllBlobs(const std::map<String, int>& layer_id, Net& network, const Pin& outPin, void TFImporter::connectToAllBlobs(const std::map<String, int>& layer_id, Net& network, const Pin& outPin,
...@@ -778,7 +787,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -778,7 +787,7 @@ void TFImporter::populateNet(Net dstNet)
Pin inp = parsePin(layer.input(ii)); Pin inp = parsePin(layer.input(ii));
if (layer_id.find(inp.name) == layer_id.end()) if (layer_id.find(inp.name) == layer_id.end())
CV_Error(Error::StsError, "Input layer not found: " + inp.name); CV_Error(Error::StsError, "Input layer not found: " + inp.name);
dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii); connect(layer_id, dstNet, inp, id, ii);
} }
} }
} }
...@@ -1028,7 +1037,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -1028,7 +1037,7 @@ void TFImporter::populateNet(Net dstNet)
Pin inp = parsePin(layer.input(ii)); Pin inp = parsePin(layer.input(ii));
if (layer_id.find(inp.name) == layer_id.end()) if (layer_id.find(inp.name) == layer_id.end())
CV_Error(Error::StsError, "Input layer not found: " + inp.name); CV_Error(Error::StsError, "Input layer not found: " + inp.name);
dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii - from); connect(layer_id, dstNet, inp, id, ii - from);
} }
} }
else if (type == "MaxPool") else if (type == "MaxPool")
...@@ -1060,10 +1069,12 @@ void TFImporter::populateNet(Net dstNet) ...@@ -1060,10 +1069,12 @@ void TFImporter::populateNet(Net dstNet)
} }
else if (type == "Placeholder") else if (type == "Placeholder")
{ {
std::vector<String> netInputs(1); if (!hasLayerAttr(layer, "dtype") ||
netInputs[0] = name; getLayerAttr(layer, "dtype").type() != tensorflow::DT_BOOL) // If input is not a train/test flag.
layer_id[name] = 0; {
dstNet.setInputsNames(netInputs); netInputsNames.push_back(name);
layer_id[name] = 0;
}
} }
else if (type == "Split") { else if (type == "Split") {
// TODO: determining axis index remapping by input dimensions order of input blob // TODO: determining axis index remapping by input dimensions order of input blob
...@@ -1201,7 +1212,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -1201,7 +1212,7 @@ void TFImporter::populateNet(Net dstNet)
Pin inp = parsePin(layer.input(ii)); Pin inp = parsePin(layer.input(ii));
if (layer_id.find(inp.name) == layer_id.end()) if (layer_id.find(inp.name) == layer_id.end())
CV_Error(Error::StsError, "Input layer not found: " + inp.name); CV_Error(Error::StsError, "Input layer not found: " + inp.name);
dstNet.connect(layer_id.at(inp.name), inp.blobIndex, id, ii); connect(layer_id, dstNet, inp, id, ii);
} }
} }
} }
...@@ -1719,6 +1730,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -1719,6 +1730,7 @@ void TFImporter::populateNet(Net dstNet)
} }
} }
} }
dstNet.setInputsNames(netInputsNames);
} }
} // namespace } // namespace
......
...@@ -440,4 +440,20 @@ TEST(Test_TensorFlow, resize_bilinear) ...@@ -440,4 +440,20 @@ TEST(Test_TensorFlow, resize_bilinear)
runTensorFlowNet("resize_bilinear_factor"); runTensorFlowNet("resize_bilinear_factor");
} }
TEST(Test_TensorFlow, two_inputs)
{
Net net = readNet(path("two_inputs_net.pbtxt"));
net.setPreferableBackend(DNN_BACKEND_OPENCV);
Mat firstInput(2, 3, CV_32FC1), secondInput(2, 3, CV_32FC1);
randu(firstInput, -1, 1);
randu(secondInput, -1, 1);
net.setInput(firstInput, "first_input");
net.setInput(secondInput, "second_input");
Mat out = net.forward();
normAssert(out, firstInput + secondInput);
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment