Commit 2c4d3d92 authored by Vadim Pisarevsky's avatar Vadim Pisarevsky

Merge pull request #10221 from dkurt:non_spatial_torch_layers

parents 0042bacd bbbec300
......@@ -119,8 +119,9 @@ public:
CV_Assert(inputs.size() == 1);
Mat &inpBlob = *inputs[0];
int rows = inpBlob.size[2];
int cols = inpBlob.size[3];
CV_Assert(inpBlob.dims == 2 || inpBlob.dims == 4);
int rows = inpBlob.dims > 2 ? inpBlob.size[2] : 1;
int cols = inpBlob.dims > 2 ? inpBlob.size[3] : 1;
for (size_t ii = 0; ii < outputs.size(); ii++)
{
......
......@@ -617,7 +617,8 @@ struct TorchImporter : public ::cv::dnn::Importer
curModule->modules.push_back(cv::Ptr<Module>(new Module(nnName, "Sigmoid")));
readObject();
}
else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization")
else if (nnName == "SpatialBatchNormalization" || nnName == "InstanceNormalization" ||
nnName == "BatchNormalization")
{
newModule->apiType = "BatchNorm";
readTorchTable(scalarParams, tensorParams);
......@@ -700,17 +701,24 @@ struct TorchImporter : public ::cv::dnn::Importer
curModule->modules.push_back(newModule);
}
else if (nnName == "SpatialDropout")
else if (nnName == "SpatialDropout" || nnName == "Dropout")
{
readTorchTable(scalarParams, tensorParams);
CV_Assert(scalarParams.has("p"));
float scale = 1 - scalarParams.get<double>("p");
if (scalarParams.has("v2") && scalarParams.get<bool>("v2"))
{
newModule->apiType = "Identity";
}
else
{
float scale = 1 - scalarParams.get<double>("p");
CV_Assert(scale > 0);
CV_Assert(scale > 0);
newModule->apiType = "Power";
layerParams.set("scale", scale);
newModule->apiType = "Power";
layerParams.set("scale", scale);
}
curModule->modules.push_back(newModule);
}
// TotalVariation layer is from fast-neural-style project: https://github.com/jcjohnson/fast-neural-style
......
......@@ -234,6 +234,11 @@ TEST(Torch_Importer, net_padding)
runTorchNet("net_spatial_reflection_padding", DNN_TARGET_CPU, "", false, true);
}
TEST(Torch_Importer, net_non_spatial)
{
runTorchNet("net_non_spatial", DNN_TARGET_CPU, "", false, true);
}
TEST(Torch_Importer, ENet_accuracy)
{
Net net;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment