Commit 84336202 authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Bidirectional LSTM

parent 11d565ca
...@@ -93,6 +93,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer ...@@ -93,6 +93,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
float forgetBias, cellClip; float forgetBias, cellClip;
bool useCellClip, usePeephole; bool useCellClip, usePeephole;
bool reverse; // If true, go in negative direction along the time axis bool reverse; // If true, go in negative direction along the time axis
bool bidirectional; // If true, produces both forward and reversed directions along time axis
public: public:
...@@ -101,6 +102,7 @@ public: ...@@ -101,6 +102,7 @@ public:
{ {
setParamsFrom(params); setParamsFrom(params);
bidirectional = params.get<bool>("bidirectional", false);
if (!blobs.empty()) if (!blobs.empty())
{ {
CV_Assert(blobs.size() >= 3); CV_Assert(blobs.size() >= 3);
...@@ -113,7 +115,7 @@ public: ...@@ -113,7 +115,7 @@ public:
CV_CheckEQ(Wh.dims, 2, ""); CV_CheckEQ(Wh.dims, 2, "");
CV_CheckEQ(Wx.dims, 2, ""); CV_CheckEQ(Wx.dims, 2, "");
CV_CheckEQ(Wh.rows, Wx.rows, ""); CV_CheckEQ(Wh.rows, Wx.rows, "");
CV_CheckEQ(Wh.rows, 4*Wh.cols, ""); CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, "");
CV_CheckEQ(Wh.rows, (int)bias.total(), ""); CV_CheckEQ(Wh.rows, (int)bias.total(), "");
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type()); CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
...@@ -136,6 +138,7 @@ public: ...@@ -136,6 +138,7 @@ public:
useCellClip = params.get<bool>("use_cell_clip", false); useCellClip = params.get<bool>("use_cell_clip", false);
usePeephole = params.get<bool>("use_peephole", false); usePeephole = params.get<bool>("use_peephole", false);
reverse = params.get<bool>("reverse", false); reverse = params.get<bool>("reverse", false);
CV_Assert(!reverse || !bidirectional);
allocated = false; allocated = false;
outTailShape.clear(); outTailShape.clear();
...@@ -207,6 +210,7 @@ public: ...@@ -207,6 +210,7 @@ public:
outResShape.push_back(_numSamples); outResShape.push_back(_numSamples);
outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end()); outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end());
outResShape.back() *= (1 + static_cast<int>(bidirectional));
size_t noutputs = produceCellOutput ? 2 : 1; size_t noutputs = produceCellOutput ? 2 : 1;
outputs.assign(noutputs, outResShape); outputs.assign(noutputs, outResShape);
...@@ -253,6 +257,7 @@ public: ...@@ -253,6 +257,7 @@ public:
outTsShape.clear(); outTsShape.clear();
outTsShape.push_back(numSamples); outTsShape.push_back(numSamples);
outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end()); outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end());
outTsShape.back() *= (1 + static_cast<int>(bidirectional));
allocated = true; allocated = true;
} }
...@@ -273,9 +278,12 @@ public: ...@@ -273,9 +278,12 @@ public:
outputs_arr.getMatVector(output); outputs_arr.getMatVector(output);
internals_arr.getMatVector(internals); internals_arr.getMatVector(internals);
const Mat &Wh = blobs[0]; const int numDirs = 1 + static_cast<int>(bidirectional);
const Mat &Wx = blobs[1]; for (int i = 0; i < numDirs; ++i)
const Mat &bias = blobs[2]; {
const Mat &Wh = blobs[0].rowRange(i * blobs[0].rows / numDirs, (i + 1) * blobs[0].rows / numDirs);
const Mat &Wx = blobs[1].rowRange(i * blobs[1].rows / numDirs, (i + 1) * blobs[1].rows / numDirs);
const Mat &bias = blobs[2].colRange(i * blobs[2].cols / numDirs, (i + 1) * blobs[2].cols / numDirs);
int numOut = Wh.size[1]; int numOut = Wh.size[1];
...@@ -289,10 +297,11 @@ public: ...@@ -289,10 +297,11 @@ public:
Mat xTs = input[0].reshape(1, numSamplesTotal); Mat xTs = input[0].reshape(1, numSamplesTotal);
Mat hOutTs = output[0].reshape(1, numSamplesTotal); Mat hOutTs = output[0].reshape(1, numSamplesTotal);
hOutTs = hOutTs.colRange(i * hOutTs.cols / numDirs, (i + 1) * hOutTs.cols / numDirs);
Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat(); Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat();
int tsStart, tsEnd, tsInc; int tsStart, tsEnd, tsInc;
if (reverse) { if (reverse || i == 1) {
tsStart = numTimeStamps - 1; tsStart = numTimeStamps - 1;
tsEnd = -1; tsEnd = -1;
tsInc = -1; tsInc = -1;
...@@ -360,6 +369,7 @@ public: ...@@ -360,6 +369,7 @@ public:
cInternal.copyTo(cOutTs.rowRange(curRowRange)); cInternal.copyTo(cOutTs.rowRange(curRowRange));
} }
} }
}
}; };
Ptr<LSTMLayer> LSTMLayer::create(const LayerParams& params) Ptr<LSTMLayer> LSTMLayer::create(const LayerParams& params)
......
...@@ -630,37 +630,44 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -630,37 +630,44 @@ void ONNXImporter::populateNet(Net dstNet)
Mat Wx = getBlob(node_proto, constBlobs, 1); Mat Wx = getBlob(node_proto, constBlobs, 1);
Mat Wh = getBlob(node_proto, constBlobs, 2); Mat Wh = getBlob(node_proto, constBlobs, 2);
Mat b = getBlob(node_proto, constBlobs, 3); Mat b = getBlob(node_proto, constBlobs, 3);
b = b.reshape(1, b.size[0]);
const int numHidden = lstmParams.get<int>("hidden_size"); const int numHidden = lstmParams.get<int>("hidden_size");
const int numDirs = Wx.size[0]; // Is 1 for forward only and 2 for bidirectional LSTM.
Wx = Wx.reshape(1, Wx.size[1]); const int numFeatures = Wx.size[2];
Wh = Wh.reshape(1, Wh.size[1]); Mat bx = b.colRange(0, b.cols / 2);
b = b.reshape(1, 2); Mat bh = b.colRange(b.cols / 2, b.cols);
reduce(b, b, 0, REDUCE_SUM); b = bx + bh;
// IFGO->IGFO // IFGO->IGFO
float* WxData = (float*)Wx.data; for (int k = 0; k < numDirs; ++k)
float* WhData = (float*)Wh.data; {
float* biasData = (float*)b.data; float* WxData = Wx.ptr<float>(k);
float* WhData = Wh.ptr<float>(k);
float* biasData = b.ptr<float>(k);
for (int j = 0; j < numHidden; ++j) for (int j = 0; j < numHidden; ++j)
{ {
for (int i = 0; i < Wx.cols; ++i) for (int i = 0; i < numFeatures; ++i)
{ {
std::swap(WxData[(numHidden + j) * Wx.cols + i], std::swap(WxData[(numHidden + j) * numFeatures + i],
WxData[(numHidden * 2 + j) * Wx.cols + i]); WxData[(numHidden * 2 + j) * numFeatures + i]);
} }
for (int i = 0; i < Wh.cols; ++i) for (int i = 0; i < numHidden; ++i)
{ {
std::swap(WhData[(numHidden + j) * Wh.cols + i], std::swap(WhData[(numHidden + j) * numHidden + i],
WhData[(numHidden * 2 + j) * Wh.cols + i]); WhData[(numHidden * 2 + j) * numHidden + i]);
} }
std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]); std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]);
} }
}
Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
lstmParams.blobs.resize(3); lstmParams.blobs.resize(3);
lstmParams.blobs[0] = Wh; lstmParams.blobs[0] = Wh;
lstmParams.blobs[1] = Wx; lstmParams.blobs[1] = Wx;
lstmParams.blobs[2] = b; lstmParams.blobs[2] = b;
lstmParams.set("bidirectional", lstmParams.get<String>("direction", "") == "bidirectional");
node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name
addLayer(dstNet, lstmParams, node_proto, layer_id, outShapes); addLayer(dstNet, lstmParams, node_proto, layer_id, outShapes);
......
...@@ -456,6 +456,11 @@ TEST_P(Test_ONNX_layers, LSTM) ...@@ -456,6 +456,11 @@ TEST_P(Test_ONNX_layers, LSTM)
testONNXModels("lstm"); testONNXModels("lstm");
} }
TEST_P(Test_ONNX_layers, LSTM_bidirectional)
{
testONNXModels("lstm_bidirectional");
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets()); INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
class Test_ONNX_nets : public Test_ONNX_layers class Test_ONNX_nets : public Test_ONNX_layers
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment