Commit e8c7d617 authored by Alexander Alekhin's avatar Alexander Alekhin

Merge pull request #16817 from dkurt:dnn_onnx_lstm

parents b1f390b1 467c3ef0
...@@ -93,6 +93,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer ...@@ -93,6 +93,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
float forgetBias, cellClip; float forgetBias, cellClip;
bool useCellClip, usePeephole; bool useCellClip, usePeephole;
bool reverse; // If true, go in negative direction along the time axis bool reverse; // If true, go in negative direction along the time axis
bool bidirectional; // If true, produces both forward and reversed directions along time axis
public: public:
...@@ -101,6 +102,7 @@ public: ...@@ -101,6 +102,7 @@ public:
{ {
setParamsFrom(params); setParamsFrom(params);
bidirectional = params.get<bool>("bidirectional", false);
if (!blobs.empty()) if (!blobs.empty())
{ {
CV_Assert(blobs.size() >= 3); CV_Assert(blobs.size() >= 3);
...@@ -110,10 +112,11 @@ public: ...@@ -110,10 +112,11 @@ public:
const Mat& Wh = blobs[0]; const Mat& Wh = blobs[0];
const Mat& Wx = blobs[1]; const Mat& Wx = blobs[1];
const Mat& bias = blobs[2]; const Mat& bias = blobs[2];
CV_Assert(Wh.dims == 2 && Wx.dims == 2); CV_CheckEQ(Wh.dims, 2, "");
CV_Assert(Wh.rows == Wx.rows); CV_CheckEQ(Wx.dims, 2, "");
CV_Assert(Wh.rows == 4*Wh.cols); CV_CheckEQ(Wh.rows, Wx.rows, "");
CV_Assert(Wh.rows == (int)bias.total()); CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, "");
CV_CheckEQ(Wh.rows, (int)bias.total(), "");
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type()); CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
// Peephole weights. // Peephole weights.
...@@ -135,6 +138,7 @@ public: ...@@ -135,6 +138,7 @@ public:
useCellClip = params.get<bool>("use_cell_clip", false); useCellClip = params.get<bool>("use_cell_clip", false);
usePeephole = params.get<bool>("use_peephole", false); usePeephole = params.get<bool>("use_peephole", false);
reverse = params.get<bool>("reverse", false); reverse = params.get<bool>("reverse", false);
CV_Assert(!reverse || !bidirectional);
allocated = false; allocated = false;
outTailShape.clear(); outTailShape.clear();
...@@ -206,6 +210,7 @@ public: ...@@ -206,6 +210,7 @@ public:
outResShape.push_back(_numSamples); outResShape.push_back(_numSamples);
outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end()); outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end());
outResShape.back() *= (1 + static_cast<int>(bidirectional));
size_t noutputs = produceCellOutput ? 2 : 1; size_t noutputs = produceCellOutput ? 2 : 1;
outputs.assign(noutputs, outResShape); outputs.assign(noutputs, outResShape);
...@@ -252,6 +257,7 @@ public: ...@@ -252,6 +257,7 @@ public:
outTsShape.clear(); outTsShape.clear();
outTsShape.push_back(numSamples); outTsShape.push_back(numSamples);
outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end()); outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end());
outTsShape.back() *= (1 + static_cast<int>(bidirectional));
allocated = true; allocated = true;
} }
...@@ -272,91 +278,96 @@ public: ...@@ -272,91 +278,96 @@ public:
outputs_arr.getMatVector(output); outputs_arr.getMatVector(output);
internals_arr.getMatVector(internals); internals_arr.getMatVector(internals);
const Mat &Wh = blobs[0]; const int numDirs = 1 + static_cast<int>(bidirectional);
const Mat &Wx = blobs[1]; for (int i = 0; i < numDirs; ++i)
const Mat &bias = blobs[2];
int numOut = Wh.size[1];
Mat hInternal = internals[0], cInternal = internals[1],
dummyOnes = internals[2], gates = internals[3];
hInternal.setTo(0.);
cInternal.setTo(0.);
dummyOnes.setTo(1.);
int numSamplesTotal = numTimeStamps*numSamples;
Mat xTs = input[0].reshape(1, numSamplesTotal);
Mat hOutTs = output[0].reshape(1, numSamplesTotal);
Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat();
int tsStart, tsEnd, tsInc;
if (reverse) {
tsStart = numTimeStamps - 1;
tsEnd = -1;
tsInc = -1;
}
else {
tsStart = 0;
tsEnd = numTimeStamps;
tsInc = 1;
}
for (int ts = tsStart; ts != tsEnd; ts += tsInc)
{ {
Range curRowRange(ts*numSamples, (ts + 1)*numSamples); const Mat &Wh = blobs[0].rowRange(i * blobs[0].rows / numDirs, (i + 1) * blobs[0].rows / numDirs);
Mat xCurr = xTs.rowRange(curRowRange); const Mat &Wx = blobs[1].rowRange(i * blobs[1].rows / numDirs, (i + 1) * blobs[1].rows / numDirs);
const Mat &bias = blobs[2].colRange(i * blobs[2].cols / numDirs, (i + 1) * blobs[2].cols / numDirs);
int numOut = Wh.size[1];
Mat hInternal = internals[0], cInternal = internals[1],
dummyOnes = internals[2], gates = internals[3];
hInternal.setTo(0.);
cInternal.setTo(0.);
dummyOnes.setTo(1.);
int numSamplesTotal = numTimeStamps*numSamples;
Mat xTs = input[0].reshape(1, numSamplesTotal);
Mat hOutTs = output[0].reshape(1, numSamplesTotal);
hOutTs = hOutTs.colRange(i * hOutTs.cols / numDirs, (i + 1) * hOutTs.cols / numDirs);
Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat();
int tsStart, tsEnd, tsInc;
if (reverse || i == 1) {
tsStart = numTimeStamps - 1;
tsEnd = -1;
tsInc = -1;
}
else {
tsStart = 0;
tsEnd = numTimeStamps;
tsInc = 1;
}
for (int ts = tsStart; ts != tsEnd; ts += tsInc)
{
Range curRowRange(ts*numSamples, (ts + 1)*numSamples);
Mat xCurr = xTs.rowRange(curRowRange);
gemm(xCurr, Wx, 1, gates, 0, gates, GEMM_2_T); // Wx * x_t gemm(xCurr, Wx, 1, gates, 0, gates, GEMM_2_T); // Wx * x_t
gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T); //+Wh * h_{t-1} gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T); //+Wh * h_{t-1}
gemm(dummyOnes, bias, 1, gates, 1, gates); //+b gemm(dummyOnes, bias, 1, gates, 1, gates); //+b
Mat gateI = gates.colRange(0*numOut, 1*numOut); Mat gateI = gates.colRange(0*numOut, 1*numOut);
Mat gateF = gates.colRange(1*numOut, 2*numOut); Mat gateF = gates.colRange(1*numOut, 2*numOut);
Mat gateO = gates.colRange(2*numOut, 3*numOut); Mat gateO = gates.colRange(2*numOut, 3*numOut);
Mat gateG = gates.colRange(3*numOut, 4*numOut); Mat gateG = gates.colRange(3*numOut, 4*numOut);
if (forgetBias) if (forgetBias)
add(gateF, forgetBias, gateF); add(gateF, forgetBias, gateF);
if (usePeephole) if (usePeephole)
{ {
Mat gatesIF = gates.colRange(0, 2*numOut); Mat gatesIF = gates.colRange(0, 2*numOut);
gemm(cInternal, blobs[3], 1, gateI, 1, gateI); gemm(cInternal, blobs[3], 1, gateI, 1, gateI);
gemm(cInternal, blobs[4], 1, gateF, 1, gateF); gemm(cInternal, blobs[4], 1, gateF, 1, gateF);
sigmoid(gatesIF, gatesIF); sigmoid(gatesIF, gatesIF);
} }
else else
{ {
Mat gatesIFO = gates.colRange(0, 3*numOut); Mat gatesIFO = gates.colRange(0, 3*numOut);
sigmoid(gatesIFO, gatesIFO); sigmoid(gatesIFO, gatesIFO);
} }
tanh(gateG, gateG); tanh(gateG, gateG);
//compute c_t //compute c_t
multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1} multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1}
multiply(gateI, gateG, gateI); // i_t (*) g_t multiply(gateI, gateG, gateI); // i_t (*) g_t
add(gateF, gateI, cInternal); // c_t = f_t (*) c_{t-1} + i_t (*) g_t add(gateF, gateI, cInternal); // c_t = f_t (*) c_{t-1} + i_t (*) g_t
if (useCellClip) if (useCellClip)
{ {
min(cInternal, cellClip, cInternal); min(cInternal, cellClip, cInternal);
max(cInternal, -cellClip, cInternal); max(cInternal, -cellClip, cInternal);
} }
if (usePeephole) if (usePeephole)
{ {
gemm(cInternal, blobs[5], 1, gateO, 1, gateO); gemm(cInternal, blobs[5], 1, gateO, 1, gateO);
sigmoid(gateO, gateO); sigmoid(gateO, gateO);
} }
//compute h_t //compute h_t
tanh(cInternal, hInternal); tanh(cInternal, hInternal);
multiply(gateO, hInternal, hInternal); multiply(gateO, hInternal, hInternal);
//save results in output blobs //save results in output blobs
hInternal.copyTo(hOutTs.rowRange(curRowRange)); hInternal.copyTo(hOutTs.rowRange(curRowRange));
if (produceCellOutput) if (produceCellOutput)
cInternal.copyTo(cOutTs.rowRange(curRowRange)); cInternal.copyTo(cOutTs.rowRange(curRowRange));
}
} }
} }
}; };
......
This diff is collapsed.
...@@ -405,6 +405,8 @@ TEST_P(Test_ONNX_layers, Reshape) ...@@ -405,6 +405,8 @@ TEST_P(Test_ONNX_layers, Reshape)
TEST_P(Test_ONNX_layers, Squeeze) TEST_P(Test_ONNX_layers, Squeeze)
{ {
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
testONNXModels("squeeze"); testONNXModels("squeeze");
} }
...@@ -451,6 +453,16 @@ TEST_P(Test_ONNX_layers, Split_EltwiseMax) ...@@ -451,6 +453,16 @@ TEST_P(Test_ONNX_layers, Split_EltwiseMax)
testONNXModels("split_max"); testONNXModels("split_max");
} }
TEST_P(Test_ONNX_layers, LSTM)
{
testONNXModels("lstm", npy, 0, 0, false, false);
}
TEST_P(Test_ONNX_layers, LSTM_bidirectional)
{
testONNXModels("lstm_bidirectional", npy, 0, 0, false, false);
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets()); INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
class Test_ONNX_nets : public Test_ONNX_layers class Test_ONNX_nets : public Test_ONNX_layers
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment