Commit 2c4d3d92 authored by Vadim Pisarevsky's avatar Vadim Pisarevsky

Merge pull request #10221 from dkurt:non_spatial_torch_layers

parents 0042bacd bbbec300
...@@ -119,8 +119,9 @@ public: ...@@ -119,8 +119,9 @@ public:
CV_Assert(inputs.size() == 1); CV_Assert(inputs.size() == 1);
Mat &inpBlob = *inputs[0]; Mat &inpBlob = *inputs[0];
int rows = inpBlob.size[2]; CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);
int cols = inpBlob.size[3]; int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1;
int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1;
for (size_t ii = 0; ii < outputs.size(); ii++) for (size_t ii = 0; ii < outputs.size(); ii++)
{ {
......
...@@ -617,7 +617,8 @@ struct TorchImporter : public ::cv::dnn::Importer ...@@ -617,7 +617,8 @@ struct TorchImporter : public ::cv::dnn::Importer
curModule->modules.push_back(cv::Ptr<Module>(new Module(nnName, "Sigmoid"))); curModule->modules.push_back(cv::Ptr<Module>(new Module(nnName, "Sigmoid")));
readObject(); readObject();
} }
else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization") else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization" ||
nnName == "BatchNormalization")
{ {
newModule->apiType = "BatchNorm"; newModule->apiType = "BatchNorm";
readTorchTable(scalarParams, tensorParams); readTorchTable(scalarParams, tensorParams);
...@@ -700,17 +701,24 @@ struct TorchImporter : public ::cv::dnn::Importer ...@@ -700,17 +701,24 @@ struct TorchImporter : public ::cv::dnn::Importer
curModule->modules.push_back(newModule); curModule->modules.push_back(newModule);
} }
else if (nnName == "SpatialDropout") else if (nnName == "SpatialDropout" || nnName == "Dropout")
{ {
readTorchTable(scalarParams, tensorParams); readTorchTable(scalarParams, tensorParams);
CV_Assert(scalarParams.has("p")); CV_Assert(scalarParams.has("p"));
float scale = 1 - scalarParams.get<double>("p"); if (scalarParams.has("v2") && scalarParams.get<bool>("v2"))
{
newModule->apiType = "Identity";
}
else
{
float scale = 1 - scalarParams.get<double>("p");
CV_Assert(scale > 0); CV_Assert(scale > 0);
newModule->apiType = "Power"; newModule->apiType = "Power";
layerParams.set("scale", scale); layerParams.set("scale", scale);
}
curModule->modules.push_back(newModule); curModule->modules.push_back(newModule);
} }
// TotalVariation layer is from fast-neural-style project: https://github.com/jcjohnson/fast-neural-style // TotalVariation layer is from fast-neural-style project: https://github.com/jcjohnson/fast-neural-style
......
...@@ -234,6 +234,11 @@ TEST(Torch_Importer, net_padding) ...@@ -234,6 +234,11 @@ TEST(Torch_Importer, net_padding)
runTorchNet("net_spatial_reflection_padding", DNN_TARGET_CPU, "", false, true); runTorchNet("net_spatial_reflection_padding", DNN_TARGET_CPU, "", false, true);
} }
TEST(Torch_Importer, net_non_spatial)
{
runTorchNet("net_non_spatial", DNN_TARGET_CPU, "", false, true);
}
TEST(Torch_Importer, ENet_accuracy) TEST(Torch_Importer, ENet_accuracy)
{ {
Net net; Net net;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment