recurrent_layers.cpp 13.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                           License Agreement
//                For Open Source Computer Vision Library
//
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of the copyright holders may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "../precomp.hpp"
#include "recurrent_layers.hpp"
#include "op_blas.hpp"
#include <iostream>
46
#include <cmath>
47
#include <opencv2/dnn/shape_utils.hpp>
48 49 50 51 52 53

namespace cv
{
namespace dnn
{

54 55 56 57 58 59 60 61 62 63
template<typename Dtype>
static void tanh(const Mat &src, Mat &dst)
{
    MatConstIterator_<Dtype> itSrc = src.begin<Dtype>();
    MatIterator_<Dtype> itDst = dst.begin<Dtype>();

    for (; itSrc != src.end<Dtype>(); itSrc++, itDst++)
        *itDst = std::tanh(*itSrc);
}

64
//TODO: make utils method
65 66 67 68 69 70 71 72 73
static void tanh(const Mat &src, Mat &dst)
{
    dst.create(src.dims, (const int*)src.size, src.type());

    if (src.type() == CV_32F)
        tanh<float>(src, dst);
    else if (src.type() == CV_64F)
        tanh<double>(src, dst);
    else
74
        CV_Error(Error::StsUnsupportedFormat, "Function supports only floating point types");
75 76 77 78 79 80 81 82
}

static void sigmoid(const Mat &src, Mat &dst)
{
    cv::exp(-src, dst);
    cv::pow(1 + dst, -1, dst);
}

83 84
class LSTMLayerImpl : public LSTMLayer
{
85 86 87 88 89 90
    int numOut, numTimeStamps, numSamples, numInp;
    Mat hInternal, cInternal;
    Mat gates, dummyOnes;
    int dtype;
    bool allocated;

91 92 93
    Shape outTailShape;                 //shape of single output sample
    Shape outTsMatShape, outTsShape;    //shape of N output samples
    Shape outResShape;                  //shape of T timestamps and N output samples
94

95 96 97
    bool useTimestampDim;
    bool produceCellOutput;

98 99 100 101 102
public:

    LSTMLayerImpl()
    {
        type = "LSTM";
103 104 105
        useTimestampDim = true;
        produceCellOutput = false;
        allocated = false;
106
        outTailShape = Shape::empty();
107 108
    }

109 110 111 112 113 114 115 116 117 118 119 120 121 122
    void setUseTimstampsDim(bool use)
    {
        CV_Assert(!allocated);
        useTimestampDim = use;
    }

    void setProduceCellOutput(bool produce)
    {
        CV_Assert(!allocated);
        produceCellOutput = produce;
    }

    void setC(const Blob &C)
    {
123 124
        CV_Assert(cInternal.empty() || C.total() == cInternal.total());
        if (!cInternal.empty())
125
            C.reshaped(Shape::like(cInternal)).matRefConst().copyTo(cInternal);
126 127
        else
            C.matRefConst().copyTo(cInternal);
128 129 130 131
    }

    void setH(const Blob &H)
    {
132 133
        CV_Assert(hInternal.empty() || H.total() == hInternal.total());
        if (!hInternal.empty())
134
            H.reshaped(Shape::like(hInternal)).matRefConst().copyTo(hInternal);
135 136
        else
            H.matRefConst().copyTo(hInternal);
137 138 139 140 141 142 143
    }

    Blob getC() const
    {
        CV_Assert(!cInternal.empty());

        //TODO: add convinient Mat -> Blob constructor
144 145
        Blob res(outTsShape, cInternal.type());
        res.fill(res.shape(), res.type(), cInternal.data);
146 147 148 149 150 151 152
        return res;
    }

    Blob getH() const
    {
        CV_Assert(!hInternal.empty());

153 154
        Blob res(outTsShape, hInternal.type());
        res.fill(res.shape(), res.type(), hInternal.data);
155 156
        return res;
    }
157

158
    void setOutShape(const Shape &outTailShape_)
159 160 161 162 163
    {
        CV_Assert(!allocated || outTailShape_.total() == outTailShape.total());
        outTailShape = outTailShape_;
    }

164 165 166
    void setWeights(const Blob &Wh, const Blob &Wx, const Blob &bias)
    {
        CV_Assert(Wh.dims() == 2 && Wx.dims() == 2);
167 168
        CV_Assert(Wh.size(0) == Wx.size(0));
        CV_Assert(Wh.size(0) == 4*Wh.size(1));
169
        CV_Assert(Wh.size(0) == (int)bias.total());
170
        CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
171 172 173 174 175

        blobs.resize(3);
        blobs[0] = Wh;
        blobs[1] = Wx;
        blobs[2] = bias;
176
        blobs[2].reshape(Shape(1, (int)bias.total()));
177 178 179 180 181
    }

    void allocate(const std::vector<Blob*> &input, std::vector<Blob> &output)
    {
        CV_Assert(blobs.size() == 3);
182
        CV_Assert(input.size() == 1);
183

184
        Blob &Wh = blobs[0], &Wx = blobs[1];
185 186
        numOut = Wh.size(1);
        numInp = Wx.size(1);
187

188 189 190
        if (!outTailShape.isEmpty())
            CV_Assert(outTailShape.total() == numOut);
        else
191
            outTailShape = Shape(numOut);
192

193 194 195 196 197
        if (useTimestampDim)
        {
            CV_Assert(input[0]->dims() >= 2 && (int)input[0]->total(2) == numInp);
            numTimeStamps = input[0]->size(0);
            numSamples = input[0]->size(1);
198
            outResShape = Shape(numTimeStamps, numSamples) + outTailShape;
199 200 201 202 203 204
        }
        else
        {
            CV_Assert(input[0]->dims() >= 1 && (int)input[0]->total(1) == numInp);
            numTimeStamps = 1;
            numSamples = input[0]->size(0);
205
            outResShape = Shape(numSamples) + outTailShape;
206
        }
207 208
        outTsMatShape = Shape(numSamples, numOut);
        outTsShape = Shape(numSamples) + outTailShape;
209

210
        dtype = input[0]->type();
211 212
        CV_Assert(dtype == CV_32F || dtype == CV_64F);
        CV_Assert(Wh.type() == dtype);
213

214 215 216 217
        output.resize( (produceCellOutput) ? 2 : 1 );
        output[0].create(outResShape, dtype);
        if (produceCellOutput)
            output[1].create(outResShape, dtype);
218

219 220 221 222 223 224 225 226 227 228
        if (hInternal.empty())
        {
            hInternal.create(outTsMatShape.dims(), outTsMatShape.ptr(), dtype);
            hInternal.setTo(0);
        }
        else
        {
            CV_Assert((int)hInternal.total() == numSamples*numOut);
            hInternal = hInternal.reshape(1, outTsMatShape.dims(), outTsMatShape.ptr());
        }
229

230 231 232 233 234 235 236 237 238 239
        if (cInternal.empty())
        {
            cInternal.create(outTsMatShape.dims(), outTsMatShape.ptr(), dtype);
            cInternal.setTo(0);
        }
        else
        {
            CV_Assert((int)cInternal.total() == numSamples*numOut);
            cInternal = cInternal.reshape(1, outTsMatShape.dims(), outTsMatShape.ptr());
        }
240

241
        gates.create(numSamples, 4*numOut, dtype);
242

243 244
        dummyOnes.create(numSamples, 1, dtype);
        dummyOnes.setTo(1);
245

246
        allocated = true;
247 248 249 250
    }

    void forward(std::vector<Blob*> &input, std::vector<Blob> &output)
    {
251 252 253
        const Mat &Wh = blobs[0].getRefConst<Mat>();
        const Mat &Wx = blobs[1].getRefConst<Mat>();
        const Mat &bias = blobs[2].getRefConst<Mat>();
254 255

        int numSamplesTotal = numTimeStamps*numSamples;
256
        Mat xTs = reshaped(input[0]->getRefConst<Mat>(), Shape(numSamplesTotal, numInp));
257

258 259 260
        Shape outMatShape(numSamplesTotal, numOut);
        Mat hOutTs = reshaped(output[0].getRef<Mat>(), outMatShape);
        Mat cOutTs = (produceCellOutput) ? reshaped(output[1].getRef<Mat>(), outMatShape) : Mat();
261 262 263 264 265 266

        for (int ts = 0; ts < numTimeStamps; ts++)
        {
            Range curRowRange(ts*numSamples, (ts + 1)*numSamples);
            Mat xCurr = xTs.rowRange(curRowRange);

267 268 269
            dnn::gemm(xCurr, Wx, 1, gates, 0, GEMM_2_T);      // Wx * x_t
            dnn::gemm(hInternal, Wh, 1, gates, 1, GEMM_2_T);  //+Wh * h_{t-1}
            dnn::gemm(dummyOnes, bias, 1, gates, 1);          //+b
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293

            Mat getesIFO = gates.colRange(0, 3*numOut);
            Mat gateI = gates.colRange(0*numOut, 1*numOut);
            Mat gateF = gates.colRange(1*numOut, 2*numOut);
            Mat gateO = gates.colRange(2*numOut, 3*numOut);
            Mat gateG = gates.colRange(3*numOut, 4*numOut);

            sigmoid(getesIFO, getesIFO);
            tanh(gateG, gateG);

            //compute c_t
            cv::multiply(gateF, cInternal, gateF);  // f_t (*) c_{t-1}
            cv::multiply(gateI, gateG, gateI);      // i_t (*) g_t
            cv::add(gateF, gateI, cInternal);       // c_t = f_t (*) c_{t-1} + i_t (*) g_t

            //compute h_t
            tanh(cInternal, hInternal);
            cv::multiply(gateO, hInternal, hInternal);

            //save results in output blobs
            hInternal.copyTo(hOutTs.rowRange(curRowRange));
            if (produceCellOutput)
                cInternal.copyTo(cOutTs.rowRange(curRowRange));
        }
294 295 296 297 298 299 300 301
    }
};

Ptr<LSTMLayer> LSTMLayer::create()
{
    return Ptr<LSTMLayer>(new LSTMLayerImpl());
}

302
void LSTMLayer::forward(std::vector<Blob*>&, std::vector<Blob>&)
303
{
304
    CV_Error(Error::StsInternal, "This function should be unreached");
305 306
}

307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
int LSTMLayer::inputNameToIndex(String inputName)
{
    if (inputName.toLowerCase() == "x")
        return 0;
    return -1;
}

int LSTMLayer::outputNameToIndex(String outputName)
{
    if (outputName.toLowerCase() == "h")
        return 0;
    else if (outputName.toLowerCase() == "c")
        return 1;
    return -1;
}

323

324 325
class RNNLayerImpl : public RNNLayer
{
326 327
    int numX, numH, numO;
    int numSamples, numTimestamps, numSamplesTotal;
328
    int dtype;
329 330
    Mat Whh, Wxh, bh;
    Mat Who, bo;
331 332
    Mat hCurr, hPrev, dummyBiasOnes;
    bool produceH;
333 334 335 336 337 338

public:

    RNNLayerImpl()
    {
        type = "RNN";
339
        produceH = false;
340 341
    }

342 343 344 345 346 347
    void setProduceHiddenOutput(bool produce = false)
    {
        produceH = produce;
    }

    void setWeights(const Blob &W_xh, const Blob &b_h, const Blob &W_hh, const Blob &W_ho, const Blob &b_o)
348
    {
349
        CV_Assert(W_hh.dims() == 2 && W_xh.dims() == 2);
350 351
        CV_Assert(W_hh.size(0) == W_xh.size(0) && W_hh.size(0) == W_hh.size(1) && (int)b_h.total() == W_xh.size(0));
        CV_Assert(W_ho.size(0) == (int)b_o.total());
352 353 354
        CV_Assert(W_ho.size(1) == W_hh.size(1));

        blobs.resize(5);
355 356 357
        blobs[0] = W_xh;
        blobs[1] = b_h;
        blobs[2] = W_hh;
358 359
        blobs[3] = W_ho;
        blobs[4] = b_o;
360 361 362 363
    }

    void allocate(const std::vector<Blob*> &input, std::vector<Blob> &output)
    {
364 365
        CV_Assert(input.size() >= 1 && input.size() <= 2);

366 367 368
        Wxh = blobs[0].matRefConst();
        bh  = blobs[1].matRefConst();
        Whh = blobs[2].matRefConst();
369 370 371
        Who = blobs[3].matRefConst();
        bo  = blobs[4].matRefConst();

372 373 374
        numH = Wxh.rows;
        numX = Wxh.cols;
        numO = Who.rows;
375

376 377 378 379 380 381 382
        CV_Assert(input[0]->dims() >= 2);
        CV_Assert((int)input[0]->total(2) == numX);
        CV_Assert(input[0]->type() == CV_32F || input[0]->type() == CV_64F);
        dtype = input[0]->type();
        numTimestamps = input[0]->size(0);
        numSamples = input[0]->size(1);
        numSamplesTotal = numTimestamps * numSamples;
383

384 385 386
        hCurr.create(numSamples, numH, dtype);
        hPrev.create(numSamples, numH, dtype);
        hPrev.setTo(0);
387

388
        dummyBiasOnes.create(numSamples, 1, dtype);
389
        dummyBiasOnes.setTo(1);
390 391 392 393 394 395 396 397 398
        bh = bh.reshape(1, 1); //is 1 x numH Mat
        bo = bo.reshape(1, 1); //is 1 x numO Mat

        reshapeOutput(output);
    }

    void reshapeOutput(std::vector<Blob> &output)
    {
        output.resize((produceH) ? 2 : 1);
399
        output[0].create(Shape(numTimestamps, numSamples, numO), dtype);
400
        if (produceH)
401
            output[1].create(Shape(numTimestamps, numSamples, numH), dtype);
402 403 404 405
    }

    void forward(std::vector<Blob*> &input, std::vector<Blob> &output)
    {
406 407 408
        Mat xTs = reshaped(input[0]->getRefConst<Mat>(), Shape(numSamplesTotal, numX));
        Mat oTs = reshaped(output[0].getRef<Mat>(), Shape(numSamplesTotal, numO));
        Mat hTs = (produceH) ? reshaped(output[1].getRef<Mat>(), Shape(numSamplesTotal, numH)) : Mat();
409 410 411 412 413 414

        for (int ts = 0; ts < numTimestamps; ts++)
        {
            Range curRowRange = Range(ts * numSamples, (ts + 1) * numSamples);
            Mat xCurr = xTs.rowRange(curRowRange);

415 416 417
            dnn::gemm(hPrev, Whh, 1, hCurr, 0, GEMM_2_T); // W_{hh} * h_{prev}
            dnn::gemm(xCurr, Wxh, 1, hCurr, 1, GEMM_2_T); //+W_{xh} * x_{curr}
            dnn::gemm(dummyBiasOnes, bh, 1, hCurr, 1);    //+bh
418 419 420
            tanh(hCurr, hPrev);

            Mat oCurr = oTs.rowRange(curRowRange);
421 422
            dnn::gemm(hPrev, Who, 1, oCurr, 0, GEMM_2_T); // W_{ho} * h_{prev}
            dnn::gemm(dummyBiasOnes, bo, 1, oCurr, 1);    //+b_o
423 424 425 426 427
            tanh(oCurr, oCurr);

            if (produceH)
                hPrev.copyTo(hTs.rowRange(curRowRange));
        }
428 429 430
    }
};

431 432 433 434 435 436 437 438 439 440
void RNNLayer::forward(std::vector<Blob*>&, std::vector<Blob>&)
{
    CV_Error(Error::StsInternal, "This function should be unreached");
}

CV_EXPORTS_W Ptr<RNNLayer> RNNLayer::create()
{
    return Ptr<RNNLayer>(new RNNLayerImpl());
}

441
}
442
}