Unverified Commit a6f3a212 authored by Gagandeep Singh's avatar Gagandeep Singh Committed by GitHub

Merge pull request #16424 from czgdp1807:issue-16370

* fixed Split layer in ONNXImporter

* added test for fix of split layer

* fixed tests for Split layer

* applied reviews

* updated tests

* fixed paths in tests
parent a8c257ce
...@@ -485,16 +485,23 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -485,16 +485,23 @@ void ONNXImporter::populateNet(Net dstNet)
} }
else if (layer_type == "Split") else if (layer_type == "Split")
{ {
DictValue splits = layerParams.get("split"); if (layerParams.has("split"))
const int numSplits = splits.size(); {
CV_Assert(numSplits > 1); DictValue splits = layerParams.get("split");
const int numSplits = splits.size();
CV_Assert(numSplits > 1);
std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0)); std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
for (int i = 1; i < splits.size() - 1; ++i) for (int i = 1; i < splits.size() - 1; ++i)
{
slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i - 1);
}
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
}
else
{ {
slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i - 1); layerParams.set("num_split", node_proto.output_size());
} }
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
layerParams.type = "Slice"; layerParams.type = "Slice";
} }
else if (layer_type == "Add" || layer_type == "Sum") else if (layer_type == "Add" || layer_type == "Sum")
......
...@@ -386,6 +386,18 @@ TEST_P(Test_ONNX_layers, ReduceL2) ...@@ -386,6 +386,18 @@ TEST_P(Test_ONNX_layers, ReduceL2)
testONNXModels("reduceL2"); testONNXModels("reduceL2");
} }
TEST_P(Test_ONNX_layers, Split)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
testONNXModels("split_1");
testONNXModels("split_2");
testONNXModels("split_3");
testONNXModels("split_4");
}
TEST_P(Test_ONNX_layers, Slice) TEST_P(Test_ONNX_layers, Slice)
{ {
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2019010000) #if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2019010000)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment