Commit 6bf8fe81 authored by Vadim Pisarevsky's avatar Vadim Pisarevsky

Merge pull request #9384 from dkurt:torch_split

parents a391871c 0ce7c33b
......@@ -75,7 +75,7 @@ public:
Layer::getMemoryShapes(inputs, max(1, outputsCount >= 0 ? outputsCount : requiredOutputs),
outputs, internals);
return true;
return false;
}
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
......@@ -86,8 +86,7 @@ public:
for (size_t i = 0; i < outputs.size(); i++)
{
CV_Assert(inputs[0]->total() == outputs[i].total());
if (outputs[i].data != inputs[0]->data)
inputs[0]->copyTo(outputs[i]);
inputs[0]->copyTo(outputs[i]);
}
}
};
......
......@@ -934,20 +934,18 @@ struct TorchImporter : public ::cv::dnn::Importer
}
else if (module->thName == "Concat")
{
int newId, splitId, mergeId;
LayerParams mergeParams, splitParams;
int newId, mergeId;
LayerParams mergeParams;
mergeParams.set("axis", module->params.get<int>("dimension") - 1);
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
net.connect(prevLayerId, prevOutNum, splitId, 0);
std::vector<int> branchIds;
for (int i = 0; i < (int)module->modules.size(); i++)
{
newId = fill(module->modules[i], addedModules, splitId, i);
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
branchIds.push_back(newId);
}
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
for (int i = 0; i < branchIds.size(); i++)
......@@ -1015,19 +1013,12 @@ struct TorchImporter : public ::cv::dnn::Importer
return mergeId;
}
else if (module->thName == "ConcatTable") {
int newId = -1, splitId;
LayerParams splitParams;
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
net.connect(prevLayerId, prevOutNum, splitId, 0);
addedModules.push_back(std::make_pair(splitId, module));
int newId = -1;
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
for (int i = 0; i < (int)module->modules.size(); i++)
{
newId = fill(module->modules[i], addedModules, splitId, i);
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
}
return newId;
}
else if (module->thName == "JoinTable") {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment