Commit 099a16bd authored by Alexander Alekhin's avatar Alexander Alekhin

Merge pull request #11198 from dkurt:torch_embedded_residuals

parents dbcb4549 598039c0
...@@ -101,6 +101,8 @@ struct TorchImporter ...@@ -101,6 +101,8 @@ struct TorchImporter
std::set<int> readedIndexes; std::set<int> readedIndexes;
std::map<int, Mat> storages; std::map<int, Mat> storages;
std::map<int, Mat> tensors; std::map<int, Mat> tensors;
// Stack with numbers of unconnected layers per scope (Sequential, ConcatTable etc.)
std::vector<int> numUnconnectedLayers;
struct Module struct Module
{ {
...@@ -489,15 +491,7 @@ struct TorchImporter ...@@ -489,15 +491,7 @@ struct TorchImporter
layerParams.set("inputDimension", scalarParams.get<int>("inputDimension")); layerParams.set("inputDimension", scalarParams.get<int>("inputDimension"));
layerParams.set("outputDimension", scalarParams.get<int>("outputDimension")); layerParams.set("outputDimension", scalarParams.get<int>("outputDimension"));
} }
if (nnName == "Concat") else if (nnName == "Concat" || nnName == "JoinTable" || nnName == "DepthConcat")
{
layerParams.set("dimension", scalarParams.get<int>("dimension"));
}
if (nnName == "JoinTable")
{
layerParams.set("dimension", scalarParams.get<int>("dimension"));
}
if (nnName == "DepthConcat")
{ {
layerParams.set("dimension", scalarParams.get<int>("dimension")); layerParams.set("dimension", scalarParams.get<int>("dimension"));
} }
...@@ -1096,6 +1090,7 @@ struct TorchImporter ...@@ -1096,6 +1090,7 @@ struct TorchImporter
{ {
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum); newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
} }
numUnconnectedLayers.push_back(module->modules.size());
return newId; return newId;
} }
else if (module->thName == "JoinTable") { else if (module->thName == "JoinTable") {
...@@ -1108,9 +1103,14 @@ struct TorchImporter ...@@ -1108,9 +1103,14 @@ struct TorchImporter
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams); mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
addedModules.push_back(std::make_pair(mergeId, module)); addedModules.push_back(std::make_pair(mergeId, module));
for (int i = 0; i < ids.size(); i++) // Connect to the last number of unconnected layers.
CV_Assert(!numUnconnectedLayers.empty());
const int numInputs = numUnconnectedLayers.back();
numUnconnectedLayers.pop_back();
CV_Assert(numInputs <= ids.size());
for (int i = 0; i < numInputs; i++)
{ {
net.connect(ids[i], 0, mergeId, i); net.connect(ids[ids.size() - numInputs + i], 0, mergeId, i);
} }
return mergeId; return mergeId;
...@@ -1124,9 +1124,14 @@ struct TorchImporter ...@@ -1124,9 +1124,14 @@ struct TorchImporter
int id = net.addLayer(name, "Eltwise", params); int id = net.addLayer(name, "Eltwise", params);
for (int i = 0; i < ids.size(); i++) // Connect to the last number of unconnected layers.
CV_Assert(!numUnconnectedLayers.empty());
const int numInputs = numUnconnectedLayers.back();
numUnconnectedLayers.pop_back();
CV_Assert(numInputs <= ids.size());
for (int i = 0; i < numInputs; i++)
{ {
net.connect(ids[i], 0, id, i); net.connect(ids[ids.size() - numInputs + i], 0, id, i);
} }
addedModules.push_back(std::make_pair(id, module)); addedModules.push_back(std::make_pair(id, module));
......
...@@ -320,4 +320,9 @@ TEST(Torch_Importer, DISABLED_run_paralel) ...@@ -320,4 +320,9 @@ TEST(Torch_Importer, DISABLED_run_paralel)
runTorchNet("net_parallel", DNN_TARGET_OPENCL, "l5_torchMerge"); runTorchNet("net_parallel", DNN_TARGET_OPENCL, "l5_torchMerge");
} }
TEST(Torch_Importer, net_residual)
{
runTorchNet("net_residual", DNN_TARGET_CPU, "", false, true);
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment