Commit a62f7e1d authored by Vitaliy Lyudvichenko's avatar Vitaliy Lyudvichenko

Small refactoring of RNN layers

parent dd9b2eb4
...@@ -44,6 +44,7 @@ ...@@ -44,6 +44,7 @@
#include "op_blas.hpp" #include "op_blas.hpp"
#include <iostream> #include <iostream>
#include <cmath> #include <cmath>
#include <opencv2/dnn/shape_utils.hpp>
namespace cv namespace cv
{ {
...@@ -60,6 +61,7 @@ static void tanh(const Mat &src, Mat &dst) ...@@ -60,6 +61,7 @@ static void tanh(const Mat &src, Mat &dst)
*itDst = std::tanh(*itSrc); *itDst = std::tanh(*itSrc);
} }
//TODO: make utils method
static void tanh(const Mat &src, Mat &dst) static void tanh(const Mat &src, Mat &dst)
{ {
dst.create(src.dims, (const int*)src.size, src.type()); dst.create(src.dims, (const int*)src.size, src.type());
...@@ -86,9 +88,9 @@ class LSTMLayerImpl : public LSTMLayer ...@@ -86,9 +88,9 @@ class LSTMLayerImpl : public LSTMLayer
int dtype; int dtype;
bool allocated; bool allocated;
BlobShape outTailShape; //shape of single output sample Shape outTailShape; //shape of single output sample
BlobShape outTsMatShape, outTsShape; //shape of N output samples Shape outTsMatShape, outTsShape; //shape of N output samples
BlobShape outResShape; //shape of T timestamps and N output samples Shape outResShape; //shape of T timestamps and N output samples
bool useTimestampDim; bool useTimestampDim;
bool produceCellOutput; bool produceCellOutput;
...@@ -101,7 +103,7 @@ public: ...@@ -101,7 +103,7 @@ public:
useTimestampDim = true; useTimestampDim = true;
produceCellOutput = false; produceCellOutput = false;
allocated = false; allocated = false;
outTailShape = BlobShape::empty(); outTailShape = Shape::empty();
} }
void setUseTimstampsDim(bool use) void setUseTimstampsDim(bool use)
...@@ -120,7 +122,7 @@ public: ...@@ -120,7 +122,7 @@ public:
{ {
CV_Assert(cInternal.empty() || C.total() == cInternal.total()); CV_Assert(cInternal.empty() || C.total() == cInternal.total());
if (!cInternal.empty()) if (!cInternal.empty())
C.reshaped(BlobShape::like(cInternal)).matRefConst().copyTo(cInternal); C.reshaped(Shape::like(cInternal)).matRefConst().copyTo(cInternal);
else else
C.matRefConst().copyTo(cInternal); C.matRefConst().copyTo(cInternal);
} }
...@@ -129,7 +131,7 @@ public: ...@@ -129,7 +131,7 @@ public:
{ {
CV_Assert(hInternal.empty() || H.total() == hInternal.total()); CV_Assert(hInternal.empty() || H.total() == hInternal.total());
if (!hInternal.empty()) if (!hInternal.empty())
H.reshaped(BlobShape::like(hInternal)).matRefConst().copyTo(hInternal); H.reshaped(Shape::like(hInternal)).matRefConst().copyTo(hInternal);
else else
H.matRefConst().copyTo(hInternal); H.matRefConst().copyTo(hInternal);
} }
...@@ -153,7 +155,7 @@ public: ...@@ -153,7 +155,7 @@ public:
return res; return res;
} }
void setOutShape(const BlobShape &outTailShape_) void setOutShape(const Shape &outTailShape_)
{ {
CV_Assert(!allocated || outTailShape_.total() == outTailShape.total()); CV_Assert(!allocated || outTailShape_.total() == outTailShape.total());
outTailShape = outTailShape_; outTailShape = outTailShape_;
...@@ -171,7 +173,7 @@ public: ...@@ -171,7 +173,7 @@ public:
blobs[0] = Wh; blobs[0] = Wh;
blobs[1] = Wx; blobs[1] = Wx;
blobs[2] = bias; blobs[2] = bias;
blobs[2].reshape(BlobShape(1, (int)bias.total())); blobs[2].reshape(Shape(1, (int)bias.total()));
} }
void allocate(const std::vector<Blob*> &input, std::vector<Blob> &output) void allocate(const std::vector<Blob*> &input, std::vector<Blob> &output)
...@@ -186,24 +188,24 @@ public: ...@@ -186,24 +188,24 @@ public:
if (!outTailShape.isEmpty()) if (!outTailShape.isEmpty())
CV_Assert(outTailShape.total() == numOut); CV_Assert(outTailShape.total() == numOut);
else else
outTailShape = BlobShape(numOut); outTailShape = Shape(numOut);
if (useTimestampDim) if (useTimestampDim)
{ {
CV_Assert(input[0]->dims() >= 2 && (int)input[0]->total(2) == numInp); CV_Assert(input[0]->dims() >= 2 && (int)input[0]->total(2) == numInp);
numTimeStamps = input[0]->size(0); numTimeStamps = input[0]->size(0);
numSamples = input[0]->size(1); numSamples = input[0]->size(1);
outResShape = BlobShape(numTimeStamps, numSamples) + outTailShape; outResShape = Shape(numTimeStamps, numSamples) + outTailShape;
} }
else else
{ {
CV_Assert(input[0]->dims() >= 1 && (int)input[0]->total(1) == numInp); CV_Assert(input[0]->dims() >= 1 && (int)input[0]->total(1) == numInp);
numTimeStamps = 1; numTimeStamps = 1;
numSamples = input[0]->size(0); numSamples = input[0]->size(0);
outResShape = BlobShape(numSamples) + outTailShape; outResShape = Shape(numSamples) + outTailShape;
} }
outTsMatShape = BlobShape(numSamples, numOut); outTsMatShape = Shape(numSamples, numOut);
outTsShape = BlobShape(numSamples) + outTailShape; outTsShape = Shape(numSamples) + outTailShape;
dtype = input[0]->type(); dtype = input[0]->type();
CV_Assert(dtype == CV_32F || dtype == CV_64F); CV_Assert(dtype == CV_32F || dtype == CV_64F);
...@@ -246,25 +248,25 @@ public: ...@@ -246,25 +248,25 @@ public:
void forward(std::vector<Blob*> &input, std::vector<Blob> &output) void forward(std::vector<Blob*> &input, std::vector<Blob> &output)
{ {
const Mat &Wh = blobs[0].matRefConst(); const Mat &Wh = blobs[0].getRefConst<Mat>();
const Mat &Wx = blobs[1].matRefConst(); const Mat &Wx = blobs[1].getRefConst<Mat>();
const Mat &bias = blobs[2].matRefConst(); const Mat &bias = blobs[2].getRefConst<Mat>();
int numSamplesTotal = numTimeStamps*numSamples; int numSamplesTotal = numTimeStamps*numSamples;
Mat xTs = input[0]->reshaped(BlobShape(numSamplesTotal, numInp)).matRefConst(); Mat xTs = reshaped(input[0]->getRefConst<Mat>(), Shape(numSamplesTotal, numInp));
BlobShape outMatShape(numSamplesTotal, numOut); Shape outMatShape(numSamplesTotal, numOut);
Mat hOutTs = output[0].reshaped(outMatShape).matRef(); Mat hOutTs = reshaped(output[0].getRef<Mat>(), outMatShape);
Mat cOutTs = (produceCellOutput) ? output[1].reshaped(outMatShape).matRef() : Mat(); Mat cOutTs = (produceCellOutput) ? reshaped(output[1].getRef<Mat>(), outMatShape) : Mat();
for (int ts = 0; ts < numTimeStamps; ts++) for (int ts = 0; ts < numTimeStamps; ts++)
{ {
Range curRowRange(ts*numSamples, (ts + 1)*numSamples); Range curRowRange(ts*numSamples, (ts + 1)*numSamples);
Mat xCurr = xTs.rowRange(curRowRange); Mat xCurr = xTs.rowRange(curRowRange);
gemmCPU(xCurr, Wx, 1, gates, 0, GEMM_2_T); // Wx * x_t dnn::gemm(xCurr, Wx, 1, gates, 0, GEMM_2_T); // Wx * x_t
gemmCPU(hInternal, Wh, 1, gates, 1, GEMM_2_T); //+Wh * h_{t-1} dnn::gemm(hInternal, Wh, 1, gates, 1, GEMM_2_T); //+Wh * h_{t-1}
gemmCPU(dummyOnes, bias, 1, gates, 1); //+b dnn::gemm(dummyOnes, bias, 1, gates, 1); //+b
Mat getesIFO = gates.colRange(0, 3*numOut); Mat getesIFO = gates.colRange(0, 3*numOut);
Mat gateI = gates.colRange(0*numOut, 1*numOut); Mat gateI = gates.colRange(0*numOut, 1*numOut);
...@@ -394,30 +396,30 @@ public: ...@@ -394,30 +396,30 @@ public:
void reshapeOutput(std::vector<Blob> &output) void reshapeOutput(std::vector<Blob> &output)
{ {
output.resize((produceH) ? 2 : 1); output.resize((produceH) ? 2 : 1);
output[0].create(BlobShape(numTimestamps, numSamples, numO), dtype); output[0].create(Shape(numTimestamps, numSamples, numO), dtype);
if (produceH) if (produceH)
output[1].create(BlobShape(numTimestamps, numSamples, numH), dtype); output[1].create(Shape(numTimestamps, numSamples, numH), dtype);
} }
void forward(std::vector<Blob*> &input, std::vector<Blob> &output) void forward(std::vector<Blob*> &input, std::vector<Blob> &output)
{ {
Mat xTs = input[0]->reshaped(BlobShape(numSamplesTotal, numX)).matRefConst(); Mat xTs = reshaped(input[0]->getRefConst<Mat>(), Shape(numSamplesTotal, numX));
Mat oTs = output[0].reshaped(BlobShape(numSamplesTotal, numO)).matRef(); Mat oTs = reshaped(output[0].getRef<Mat>(), Shape(numSamplesTotal, numO));
Mat hTs = (produceH) ? output[1].reshaped(BlobShape(numSamplesTotal, numH)).matRef() : Mat(); Mat hTs = (produceH) ? reshaped(output[1].getRef<Mat>(), Shape(numSamplesTotal, numH)) : Mat();
for (int ts = 0; ts < numTimestamps; ts++) for (int ts = 0; ts < numTimestamps; ts++)
{ {
Range curRowRange = Range(ts * numSamples, (ts + 1) * numSamples); Range curRowRange = Range(ts * numSamples, (ts + 1) * numSamples);
Mat xCurr = xTs.rowRange(curRowRange); Mat xCurr = xTs.rowRange(curRowRange);
gemmCPU(hPrev, Whh, 1, hCurr, 0, GEMM_2_T); // W_{hh} * h_{prev} dnn::gemm(hPrev, Whh, 1, hCurr, 0, GEMM_2_T); // W_{hh} * h_{prev}
gemmCPU(xCurr, Wxh, 1, hCurr, 1, GEMM_2_T); //+W_{xh} * x_{curr} dnn::gemm(xCurr, Wxh, 1, hCurr, 1, GEMM_2_T); //+W_{xh} * x_{curr}
gemmCPU(dummyBiasOnes, bh, 1, hCurr, 1); //+bh dnn::gemm(dummyBiasOnes, bh, 1, hCurr, 1); //+bh
tanh(hCurr, hPrev); tanh(hCurr, hPrev);
Mat oCurr = oTs.rowRange(curRowRange); Mat oCurr = oTs.rowRange(curRowRange);
gemmCPU(hPrev, Who, 1, oCurr, 0, GEMM_2_T); // W_{ho} * h_{prev} dnn::gemm(hPrev, Who, 1, oCurr, 0, GEMM_2_T); // W_{ho} * h_{prev}
gemmCPU(dummyBiasOnes, bo, 1, oCurr, 1); //+b_o dnn::gemm(dummyBiasOnes, bo, 1, oCurr, 1); //+b_o
tanh(oCurr, oCurr); tanh(oCurr, oCurr);
if (produceH) if (produceH)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment