Commit 78ff9d93 authored by dkurt's avatar dkurt

Import SoftMax, LogSoftMax layers from Torch

parent 9638b145
...@@ -251,6 +251,8 @@ namespace dnn ...@@ -251,6 +251,8 @@ namespace dnn
class CV_EXPORTS SoftmaxLayer : public Layer class CV_EXPORTS SoftmaxLayer : public Layer
{ {
public: public:
bool logSoftMax;
static Ptr<SoftmaxLayer> create(const LayerParams& params); static Ptr<SoftmaxLayer> create(const LayerParams& params);
}; };
......
...@@ -436,6 +436,7 @@ namespace dnn //! This namespace is used for dnn module functionlaity. ...@@ -436,6 +436,7 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
* - nn.SpatialMaxPooling, nn.SpatialAveragePooling * - nn.SpatialMaxPooling, nn.SpatialAveragePooling
* - nn.ReLU, nn.TanH, nn.Sigmoid * - nn.ReLU, nn.TanH, nn.Sigmoid
* - nn.Reshape * - nn.Reshape
* - nn.SoftMax, nn.LogSoftMax
* *
* Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported. * Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported.
*/ */
......
...@@ -57,6 +57,7 @@ public: ...@@ -57,6 +57,7 @@ public:
SoftMaxLayerImpl(const LayerParams& params) SoftMaxLayerImpl(const LayerParams& params)
{ {
axisRaw = params.get<int>("axis", 1); axisRaw = params.get<int>("axis", 1);
logSoftMax = params.get<int>("log_softmax", false);
setParamsFrom(params); setParamsFrom(params);
} }
...@@ -143,6 +144,14 @@ public: ...@@ -143,6 +144,14 @@ public:
for (size_t i = 0; i < innerSize; i++) for (size_t i = 0; i < innerSize; i++)
dstPtr[srcOffset + cnDim * cnStep + i] /= bufPtr[bufOffset + i]; dstPtr[srcOffset + cnDim * cnStep + i] /= bufPtr[bufOffset + i];
} }
if (logSoftMax)
{
for (size_t cnDim = 0; cnDim < channels; cnDim++)
{
for (size_t i = 0; i < innerSize; i++)
dstPtr[srcOffset + cnDim * cnStep + i] = log(dstPtr[srcOffset + cnDim * cnStep + i]);
}
}
} }
} }
......
...@@ -741,6 +741,17 @@ struct TorchImporter : public ::cv::dnn::Importer ...@@ -741,6 +741,17 @@ struct TorchImporter : public ::cv::dnn::Importer
layerParams.set("indices_blob_id", tensorParams["indices"].first); layerParams.set("indices_blob_id", tensorParams["indices"].first);
curModule->modules.push_back(newModule); curModule->modules.push_back(newModule);
} }
else if (nnName == "SoftMax")
{
newModule->apiType = "SoftMax";
curModule->modules.push_back(newModule);
}
else if (nnName == "LogSoftMax")
{
newModule->apiType = "SoftMax";
layerParams.set("log_softmax", true);
curModule->modules.push_back(newModule);
}
else else
{ {
CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\""); CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\"");
......
...@@ -159,6 +159,18 @@ TEST(Torch_Importer, net_cadd_table) ...@@ -159,6 +159,18 @@ TEST(Torch_Importer, net_cadd_table)
runTorchNet("net_cadd_table"); runTorchNet("net_cadd_table");
} }
TEST(Torch_Importer, net_softmax)
{
runTorchNet("net_softmax");
runTorchNet("net_softmax_spatial");
}
TEST(Torch_Importer, net_logsoftmax)
{
runTorchNet("net_logsoftmax");
runTorchNet("net_logsoftmax_spatial");
}
TEST(Torch_Importer, ENet_accuracy) TEST(Torch_Importer, ENet_accuracy)
{ {
Net net; Net net;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment