Commit 73dc666d authored by Alexander Alekhin's avatar Alexander Alekhin

Merge pull request #1212 from dkurt:torch_softmax_layer

parents 6c9d6d50 78ff9d93
......@@ -251,6 +251,8 @@ namespace dnn
class CV_EXPORTS SoftmaxLayer : public Layer
{
public:
bool logSoftMax;
static Ptr<SoftmaxLayer> create(const LayerParams& params);
};
......
......@@ -436,6 +436,7 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
* - nn.SpatialMaxPooling, nn.SpatialAveragePooling
* - nn.ReLU, nn.TanH, nn.Sigmoid
* - nn.Reshape
* - nn.SoftMax, nn.LogSoftMax
*
* Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported.
*/
......
......@@ -57,6 +57,7 @@ public:
SoftMaxLayerImpl(const LayerParams& params)
{
axisRaw = params.get<int>("axis", 1);
logSoftMax = params.get<int>("log_softmax", false);
setParamsFrom(params);
}
......@@ -143,6 +144,14 @@ public:
for (size_t i = 0; i < innerSize; i++)
dstPtr[srcOffset + cnDim * cnStep + i] /= bufPtr[bufOffset + i];
}
if (logSoftMax)
{
for (size_t cnDim = 0; cnDim < channels; cnDim++)
{
for (size_t i = 0; i < innerSize; i++)
dstPtr[srcOffset + cnDim * cnStep + i] = log(dstPtr[srcOffset + cnDim * cnStep + i]);
}
}
}
}
......
......@@ -741,6 +741,17 @@ struct TorchImporter : public ::cv::dnn::Importer
layerParams.set("indices_blob_id", tensorParams["indices"].first);
curModule->modules.push_back(newModule);
}
else if (nnName == "SoftMax")
{
newModule->apiType = "SoftMax";
curModule->modules.push_back(newModule);
}
else if (nnName == "LogSoftMax")
{
newModule->apiType = "SoftMax";
layerParams.set("log_softmax", true);
curModule->modules.push_back(newModule);
}
else
{
CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\"");
......
......@@ -159,6 +159,18 @@ TEST(Torch_Importer, net_cadd_table)
runTorchNet("net_cadd_table");
}
TEST(Torch_Importer, net_softmax)
{
runTorchNet("net_softmax");
runTorchNet("net_softmax_spatial");
}
TEST(Torch_Importer, net_logsoftmax)
{
runTorchNet("net_logsoftmax");
runTorchNet("net_logsoftmax_spatial");
}
TEST(Torch_Importer, ENet_accuracy)
{
Net net;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment