Commit 45010af8 authored by Vitaliy Lyudvichenko's avatar Vitaliy Lyudvichenko

Extending of BlobShape and LSTMLayer API, adding of new test

parent a3c6f1dc
...@@ -16,7 +16,7 @@ if(${the_module}_WITH_BLAS) ...@@ -16,7 +16,7 @@ if(${the_module}_WITH_BLAS)
endif() endif()
if(NOT HAVE_BLAS) if(NOT HAVE_BLAS)
include(cmake/OpenCVFindMKL.cmake) include(cmake/OpenCVFindMKL.cmake)
if(MKL_FOUND AND FALSE) if(MKL_FOUND)
set(BLAS_INCLUDE_DIR ${MKL_INCLUDE_DIRS}) set(BLAS_INCLUDE_DIR ${MKL_INCLUDE_DIRS})
set(BLAS_LIBRARIES ${MKL_LIBRARIES} ) set(BLAS_LIBRARIES ${MKL_LIBRARIES} )
set(BLAS_CBLAS_H "mkl_cblas.h" ) set(BLAS_CBLAS_H "mkl_cblas.h" )
......
...@@ -98,12 +98,12 @@ namespace dnn ...@@ -98,12 +98,12 @@ namespace dnn
g_t &= tanh &(W_{xg} x_t + W_{hg} h_{t-1} + b_g), \\ g_t &= tanh &(W_{xg} x_t + W_{hg} h_{t-1} + b_g), \\
@f} @f}
where @f$W_{x?}@f$, @f$W_{h?}@f$ and @f$b_{?}@f$ are learned weights represented as matrices: where @f$W_{x?}@f$, @f$W_{h?}@f$ and @f$b_{?}@f$ are learned weights represented as matrices:
@f$W_{x?} \in R^{N_c \times N_x}@f$, @f$W_h? \in R^{N_c \times N_h}@f$, @f$b_? \in R^{N_c}@f$. @f$W_{x?} \in R^{N_h \times N_x}@f$, @f$W_{h?} \in R^{N_h \times N_h}@f$, @f$b_? \in R^{N_h}@f$.
For simplicity and performance purposes we use @f$ W_x = [W_{xi}; W_{xf}; W_{xo}, W_{xg}] @f$ For simplicity and performance purposes we use @f$ W_x = [W_{xi}; W_{xf}; W_{xo}, W_{xg}] @f$
(i.e. @f$W_x@f$ is vertical contacentaion of @f$ W_{x?} @f$), @f$ W_x \in R^{4N_c \times N_x} @f$. (i.e. @f$W_x@f$ is vertical contacentaion of @f$ W_{x?} @f$), @f$ W_x \in R^{4N_h \times N_x} @f$.
The same for @f$ W_h = [W_{hi}; W_{hf}; W_{ho}, W_{hg}], W_h \in R^{4N_c \times N_h} @f$ The same for @f$ W_h = [W_{hi}; W_{hf}; W_{ho}, W_{hg}], W_h \in R^{4N_h \times N_h} @f$
and for @f$ b = [b_i; b_f, b_o, b_g]@f$, @f$b \in R^{4N_c} @f$. and for @f$ b = [b_i; b_f, b_o, b_g]@f$, @f$b \in R^{4N_h} @f$.
@param Wh is matrix defining how previous output is transformed to internal gates (i.e. according to abovemtioned notation is @f$ W_h @f$) @param Wh is matrix defining how previous output is transformed to internal gates (i.e. according to abovemtioned notation is @f$ W_h @f$)
@param Wx is matrix defining how current input is transformed to internal gates (i.e. according to abovemtioned notation is @f$ W_x @f$) @param Wx is matrix defining how current input is transformed to internal gates (i.e. according to abovemtioned notation is @f$ W_x @f$)
...@@ -111,6 +111,12 @@ namespace dnn ...@@ -111,6 +111,12 @@ namespace dnn
*/ */
virtual void setWeights(const Blob &Wh, const Blob &Wx, const Blob &b) = 0; virtual void setWeights(const Blob &Wh, const Blob &Wx, const Blob &b) = 0;
/** @brief Specifies shape of output blob which will be [[`T`], `N`] + @p outTailShape.
* @details If this parameter is empty or unset then @p outTailShape = [`Wh`.size(0)] will be used,
* where `Wh` is parameter from setWeights().
*/
virtual void setOutShape(const BlobShape &outTailShape = BlobShape::empty()) = 0;
/** @brief Set @f$ h_{t-1} @f$ value that will be used in next forward() calls. /** @brief Set @f$ h_{t-1} @f$ value that will be used in next forward() calls.
* @details By-default @f$ h_{t-1} @f$ is inited by zeros and updated after each forward() call. * @details By-default @f$ h_{t-1} @f$ is inited by zeros and updated after each forward() call.
*/ */
...@@ -145,12 +151,16 @@ namespace dnn ...@@ -145,12 +151,16 @@ namespace dnn
* @param output contains computed outputs: @f$h_t@f$ (and @f$c_t@f$ if setProduceCellOutput() flag was set to true). * @param output contains computed outputs: @f$h_t@f$ (and @f$c_t@f$ if setProduceCellOutput() flag was set to true).
* *
* If setUseTimstampsDim() is set to true then @p input[0] should has at least two dimensions with the following shape: [`T`, `N`, `[data dims]`], * If setUseTimstampsDim() is set to true then @p input[0] should has at least two dimensions with the following shape: [`T`, `N`, `[data dims]`],
* where `T` specifies number of timpestamps, `N` is number of independent streams (i.e. x_{t_0 + t}^{stream} is @p input[0][t, stream, ...]). * where `T` specifies number of timpestamps, `N` is number of independent streams (i.e. @f$ x_{t_0 + t}^{stream} @f$ is stored inside @p input[0][t, stream, ...]).
* *
* If setUseTimstampsDim() is set to fase then @p input[0] should contain single timestamp, its shape should has form [`N`, `[data dims]`] with at least one dimension. * If setUseTimstampsDim() is set to fase then @p input[0] should contain single timestamp, its shape should has form [`N`, `[data dims]`] with at least one dimension.
* (i.e. x_{t}^{stream} = @p input[0][stream, ...]). * (i.e. @f$ x_{t}^{stream} @f$ is stored inside @p input[0][stream, ...]).
*/ */
void forward(std::vector<Blob*> &input, std::vector<Blob> &output); void forward(std::vector<Blob*> &input, std::vector<Blob> &output);
int inputNameToIndex(String inputName);
int outputNameToIndex(String outputName);
}; };
//! Classical recurrent layer //! Classical recurrent layer
......
...@@ -44,6 +44,7 @@ ...@@ -44,6 +44,7 @@
#include <opencv2/core.hpp> #include <opencv2/core.hpp>
#include <vector> #include <vector>
#include <ostream> #include <ostream>
#include <iostream>
namespace cv namespace cv
{ {
...@@ -56,7 +57,7 @@ namespace dnn ...@@ -56,7 +57,7 @@ namespace dnn
struct BlobShape struct BlobShape
{ {
BlobShape(); //!< Creates [1, 1, 1, 1] shape @todo Make more clearer behavior. BlobShape(); //!< Creates [1, 1, 1, 1] shape @todo Make more clearer behavior.
BlobShape(int s0); //!< Creates 1-dim shape [@p s0] explicit BlobShape(int s0); //!< Creates 1-dim shape [@p s0]
BlobShape(int s0, int s1); //!< @overload BlobShape(int s0, int s1); //!< @overload
BlobShape(int s0, int s1, int s2); //!< @overload BlobShape(int s0, int s1, int s2); //!< @overload
BlobShape(int num, int cn, int rows, int cols); //!< Creates 4-dim shape [@p num, @p cn, @p rows, @p cols] BlobShape(int num, int cn, int rows, int cols); //!< Creates 4-dim shape [@p num, @p cn, @p rows, @p cols]
...@@ -96,24 +97,35 @@ namespace dnn ...@@ -96,24 +97,35 @@ namespace dnn
*/ */
int xsize(int axis) const; int xsize(int axis) const;
/** @brief Converts @p axis index to canonical format (where 0 <= @p axis < dims()). */
int canonicalAxis(int axis) const;
/** @brief Returns the product of all sizes of axes. */ /** @brief Returns the product of all sizes of axes. */
ptrdiff_t total(); ptrdiff_t total() const;
/** @brief Computes the product of sizes of axes among the specified axes range [@p startAxis; @p endAxis).
* @details Negative axis indexing can be used. @sa Blob::total(int,int)
*/
ptrdiff_t total(int startAxis, int endAxis = INT_MAX) const;
/** @brief Constructs new shape from axes in range [@p startAxis; @p endAxis).
* @details Negative axis indexing can be used. @sa Blob::total(int,int)
*/
BlobShape slice(int startAxis, int endAxis = INT_MAX) const;
/** @brief Returns pointer to the first element of continuous size array. */ /** @brief Returns pointer to the first element of continuous size array. */
const int *ptr() const; const int *ptr() const;
/** @brief Checks equality of two shapes. */ bool equal(const BlobShape &other) const; //!< Checks equality of two shapes.
bool equal(const BlobShape &other) const; bool operator== (const BlobShape &r) const; //!< @sa equal()
bool operator== (const BlobShape &r) const; BlobShape operator+ (const BlobShape &r) const; //!< Contacenates two shapes.
/** @brief Contacenates two shapes */ static BlobShape like(const Mat &m); //!< Returns shape of passed Mat.
BlobShape operator+ (const BlobShape &r) const; static BlobShape like(const UMat &m); //!< Returns shape of passed UMat.
/** @brief Returns shape of passed Mat. */ static BlobShape empty(); //!< Returns empty shape [].
static BlobShape like(const Mat &m); bool isEmpty() const; //!< Returns true if shape is empty (i.e []).
/** @brief Returns shape of passed Mat. */
static BlobShape like(const UMat &m);
#ifdef CV_CXX_MOVE_SEMANTICS #ifdef CV_CXX_MOVE_SEMANTICS
//TBD //TBD
...@@ -183,7 +195,7 @@ namespace dnn ...@@ -183,7 +195,7 @@ namespace dnn
*/ */
size_t total(int startAxis = 0, int endAxis = INT_MAX) const; size_t total(int startAxis = 0, int endAxis = INT_MAX) const;
/** @brief Converts @p axis index to canonical format (where 0 <= axis < dims()). */ /** @brief Converts @p axis index to canonical format (where 0 <= @p axis < dims()). */
int canonicalAxis(int axis) const; int canonicalAxis(int axis) const;
/** @brief Returns shape of the blob. */ /** @brief Returns shape of the blob. */
......
...@@ -150,7 +150,13 @@ inline int &BlobShape::operator[] (int axis) ...@@ -150,7 +150,13 @@ inline int &BlobShape::operator[] (int axis)
return sz[(axis < 0) ? axis + dims() : axis]; return sz[(axis < 0) ? axis + dims() : axis];
} }
inline ptrdiff_t BlobShape::total() inline int BlobShape::canonicalAxis(int axis) const
{
CV_Assert(-dims() <= axis && axis < dims());
return (axis < 0) ? axis + dims() : axis;
}
inline ptrdiff_t BlobShape::total() const
{ {
if (dims() == 0) if (dims() == 0)
return 0; return 0;
...@@ -161,6 +167,42 @@ inline ptrdiff_t BlobShape::total() ...@@ -161,6 +167,42 @@ inline ptrdiff_t BlobShape::total()
return res; return res;
} }
inline ptrdiff_t BlobShape::total(int startAxis, int endAxis) const
{
if (isEmpty())
return 0;
if (endAxis == INT_MAX)
endAxis = dims();
else if (endAxis < 0)
endAxis += dims();
startAxis = (startAxis < 0) ? startAxis + dims() : startAxis;
CV_Assert(0 <= startAxis && startAxis <= endAxis && endAxis <= dims());
ptrdiff_t res = 1;
for (int i = startAxis; i < endAxis; i++)
res *= sz[i];
return res;
}
inline BlobShape BlobShape::slice(int startAxis, int endAxis) const
{
if (isEmpty())
return BlobShape::empty();
if (endAxis == INT_MAX)
endAxis = dims();
else if (endAxis < 0)
endAxis += dims();
startAxis = (startAxis < 0) ? startAxis + dims() : startAxis;
CV_Assert(0 <= startAxis && startAxis <= endAxis && endAxis <= dims());
BlobShape res(endAxis - startAxis, (const int*)NULL);
for (int i = startAxis; i < endAxis; i++)
res[i - startAxis] = sz[i];
return res;
}
inline const int *BlobShape::ptr() const inline const int *BlobShape::ptr() const
{ {
return sz; return sz;
...@@ -195,6 +237,16 @@ inline BlobShape BlobShape::like(const UMat &m) ...@@ -195,6 +237,16 @@ inline BlobShape BlobShape::like(const UMat &m)
return BlobShape(m.dims, (const int*)m.size); return BlobShape(m.dims, (const int*)m.size);
} }
inline BlobShape BlobShape::empty()
{
return BlobShape(0, (const int*)NULL);
}
inline bool BlobShape::isEmpty() const
{
return dims() == 0;
}
CV_EXPORTS std::ostream &operator<< (std::ostream &stream, const BlobShape &shape); CV_EXPORTS std::ostream &operator<< (std::ostream &stream, const BlobShape &shape);
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
......
...@@ -69,7 +69,7 @@ static void tanh(const Mat &src, Mat &dst) ...@@ -69,7 +69,7 @@ static void tanh(const Mat &src, Mat &dst)
else if (src.type() == CV_64F) else if (src.type() == CV_64F)
tanh<double>(src, dst); tanh<double>(src, dst);
else else
CV_Error(Error::StsUnsupportedFormat, "Functions supports only floating point types"); CV_Error(Error::StsUnsupportedFormat, "Function supports only floating point types");
} }
static void sigmoid(const Mat &src, Mat &dst) static void sigmoid(const Mat &src, Mat &dst)
...@@ -86,6 +86,10 @@ class LSTMLayerImpl : public LSTMLayer ...@@ -86,6 +86,10 @@ class LSTMLayerImpl : public LSTMLayer
int dtype; int dtype;
bool allocated; bool allocated;
BlobShape outTailShape; //shape of single output sample
BlobShape outTsMatShape, outTsShape; //shape of N output samples
BlobShape outResShape; //shape of T timestamps and N output samples
bool useTimestampDim; bool useTimestampDim;
bool produceCellOutput; bool produceCellOutput;
...@@ -97,6 +101,7 @@ public: ...@@ -97,6 +101,7 @@ public:
useTimestampDim = true; useTimestampDim = true;
produceCellOutput = false; produceCellOutput = false;
allocated = false; allocated = false;
outTailShape = BlobShape::empty();
} }
void setUseTimstampsDim(bool use) void setUseTimstampsDim(bool use)
...@@ -113,13 +118,19 @@ public: ...@@ -113,13 +118,19 @@ public:
void setC(const Blob &C) void setC(const Blob &C)
{ {
CV_Assert(!allocated || C.total() == cInternal.total()); CV_Assert(cInternal.empty() || C.total() == cInternal.total());
if (!cInternal.empty())
C.reshaped(BlobShape::like(cInternal)).matRefConst().copyTo(cInternal);
else
C.matRefConst().copyTo(cInternal); C.matRefConst().copyTo(cInternal);
} }
void setH(const Blob &H) void setH(const Blob &H)
{ {
CV_Assert(!allocated || H.total() == hInternal.total()); CV_Assert(hInternal.empty() || H.total() == hInternal.total());
if (!hInternal.empty())
H.reshaped(BlobShape::like(hInternal)).matRefConst().copyTo(hInternal);
else
H.matRefConst().copyTo(hInternal); H.matRefConst().copyTo(hInternal);
} }
...@@ -128,8 +139,8 @@ public: ...@@ -128,8 +139,8 @@ public:
CV_Assert(!cInternal.empty()); CV_Assert(!cInternal.empty());
//TODO: add convinient Mat -> Blob constructor //TODO: add convinient Mat -> Blob constructor
Blob res; Blob res(outTsShape, cInternal.type());
res.fill(BlobShape::like(cInternal), cInternal.type(), cInternal.data); res.fill(res.shape(), res.type(), cInternal.data);
return res; return res;
} }
...@@ -137,11 +148,17 @@ public: ...@@ -137,11 +148,17 @@ public:
{ {
CV_Assert(!hInternal.empty()); CV_Assert(!hInternal.empty());
Blob res; Blob res(outTsShape, hInternal.type());
res.fill(BlobShape::like(hInternal), hInternal.type(), hInternal.data); res.fill(res.shape(), res.type(), hInternal.data);
return res; return res;
} }
void setOutShape(const BlobShape &outTailShape_)
{
CV_Assert(!allocated || outTailShape_.total() == outTailShape.total());
outTailShape = outTailShape_;
}
void setWeights(const Blob &Wh, const Blob &Wx, const Blob &bias) void setWeights(const Blob &Wh, const Blob &Wx, const Blob &bias)
{ {
CV_Assert(Wh.dims() == 2 && Wx.dims() == 2); CV_Assert(Wh.dims() == 2 && Wx.dims() == 2);
...@@ -160,31 +177,64 @@ public: ...@@ -160,31 +177,64 @@ public:
void allocate(const std::vector<Blob*> &input, std::vector<Blob> &output) void allocate(const std::vector<Blob*> &input, std::vector<Blob> &output)
{ {
CV_Assert(blobs.size() == 3); CV_Assert(blobs.size() == 3);
Blob &Wh = blobs[0], &Wx = blobs[1]; CV_Assert(input.size() == 1);
Blob &Wh = blobs[0], &Wx = blobs[1];
numOut = Wh.size(1); numOut = Wh.size(1);
numInp = Wx.size(1); numInp = Wx.size(1);
CV_Assert(input.size() == 1); if (!outTailShape.isEmpty())
CV_Assert(input[0]->dims() > 2 && (int)input[0]->total(2) == numInp); CV_Assert(outTailShape.total() == numOut);
else
outTailShape = BlobShape(numOut);
if (useTimestampDim)
{
CV_Assert(input[0]->dims() >= 2 && (int)input[0]->total(2) == numInp);
numTimeStamps = input[0]->size(0); numTimeStamps = input[0]->size(0);
numSamples = input[0]->size(1); numSamples = input[0]->size(1);
dtype = input[0]->type(); outResShape = BlobShape(numTimeStamps, numSamples) + outTailShape;
}
else
{
CV_Assert(input[0]->dims() >= 1 && (int)input[0]->total(1) == numInp);
numTimeStamps = 1;
numSamples = input[0]->size(0);
outResShape = BlobShape(numSamples) + outTailShape;
}
outTsMatShape = BlobShape(numSamples, numOut);
outTsShape = BlobShape(numSamples) + outTailShape;
dtype = input[0]->type();
CV_Assert(dtype == CV_32F || dtype == CV_64F); CV_Assert(dtype == CV_32F || dtype == CV_64F);
CV_Assert(Wh.type() == dtype); CV_Assert(Wh.type() == dtype);
BlobShape outShape(numTimeStamps, numSamples, numOut); output.resize( (produceCellOutput) ? 2 : 1 );
output.resize(2); output[0].create(outResShape, dtype);
output[0].create(outShape, dtype); if (produceCellOutput)
output[1].create(outShape, dtype); output[1].create(outResShape, dtype);
hInternal.create(numSamples, numOut, dtype); if (hInternal.empty())
{
hInternal.create(outTsMatShape.dims(), outTsMatShape.ptr(), dtype);
hInternal.setTo(0); hInternal.setTo(0);
}
else
{
CV_Assert((int)hInternal.total() == numSamples*numOut);
hInternal = hInternal.reshape(1, outTsMatShape.dims(), outTsMatShape.ptr());
}
cInternal.create(numSamples, numOut, dtype); if (cInternal.empty())
{
cInternal.create(outTsMatShape.dims(), outTsMatShape.ptr(), dtype);
cInternal.setTo(0); cInternal.setTo(0);
}
else
{
CV_Assert((int)cInternal.total() == numSamples*numOut);
cInternal = cInternal.reshape(1, outTsMatShape.dims(), outTsMatShape.ptr());
}
gates.create(numSamples, 4*numOut, dtype); gates.create(numSamples, 4*numOut, dtype);
...@@ -252,6 +302,22 @@ void LSTMLayer::forward(std::vector<Blob*>&, std::vector<Blob>&) ...@@ -252,6 +302,22 @@ void LSTMLayer::forward(std::vector<Blob*>&, std::vector<Blob>&)
CV_Error(Error::StsInternal, "This function should be unreached"); CV_Error(Error::StsInternal, "This function should be unreached");
} }
int LSTMLayer::inputNameToIndex(String inputName)
{
if (inputName.toLowerCase() == "x")
return 0;
return -1;
}
int LSTMLayer::outputNameToIndex(String outputName)
{
if (outputName.toLowerCase() == "h")
return 0;
else if (outputName.toLowerCase() == "c")
return 1;
return -1;
}
class RNNLayerImpl : public RNNLayer class RNNLayerImpl : public RNNLayer
{ {
......
...@@ -181,78 +181,81 @@ enum RunLayerMode ...@@ -181,78 +181,81 @@ enum RunLayerMode
{ {
ALLOC_ONLY = 1, ALLOC_ONLY = 1,
FORWARD_ONLY = 2, FORWARD_ONLY = 2,
ALLOC_AND_FORWARD = 3 ALLOC_AND_FORWARD = ALLOC_ONLY | FORWARD_ONLY
}; };
void runLayer(Ptr<Layer> layer, std::vector<Blob> &inpBlobs, std::vector<Blob> &outBlobs, int mode=ALLOC_AND_FORWARD) typedef Ptr<std::vector<Blob*> > PtrToVecPtrBlob;
PtrToVecPtrBlob
runLayer(Ptr<Layer> layer, std::vector<Blob> &inpBlobs, std::vector<Blob> &outBlobs, int mode=ALLOC_AND_FORWARD)
{ {
std::vector<Blob*> inpPtrs(inpBlobs.size()); PtrToVecPtrBlob inpPtrs( new std::vector<Blob*>() );
inpPtrs->reserve(inpBlobs.size());
for (size_t i = 0; i < inpBlobs.size(); i++) for (size_t i = 0; i < inpBlobs.size(); i++)
inpPtrs[i] = &inpBlobs[i]; inpPtrs->push_back(&inpBlobs[i]);
if (mode & ALLOC_ONLY) layer->allocate(*inpPtrs, outBlobs);
if (mode & FORWARD_ONLY) layer->forward(*inpPtrs, outBlobs);
if (mode & ALLOC_ONLY) layer->allocate(inpPtrs, outBlobs); return inpPtrs;
if (mode & FORWARD_ONLY) layer->forward(inpPtrs, outBlobs);
} }
class Layer_LSTM_Test : public ::testing::Test class Layer_LSTM_Test : public ::testing::Test
{ {
public: public:
int Nx, Nc; int numInp, numOut;
Blob Wh, Wx, b; Blob Wh, Wx, b;
Ptr<LSTMLayer> layer; Ptr<LSTMLayer> layer;
std::vector<Blob> inputs, outputs; std::vector<Blob> inputs, outputs;
std::vector<Blob*> inputsPtr;
Layer_LSTM_Test(int _Nx = 31, int _Nc = 100) Layer_LSTM_Test() {}
void init(const BlobShape &inpShape_, const BlobShape &outShape_)
{ {
Nx = _Nx; numInp = inpShape_.total();
Nc = _Nc; numOut = outShape_.total();
Wh = Blob(BlobShape(4 * Nc, Nc)); Wh = Blob(BlobShape(4 * numOut, numOut));
Wx = Blob(BlobShape(4 * Nc, Nx)); Wx = Blob(BlobShape(4 * numOut, numInp));
b = Blob(BlobShape(4 * Nc, 1)); b = Blob(BlobShape(4 * numOut, 1));
layer = LSTMLayer::create(); layer = LSTMLayer::create();
layer->setWeights(Wh, Wx, b); layer->setWeights(Wh, Wx, b);
} layer->setOutShape(outShape_);
void allocateAndForward()
{
inputsPtr.clear();
for (size_t i = 0; i < inputs.size(); i++)
inputsPtr.push_back(&inputs[i]);
layer->allocate(inputsPtr, outputs);
layer->forward(inputsPtr, outputs);
} }
}; };
TEST_F(Layer_LSTM_Test, BasicTest_1) TEST_F(Layer_LSTM_Test, get_set_test)
{ {
inputs.push_back(Blob(BlobShape(1, 2, 3, Nx))); BlobShape TN(4);
allocateAndForward(); BlobShape inpShape(5, 3, 2), inpResShape = TN + inpShape;
BlobShape outShape(3, 1, 2), outResShape = TN + outShape;
EXPECT_EQ(outputs.size(), 2); init(inpShape, outShape);
EXPECT_EQ(outputs[0].shape(), BlobShape(1, 2, 3, Nc)); layer->setProduceCellOutput(true);
EXPECT_EQ(outputs[1].shape(), BlobShape(1, 2, 3, Nc)); layer->setUseTimstampsDim(false);
} layer->setOutShape(outShape);
TEST_F(Layer_LSTM_Test, BasicTest_2) layer->setC(Blob(outResShape));
{ layer->setH(Blob(outResShape));
inputs.push_back(Blob(BlobShape(1, 2, 3, Nx)));
inputs.push_back(Blob(BlobShape(1, 2, 3, Nc)));
inputs.push_back(Blob(BlobShape(1, 2, 3, Nc)));
allocateAndForward();
EXPECT_EQ(outputs.size(), 2); inputs.push_back(Blob(inpResShape));
EXPECT_EQ(outputs[0].shape(), BlobShape(1, 2, 3, Nc)); runLayer(layer, inputs, outputs);
EXPECT_EQ(outputs[1].shape(), BlobShape(1, 2, 3, Nc));
EXPECT_EQ(2, outputs.size());
EXPECT_EQ(outResShape, outputs[0].shape());
EXPECT_EQ(outResShape, outputs[1].shape());
EXPECT_EQ(outResShape, layer->getC().shape());
EXPECT_EQ(outResShape, layer->getH().shape());
EXPECT_EQ(0, layer->inputNameToIndex("x"));
EXPECT_EQ(0, layer->outputNameToIndex("h"));
EXPECT_EQ(1, layer->outputNameToIndex("c"));
} }
TEST(Layer_LSTM_Test_Accuracy_Reference_with_, CaffeRecurrent) TEST(Layer_LSTM_Test_Accuracy_Reference_with_, CaffeRecurrent)
{ {
Ptr<LSTMLayer> layer = LSTMLayer::create(); Ptr<LSTMLayer> layer = LSTMLayer::create();
Blob Wx = blobFromNPY(_tf("lstm.prototxt.w_0.npy")); Blob Wx = blobFromNPY(_tf("lstm.prototxt.w_0.npy"));
...@@ -262,13 +265,11 @@ TEST(Layer_LSTM_Test_Accuracy_Reference_with_, CaffeRecurrent) ...@@ -262,13 +265,11 @@ TEST(Layer_LSTM_Test_Accuracy_Reference_with_, CaffeRecurrent)
Blob inp = blobFromNPY(_tf("blob.npy")); Blob inp = blobFromNPY(_tf("blob.npy"));
std::vector<Blob> inputs(1, inp), outputs; std::vector<Blob> inputs(1, inp), outputs;
runLayer(layer, inputs, outputs, ALLOC_ONLY | FORWARD_ONLY); runLayer(layer, inputs, outputs);
Blob &h_t_gathered = outputs[0]; Blob &h_t_gathered = outputs[0];
Blob h_t_reference = blobFromNPY(_tf("lstm.prototxt.h_1.npy")); Blob h_t_reference = blobFromNPY(_tf("lstm.prototxt.h_1.npy"));
//h_t_gathered.reshape(h_t_reference.shape());
normAssert(h_t_reference, h_t_gathered); normAssert(h_t_reference, h_t_gathered);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment