Commit f9f16040 authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Add support for slice from ONNX with multiple outputs

parent 2693ed9b
...@@ -465,6 +465,20 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -465,6 +465,20 @@ void ONNXImporter::populateNet(Net dstNet)
} }
layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size())); layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
layerParams.set("end", DictValue::arrayInt(&end[0], end.size())); layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
}
else if (layer_type == "Split")
{
DictValue splits = layerParams.get("split");
const int numSplits = splits.size();
CV_Assert(numSplits > 1);
std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
for (int i = 1; i < splits.size() - 1; ++i)
{
slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i - 1);
}
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
layerParams.type = "Slice";
} }
else if (layer_type == "Add" || layer_type == "Sum") else if (layer_type == "Add" || layer_type == "Sum")
{ {
...@@ -486,6 +500,11 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -486,6 +500,11 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.type = "Eltwise"; layerParams.type = "Eltwise";
} }
} }
else if (layer_type == "Max")
{
layerParams.type = "Eltwise";
layerParams.set("operation", "max");
}
else if (layer_type == "Sub") else if (layer_type == "Sub")
{ {
Mat blob = getBlob(node_proto, constBlobs, 1); Mat blob = getBlob(node_proto, constBlobs, 1);
...@@ -741,6 +760,16 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -741,6 +760,16 @@ void ONNXImporter::populateNet(Net dstNet)
{ {
layerParams.type = "Permute"; layerParams.type = "Permute";
replaceLayerParam(layerParams, "perm", "order"); replaceLayerParam(layerParams, "perm", "order");
CV_Assert(node_proto.input_size() == 1);
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), transposed;
runLayer(layerParams, inputs, transposed);
CV_Assert(transposed.size() == 1);
constBlobs.insert(std::make_pair(layerParams.name, transposed[0]));
continue;
}
} }
else if (layer_type == "Unsqueeze") else if (layer_type == "Unsqueeze")
{ {
...@@ -906,8 +935,10 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -906,8 +935,10 @@ void ONNXImporter::populateNet(Net dstNet)
} }
int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams); int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
layer_id.insert(std::make_pair(layerParams.name, LayerInfo(id, 0))); for (int i = 0; i < node_proto.output_size(); ++i)
{
layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
}
std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes; std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
for (int j = 0; j < node_proto.input_size(); j++) { for (int j = 0; j < node_proto.input_size(); j++) {
...@@ -924,8 +955,10 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -924,8 +955,10 @@ void ONNXImporter::populateNet(Net dstNet)
// Compute shape of output blob for this layer. // Compute shape of output blob for this layer.
Ptr<Layer> layer = dstNet.getLayer(id); Ptr<Layer> layer = dstNet.getLayer(id);
layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes); layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
CV_Assert(!layerOutShapes.empty()); for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
outShapes[layerParams.name] = layerOutShapes[0]; {
outShapes[node_proto.output(i)] = layerOutShapes[i];
}
} }
} }
......
...@@ -348,6 +348,13 @@ TEST_P(Test_ONNX_layers, Softmax) ...@@ -348,6 +348,13 @@ TEST_P(Test_ONNX_layers, Softmax)
testONNXModels("log_softmax", npy, 0, 0, false, false); testONNXModels("log_softmax", npy, 0, 0, false, false);
} }
TEST_P(Test_ONNX_layers, Split_EltwiseMax)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE);
testONNXModels("split_max");
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets()); INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
class Test_ONNX_nets : public Test_ONNX_layers class Test_ONNX_nets : public Test_ONNX_layers
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment