Commit c12e26ff authored by Alexander Alekhin's avatar Alexander Alekhin

Merge pull request #15071 from l-bat:tf_split

parents e4e0bb53 0d2bc7b5
...@@ -366,6 +366,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN ...@@ -366,6 +366,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
*/ */
std::vector<std::vector<Range> > sliceRanges; std::vector<std::vector<Range> > sliceRanges;
int axis; int axis;
int num_split;
static Ptr<SliceLayer> create(const LayerParams &params); static Ptr<SliceLayer> create(const LayerParams &params);
}; };
......
...@@ -61,6 +61,7 @@ public: ...@@ -61,6 +61,7 @@ public:
{ {
setParamsFrom(params); setParamsFrom(params);
axis = params.get<int>("axis", 1); axis = params.get<int>("axis", 1);
num_split = params.get<int>("num_split", 0);
if (params.has("slice_point")) if (params.has("slice_point"))
{ {
CV_Assert(!params.has("begin") && !params.has("size") && !params.has("end")); CV_Assert(!params.has("begin") && !params.has("size") && !params.has("end"));
...@@ -141,9 +142,10 @@ public: ...@@ -141,9 +142,10 @@ public:
else // Divide input blob on equal parts by axis. else // Divide input blob on equal parts by axis.
{ {
CV_Assert(0 <= axis && axis < inpShape.size()); CV_Assert(0 <= axis && axis < inpShape.size());
CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0); int splits = num_split ? num_split : requiredOutputs;
inpShape[axis] /= requiredOutputs; CV_Assert(splits > 0 && inpShape[axis] % splits == 0);
outputs.resize(requiredOutputs, inpShape); inpShape[axis] /= splits;
outputs.resize(splits, inpShape);
} }
return false; return false;
} }
......
...@@ -1410,6 +1410,9 @@ void TFImporter::populateNet(Net dstNet) ...@@ -1410,6 +1410,9 @@ void TFImporter::populateNet(Net dstNet)
axis = toNCHW(axis); axis = toNCHW(axis);
layerParams.set("axis", axis); layerParams.set("axis", axis);
if (hasLayerAttr(layer, "num_split"))
layerParams.set("num_split", getLayerAttr(layer, "num_split").i());
int id = dstNet.addLayer(name, "Slice", layerParams); int id = dstNet.addLayer(name, "Slice", layerParams);
layer_id[name] = id; layer_id[name] = id;
......
...@@ -350,6 +350,11 @@ TEST_P(Test_TensorFlow_layers, l2_normalize_3d) ...@@ -350,6 +350,11 @@ TEST_P(Test_TensorFlow_layers, l2_normalize_3d)
runTensorFlowNet("l2_normalize_3d"); runTensorFlowNet("l2_normalize_3d");
} }
TEST_P(Test_TensorFlow_layers, Split)
{
runTensorFlowNet("split");
}
class Test_TensorFlow_nets : public DNNTestLayer {}; class Test_TensorFlow_nets : public DNNTestLayer {};
TEST_P(Test_TensorFlow_nets, MobileNet_SSD) TEST_P(Test_TensorFlow_nets, MobileNet_SSD)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment