Commit e8c7d617 authored by Alexander Alekhin's avatar Alexander Alekhin

Merge pull request #16817 from dkurt:dnn_onnx_lstm

parents b1f390b1 467c3ef0
...@@ -93,6 +93,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer ...@@ -93,6 +93,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
float forgetBias, cellClip; float forgetBias, cellClip;
bool useCellClip, usePeephole; bool useCellClip, usePeephole;
bool reverse; // If true, go in negative direction along the time axis bool reverse; // If true, go in negative direction along the time axis
bool bidirectional; // If true, produces both forward and reversed directions along time axis
public: public:
...@@ -101,6 +102,7 @@ public: ...@@ -101,6 +102,7 @@ public:
{ {
setParamsFrom(params); setParamsFrom(params);
bidirectional = params.get<bool>("bidirectional", false);
if (!blobs.empty()) if (!blobs.empty())
{ {
CV_Assert(blobs.size() >= 3); CV_Assert(blobs.size() >= 3);
...@@ -110,10 +112,11 @@ public: ...@@ -110,10 +112,11 @@ public:
const Mat& Wh = blobs[0]; const Mat& Wh = blobs[0];
const Mat& Wx = blobs[1]; const Mat& Wx = blobs[1];
const Mat& bias = blobs[2]; const Mat& bias = blobs[2];
CV_Assert(Wh.dims == 2 && Wx.dims == 2); CV_CheckEQ(Wh.dims, 2, "");
CV_Assert(Wh.rows == Wx.rows); CV_CheckEQ(Wx.dims, 2, "");
CV_Assert(Wh.rows == 4*Wh.cols); CV_CheckEQ(Wh.rows, Wx.rows, "");
CV_Assert(Wh.rows == (int)bias.total()); CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, "");
CV_CheckEQ(Wh.rows, (int)bias.total(), "");
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type()); CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
// Peephole weights. // Peephole weights.
...@@ -135,6 +138,7 @@ public: ...@@ -135,6 +138,7 @@ public:
useCellClip = params.get<bool>("use_cell_clip", false); useCellClip = params.get<bool>("use_cell_clip", false);
usePeephole = params.get<bool>("use_peephole", false); usePeephole = params.get<bool>("use_peephole", false);
reverse = params.get<bool>("reverse", false); reverse = params.get<bool>("reverse", false);
CV_Assert(!reverse || !bidirectional);
allocated = false; allocated = false;
outTailShape.clear(); outTailShape.clear();
...@@ -206,6 +210,7 @@ public: ...@@ -206,6 +210,7 @@ public:
outResShape.push_back(_numSamples); outResShape.push_back(_numSamples);
outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end()); outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end());
outResShape.back() *= (1 + static_cast<int>(bidirectional));
size_t noutputs = produceCellOutput ? 2 : 1; size_t noutputs = produceCellOutput ? 2 : 1;
outputs.assign(noutputs, outResShape); outputs.assign(noutputs, outResShape);
...@@ -252,6 +257,7 @@ public: ...@@ -252,6 +257,7 @@ public:
outTsShape.clear(); outTsShape.clear();
outTsShape.push_back(numSamples); outTsShape.push_back(numSamples);
outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end()); outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end());
outTsShape.back() *= (1 + static_cast<int>(bidirectional));
allocated = true; allocated = true;
} }
...@@ -272,9 +278,12 @@ public: ...@@ -272,9 +278,12 @@ public:
outputs_arr.getMatVector(output); outputs_arr.getMatVector(output);
internals_arr.getMatVector(internals); internals_arr.getMatVector(internals);
const Mat &Wh = blobs[0]; const int numDirs = 1 + static_cast<int>(bidirectional);
const Mat &Wx = blobs[1]; for (int i = 0; i < numDirs; ++i)
const Mat &bias = blobs[2]; {
const Mat &Wh = blobs[0].rowRange(i * blobs[0].rows / numDirs, (i + 1) * blobs[0].rows / numDirs);
const Mat &Wx = blobs[1].rowRange(i * blobs[1].rows / numDirs, (i + 1) * blobs[1].rows / numDirs);
const Mat &bias = blobs[2].colRange(i * blobs[2].cols / numDirs, (i + 1) * blobs[2].cols / numDirs);
int numOut = Wh.size[1]; int numOut = Wh.size[1];
...@@ -288,10 +297,11 @@ public: ...@@ -288,10 +297,11 @@ public:
Mat xTs = input[0].reshape(1, numSamplesTotal); Mat xTs = input[0].reshape(1, numSamplesTotal);
Mat hOutTs = output[0].reshape(1, numSamplesTotal); Mat hOutTs = output[0].reshape(1, numSamplesTotal);
hOutTs = hOutTs.colRange(i * hOutTs.cols / numDirs, (i + 1) * hOutTs.cols / numDirs);
Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat(); Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat();
int tsStart, tsEnd, tsInc; int tsStart, tsEnd, tsInc;
if (reverse) { if (reverse || i == 1) {
tsStart = numTimeStamps - 1; tsStart = numTimeStamps - 1;
tsEnd = -1; tsEnd = -1;
tsInc = -1; tsInc = -1;
...@@ -359,6 +369,7 @@ public: ...@@ -359,6 +369,7 @@ public:
cInternal.copyTo(cOutTs.rowRange(curRowRange)); cInternal.copyTo(cOutTs.rowRange(curRowRange));
} }
} }
}
}; };
Ptr<LSTMLayer> LSTMLayer::create(const LayerParams& params) Ptr<LSTMLayer> LSTMLayer::create(const LayerParams& params)
......
...@@ -49,6 +49,11 @@ class ONNXImporter ...@@ -49,6 +49,11 @@ class ONNXImporter
LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto); LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto);
bool isCeilMode(const LayerParams& layerParams); bool isCeilMode(const LayerParams& layerParams);
void addLayer(Net& dstNet, LayerParams& layerParams,
const opencv_onnx::NodeProto& node_proto,
std::map<std::string, LayerInfo>& layer_id,
std::map<std::string, MatShape>& outShapes);
public: public:
ONNXImporter(const char *onnxFile) ONNXImporter(const char *onnxFile)
...@@ -259,6 +264,42 @@ Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto, ...@@ -259,6 +264,42 @@ Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto,
return constBlob->second; return constBlob->second;
} }
void ONNXImporter::addLayer(Net& dstNet, LayerParams& layerParams,
const opencv_onnx::NodeProto& node_proto,
std::map<std::string, LayerInfo>& layer_id,
std::map<std::string, MatShape>& outShapes)
{
std::map<std::string, LayerInfo>::iterator layerId;
std::map<std::string, MatShape>::iterator shapeIt;
int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
for (int i = 0; i < node_proto.output_size(); ++i)
{
layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
}
std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
int inpNum = 0;
for (int j = 0; j < node_proto.input_size(); j++) {
layerId = layer_id.find(node_proto.input(j));
if (layerId != layer_id.end()) {
dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, inpNum);
++inpNum;
// Collect input shapes.
shapeIt = outShapes.find(node_proto.input(j));
CV_Assert(shapeIt != outShapes.end());
layerInpShapes.push_back(shapeIt->second);
}
}
// Compute shape of output blob for this layer.
Ptr<Layer> layer = dstNet.getLayer(id);
layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
{
outShapes[node_proto.output(i)] = layerOutShapes[i];
}
}
void ONNXImporter::populateNet(Net dstNet) void ONNXImporter::populateNet(Net dstNet)
{ {
CV_Assert(model_proto.has_graph()); CV_Assert(model_proto.has_graph());
...@@ -455,6 +496,7 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -455,6 +496,7 @@ void ONNXImporter::populateNet(Net dstNet)
runLayer(layerParams, inputs, sliced); runLayer(layerParams, inputs, sliced);
CV_Assert(sliced.size() == 1); CV_Assert(sliced.size() == 1);
constBlobs.insert(std::make_pair(layerParams.name, sliced[0])); constBlobs.insert(std::make_pair(layerParams.name, sliced[0]));
outShapes[layerParams.name] = shape(sliced[0]);
continue; continue;
} }
} }
...@@ -579,6 +621,70 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -579,6 +621,70 @@ void ONNXImporter::populateNet(Net dstNet)
constBlobs.insert(std::make_pair(layerParams.name, layerParams.blobs[0])); constBlobs.insert(std::make_pair(layerParams.name, layerParams.blobs[0]));
continue; continue;
} }
else if (layer_type == "LSTM")
{
LayerParams lstmParams = layerParams;
lstmParams.name += "/lstm";
// https://pytorch.org/docs/stable/nn.html#lstm
CV_Assert(node_proto.input_size() == 7);
Mat Wx = getBlob(node_proto, constBlobs, 1);
Mat Wh = getBlob(node_proto, constBlobs, 2);
Mat b = getBlob(node_proto, constBlobs, 3);
CV_CheckEQ(countNonZero(getBlob(node_proto, constBlobs, 5)), 0, "Unsupported non zero initial_h");
CV_CheckEQ(countNonZero(getBlob(node_proto, constBlobs, 6)), 0, "Unsupported non zero initial_c");
b = b.reshape(1, b.size[0]);
const int numHidden = lstmParams.get<int>("hidden_size");
const int numDirs = Wx.size[0]; // Is 1 for forward only and 2 for bidirectional LSTM.
const int numFeatures = Wx.size[2];
Mat bx = b.colRange(0, b.cols / 2);
Mat bh = b.colRange(b.cols / 2, b.cols);
b = bx + bh;
// IFGO->IGFO
for (int k = 0; k < numDirs; ++k)
{
float* WxData = Wx.ptr<float>(k);
float* WhData = Wh.ptr<float>(k);
float* biasData = b.ptr<float>(k);
for (int j = 0; j < numHidden; ++j)
{
for (int i = 0; i < numFeatures; ++i)
{
std::swap(WxData[(numHidden + j) * numFeatures + i],
WxData[(numHidden * 2 + j) * numFeatures + i]);
}
for (int i = 0; i < numHidden; ++i)
{
std::swap(WhData[(numHidden + j) * numHidden + i],
WhData[(numHidden * 2 + j) * numHidden + i]);
}
std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]);
}
}
Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
lstmParams.blobs.resize(3);
lstmParams.blobs[0] = Wh;
lstmParams.blobs[1] = Wx;
lstmParams.blobs[2] = b;
lstmParams.set("bidirectional", lstmParams.get<String>("direction", "") == "bidirectional");
node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name
addLayer(dstNet, lstmParams, node_proto, layer_id, outShapes);
MatShape lstmShape = outShapes[node_proto.output(0)];
// Add fake 1 as it is done in ONNX
lstmShape.insert(lstmShape.begin() + 1, 1);
layerParams.type = "Reshape";
layerParams.set("dim", DictValue::arrayInt(&lstmShape[0], lstmShape.size()));
node_proto.set_input(0, lstmParams.name); // redirect input to LSTM
node_proto.set_output(0, layerParams.name); // keep origin LSTM's name
}
else if (layer_type == "ImageScaler") else if (layer_type == "ImageScaler")
{ {
const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f; const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f;
...@@ -882,13 +988,38 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -882,13 +988,38 @@ void ONNXImporter::populateNet(Net dstNet)
{ {
CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes")); CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
DictValue axes_dict = layerParams.get("axes"); DictValue axes_dict = layerParams.get("axes");
if (axes_dict.size() != 1) MatShape inpShape = outShapes[node_proto.input(0)];
CV_Error(Error::StsNotImplemented, "Multidimensional squeeze");
int axis = axes_dict.getIntValue(0); std::vector<bool> maskedAxes(inpShape.size(), false);
layerParams.set("axis", axis - 1); for (int i = 0; i < axes_dict.size(); ++i)
layerParams.set("end_axis", axis); {
layerParams.type = "Flatten"; int axis = axes_dict.getIntValue(i);
CV_CheckLE(axis, static_cast<int>(inpShape.size()), "Squeeze axis");
maskedAxes[axis] = inpShape[axis] == 1;
}
MatShape outShape;
for (int i = 0; i < inpShape.size(); ++i)
{
if (!maskedAxes[i])
outShape.push_back(inpShape[i]);
}
if (outShape.size() != inpShape.size())
{
layerParams.type = "Reshape";
layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size()));
}
else
layerParams.type = "Identity";
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
Mat inp = getBlob(node_proto, constBlobs, 0);
Mat out = inp.reshape(1, outShape);
out.dims = outShape.size(); // to workaround dims == 1
constBlobs.insert(std::make_pair(layerParams.name, out));
outShapes[layerParams.name] = shape(out);
continue;
}
} }
else if (layer_type == "Flatten") else if (layer_type == "Flatten")
{ {
...@@ -1018,9 +1149,17 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -1018,9 +1149,17 @@ void ONNXImporter::populateNet(Net dstNet)
else else
layerParams.type = "Identity"; layerParams.type = "Identity";
} }
else if (layer_type == "ConstantOfShape") else if (layer_type == "ConstantOfShape" || layer_type == "ConstantFill")
{
float fill_value;
if (!layerParams.blobs.empty())
{ {
float fill_value = layerParams.blobs.empty() ? 0 : layerParams.blobs[0].at<float>(0, 0); CV_Assert(!layerParams.has("value"));
fill_value = layerParams.blobs[0].at<float>(0, 0);
}
else
fill_value = layerParams.get("value", 0);
MatShape inpShape = getBlob(node_proto, constBlobs, 0); MatShape inpShape = getBlob(node_proto, constBlobs, 0);
for (int i = 0; i < inpShape.size(); i++) for (int i = 0; i < inpShape.size(); i++)
CV_CheckGT(inpShape[i], 0, ""); CV_CheckGT(inpShape[i], 0, "");
...@@ -1032,17 +1171,30 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -1032,17 +1171,30 @@ void ONNXImporter::populateNet(Net dstNet)
else if (layer_type == "Gather") else if (layer_type == "Gather")
{ {
CV_Assert(node_proto.input_size() == 2); CV_Assert(node_proto.input_size() == 2);
CV_Assert(layerParams.has("axis"));
Mat input = getBlob(node_proto, constBlobs, 0); Mat input = getBlob(node_proto, constBlobs, 0);
Mat indexMat = getBlob(node_proto, constBlobs, 1); Mat indexMat = getBlob(node_proto, constBlobs, 1);
CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1); CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
int index = indexMat.at<int>(0); int index = indexMat.at<int>(0);
Mat out;
if (layerParams.has("axis"))
{
int axis = layerParams.get<int>("axis"); int axis = layerParams.get<int>("axis");
std::vector<cv::Range> ranges(input.dims, Range::all()); std::vector<cv::Range> ranges(input.dims, Range::all());
ranges[axis] = Range(index, index + 1); ranges[axis] = Range(index, index + 1);
Mat out = input(ranges); out = input(ranges);
}
else
{
CV_Assert(index < input.total());
const int dims = input.dims;
input = input.reshape(1, 1);
input.dims = 2;
out = input.reshape(1, 1).colRange(index, index + 1);
out.dims = dims;
}
constBlobs.insert(std::make_pair(layerParams.name, out)); constBlobs.insert(std::make_pair(layerParams.name, out));
continue; continue;
} }
...@@ -1145,34 +1297,7 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -1145,34 +1297,7 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j)); layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
} }
} }
addLayer(dstNet, layerParams, node_proto, layer_id, outShapes);
int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
for (int i = 0; i < node_proto.output_size(); ++i)
{
layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
}
std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
int inpNum = 0;
for (int j = 0; j < node_proto.input_size(); j++) {
layerId = layer_id.find(node_proto.input(j));
if (layerId != layer_id.end()) {
dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, inpNum);
++inpNum;
// Collect input shapes.
shapeIt = outShapes.find(node_proto.input(j));
CV_Assert(shapeIt != outShapes.end());
layerInpShapes.push_back(shapeIt->second);
}
}
// Compute shape of output blob for this layer.
Ptr<Layer> layer = dstNet.getLayer(id);
layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
{
outShapes[node_proto.output(i)] = layerOutShapes[i];
}
} }
} }
......
...@@ -405,6 +405,8 @@ TEST_P(Test_ONNX_layers, Reshape) ...@@ -405,6 +405,8 @@ TEST_P(Test_ONNX_layers, Reshape)
TEST_P(Test_ONNX_layers, Squeeze) TEST_P(Test_ONNX_layers, Squeeze)
{ {
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
testONNXModels("squeeze"); testONNXModels("squeeze");
} }
...@@ -451,6 +453,16 @@ TEST_P(Test_ONNX_layers, Split_EltwiseMax) ...@@ -451,6 +453,16 @@ TEST_P(Test_ONNX_layers, Split_EltwiseMax)
testONNXModels("split_max"); testONNXModels("split_max");
} }
TEST_P(Test_ONNX_layers, LSTM)
{
testONNXModels("lstm", npy, 0, 0, false, false);
}
TEST_P(Test_ONNX_layers, LSTM_bidirectional)
{
testONNXModels("lstm_bidirectional", npy, 0, 0, false, false);
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets()); INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
class Test_ONNX_nets : public Test_ONNX_layers class Test_ONNX_nets : public Test_ONNX_layers
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment