Commit 0ce7c33b authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Torch's Concat and ConcatTable doesn't use Split layer

parent 8ffa2947
......@@ -75,7 +75,7 @@ public:
Layer::getMemoryShapes(inputs, max(1, outputsCount >= 0 ? outputsCount : requiredOutputs),
outputs, internals);
return true;
return false;
}
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
......@@ -86,7 +86,6 @@ public:
for (size_t i = 0; i < outputs.size(); i++)
{
CV_Assert(inputs[0]->total() == outputs[i].total());
if (outputs[i].data != inputs[0]->data)
inputs[0]->copyTo(outputs[i]);
}
}
......
......@@ -827,20 +827,18 @@ struct TorchImporter : public ::cv::dnn::Importer
}
else if (module->thName == "Concat")
{
int newId, splitId, mergeId;
LayerParams mergeParams, splitParams;
int newId, mergeId;
LayerParams mergeParams;
mergeParams.set("axis", module->params.get<int>("dimension") - 1);
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
net.connect(prevLayerId, prevOutNum, splitId, 0);
std::vector<int> branchIds;
for (int i = 0; i < (int)module->modules.size(); i++)
{
newId = fill(module->modules[i], addedModules, splitId, i);
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
branchIds.push_back(newId);
}
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
for (int i = 0; i < branchIds.size(); i++)
......@@ -884,19 +882,12 @@ struct TorchImporter : public ::cv::dnn::Importer
return mergeId;
}
else if (module->thName == "ConcatTable") {
int newId = -1, splitId;
LayerParams splitParams;
splitId = net.addLayer(generateLayerName("torchSplit"), "Split", splitParams);
net.connect(prevLayerId, prevOutNum, splitId, 0);
addedModules.push_back(std::make_pair(splitId, module));
int newId = -1;
moduleCounter += 1; // Skip split layer creation. See https://github.com/opencv/opencv/pull/9384.
for (int i = 0; i < (int)module->modules.size(); i++)
{
newId = fill(module->modules[i], addedModules, splitId, i);
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
}
return newId;
}
else if (module->thName == "JoinTable") {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment