Commit 6bf8fe81 authored by Vadim Pisarevsky's avatar Vadim Pisarevsky

Merge pull request #9384 from dkurt:torch_split

parents a391871c 0ce7c33b
...@@ -75,7 +75,7 @@ public: ...@@ -75,7 +75,7 @@ public:
Layer::getMemoryShapes(inputs, max(1, outputsCount >= 0 ? outputsCount : requiredOutputs), Layer::getMemoryShapes(inputs, max(1, outputsCount >= 0 ? outputsCount : requiredOutputs),
outputs, internals); outputs, internals);
return true; return false;
} }
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals) void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
...@@ -86,8 +86,7 @@ public: ...@@ -86,8 +86,7 @@ public:
for (size_t i = 0; i < outputs.size(); i++) for (size_t i = 0; i < outputs.size(); i++)
{ {
CV_Assert(inputs[0]->total() == outputs[i].total()); CV_Assert(inputs[0]->total() == outputs[i].total());
if (outputs[i].data != inputs[0]->data) inputs[0]->copyTo(outputs[i]);
inputs[0]->copyTo(outputs[i]);
} }
} }
}; };
......
...@@ -934,20 +934,18 @@ struct TorchImporter : public ::cv::dnn::Importer ...@@ -934,20 +934,18 @@ struct TorchImporter : public ::cv::dnn::Importer
} }
else if (module->thName == "Concat") else if (module->thName == "Concat")
{ {
int newId, splitId, mergeId; int newId, mergeId;
LayerParams mergeParams, splitParams; LayerParams mergeParams;
mergeParams.set("axis", module->params.get<int>("dimension") - 1); mergeParams.set("axis", module->params.get<int>("dimension") - 1);
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
net.connect(prevLayerId, prevOutNum, splitId, 0);
std::vector<int> branchIds; std::vector<int> branchIds;
for (int i = 0; i < (int)module->modules.size(); i++) for (int i = 0; i < (int)module->modules.size(); i++)
{ {
newId = fill(module->modules[i], addedModules, splitId, i); newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
branchIds.push_back(newId); branchIds.push_back(newId);
} }
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams); mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
for (int i = 0; i < branchIds.size(); i++) for (int i = 0; i < branchIds.size(); i++)
...@@ -1015,19 +1013,12 @@ struct TorchImporter : public ::cv::dnn::Importer ...@@ -1015,19 +1013,12 @@ struct TorchImporter : public ::cv::dnn::Importer
return mergeId; return mergeId;
} }
else if (module->thName == "ConcatTable") { else if (module->thName == "ConcatTable") {
int newId = -1, splitId; int newId = -1;
LayerParams splitParams; moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
net.connect(prevLayerId, prevOutNum, splitId, 0);
addedModules.push_back(std::make_pair(splitId, module));
for (int i = 0; i < (int)module->modules.size(); i++) for (int i = 0; i < (int)module->modules.size(); i++)
{ {
newId = fill(module->modules[i], addedModules, splitId, i); newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
} }
return newId; return newId;
} }
else if (module->thName == "JoinTable") { else if (module->thName == "JoinTable") {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment