Commit 5620306c authored by Feng Chen's avatar Feng Chen Committed by Alexander Alekhin

Merge pull request #14845 from vonchenplus:ocv_mirrorpad

* tensorflow support mirror pad

* revert macro define

* revert macro define

* reduce code duplication

* revert macro define
parent f8c96cb1
...@@ -792,7 +792,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -792,7 +792,7 @@ void TFImporter::populateNet(Net dstNet)
int predictedLayout = predictOutputDataLayout(net, layer, data_layouts); int predictedLayout = predictOutputDataLayout(net, layer, data_layouts);
data_layouts[name] = predictedLayout; data_layouts[name] = predictedLayout;
if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative" || type == "Pad" || type == "Conv3D") if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative" || type == "Pad" || type == "MirrorPad" || type == "Conv3D")
{ {
// The first node of dilated convolution subgraph. // The first node of dilated convolution subgraph.
// Extract input node, dilation rate and paddings. // Extract input node, dilation rate and paddings.
...@@ -804,6 +804,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -804,6 +804,7 @@ void TFImporter::populateNet(Net dstNet)
if (next_layers.empty()) if (next_layers.empty())
next_layers = getNextLayers(net, name, "DepthwiseConv2dNative"); next_layers = getNextLayers(net, name, "DepthwiseConv2dNative");
} }
if (type == "SpaceToBatchND") if (type == "SpaceToBatchND")
{ {
// op: "SpaceToBatchND" // op: "SpaceToBatchND"
...@@ -830,7 +831,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -830,7 +831,7 @@ void TFImporter::populateNet(Net dstNet)
name = layer.name(); name = layer.name();
type = layer.op(); type = layer.op();
} }
else if (type == "Pad") else if (type == "Pad" || type == "MirrorPad")
{ {
Mat paddings = getTensorContent(getConstBlob(layer, value_id, 1)); Mat paddings = getTensorContent(getConstBlob(layer, value_id, 1));
CV_Assert(paddings.type() == CV_32SC1); CV_Assert(paddings.type() == CV_32SC1);
...@@ -848,12 +849,15 @@ void TFImporter::populateNet(Net dstNet) ...@@ -848,12 +849,15 @@ void TFImporter::populateNet(Net dstNet)
// N C H W // N C H W
// 0 1 2 3 4 5 6 7 // 0 1 2 3 4 5 6 7
} }
if (next_layers.empty() || paddings.total() != 8 || if (next_layers.empty() || paddings.total() != 8 ||
paddings.at<int32_t>(4) != paddings.at<int32_t>(5) || paddings.at<int32_t>(4) != paddings.at<int32_t>(5) ||
paddings.at<int32_t>(6) != paddings.at<int32_t>(7)) paddings.at<int32_t>(6) != paddings.at<int32_t>(7) || type == "MirrorPad")
{ {
// Just a single padding layer. // Just a single padding layer.
layerParams.set("paddings", DictValue::arrayInt<int*>((int*)paddings.data, paddings.total())); layerParams.set("paddings", DictValue::arrayInt<int*>((int*)paddings.data, paddings.total()));
if (type == "MirrorPad")
layerParams.set("type", "reflect");
int id = dstNet.addLayer(name, "Padding", layerParams); int id = dstNet.addLayer(name, "Padding", layerParams);
layer_id[name] = id; layer_id[name] = id;
......
...@@ -146,6 +146,7 @@ TEST_P(Test_TensorFlow_layers, padding) ...@@ -146,6 +146,7 @@ TEST_P(Test_TensorFlow_layers, padding)
runTensorFlowNet("padding_valid"); runTensorFlowNet("padding_valid");
runTensorFlowNet("spatial_padding"); runTensorFlowNet("spatial_padding");
runTensorFlowNet("keras_pad_concat"); runTensorFlowNet("keras_pad_concat");
runTensorFlowNet("mirror_pad");
} }
TEST_P(Test_TensorFlow_layers, padding_same) TEST_P(Test_TensorFlow_layers, padding_same)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment