Commit 0b4fc061 authored by Alexander Alekhin's avatar Alexander Alekhin

Merge pull request #1234 from dkurt:fix_reshape_layer

parents aa0d8060 12b17910
...@@ -55,7 +55,6 @@ static void computeShapeByReshapeMask(const MatShape &srcShape, ...@@ -55,7 +55,6 @@ static void computeShapeByReshapeMask(const MatShape &srcShape,
{ {
int srcShapeSize = (int)srcShape.size(); int srcShapeSize = (int)srcShape.size();
int maskShapeSize = (int)maskShape.size(); int maskShapeSize = (int)maskShape.size();
int maskTotal = abs(total(maskShape)); // Mask might have negative ones.
if (srcRange == Range::all()) if (srcRange == Range::all())
srcRange = Range(0, srcShapeSize); srcRange = Range(0, srcShapeSize);
...@@ -66,8 +65,15 @@ static void computeShapeByReshapeMask(const MatShape &srcShape, ...@@ -66,8 +65,15 @@ static void computeShapeByReshapeMask(const MatShape &srcShape,
srcRange.end = srcRange.end == INT_MAX ? srcShapeSize : srcRange.start + sz; srcRange.end = srcRange.end == INT_MAX ? srcShapeSize : srcRange.start + sz;
} }
if (maskTotal != 0) bool explicitMask = !maskShape.empty(); // All mask values are positive.
for (int i = 0, n = maskShape.size(); i < n && explicitMask; ++i)
{ {
explicitMask = maskShape[i] > 0;
}
// Working range of source shape is a range where area(src) == area(mask).
if (explicitMask)
{
int maskTotal = total(maskShape);
for (int i = srcRange.start + 1; i < srcRange.end; ++i) for (int i = srcRange.start + 1; i < srcRange.end; ++i)
{ {
if (total(srcShape, i, srcRange.end) != maskTotal) if (total(srcShape, i, srcRange.end) != maskTotal)
......
...@@ -169,14 +169,20 @@ TEST(Layer_Test_MVN, Accuracy) ...@@ -169,14 +169,20 @@ TEST(Layer_Test_MVN, Accuracy)
testLayerUsingCaffeModels("layer_mvn"); testLayerUsingCaffeModels("layer_mvn");
} }
TEST(Layer_Test_Reshape, squeeze) void testReshape(const MatShape& inputShape, const MatShape& targetShape,
int axis = 0, int num_axes = -1, bool reorder_dims = false,
MatShape mask = MatShape())
{ {
LayerParams params; LayerParams params;
params.set("axis", 2); params.set("axis", axis);
params.set("num_axes", 1); params.set("num_axes", num_axes);
params.set("reorder_dims", reorder_dims);
if (!mask.empty())
{
params.set("dim", DictValue::arrayInt<int*>(&mask[0], mask.size()));
}
int sz[] = {4, 3, 1, 2}; Mat inp(inputShape.size(), &inputShape[0], CV_32F);
Mat inp(4, sz, CV_32F);
std::vector<Mat> inpVec(1, inp); std::vector<Mat> inpVec(1, inp);
std::vector<Mat> outVec, intVec; std::vector<Mat> outVec, intVec;
...@@ -185,9 +191,23 @@ TEST(Layer_Test_Reshape, squeeze) ...@@ -185,9 +191,23 @@ TEST(Layer_Test_Reshape, squeeze)
Mat& out = outVec[0]; Mat& out = outVec[0];
MatShape shape(out.size.p, out.size.p + out.dims); MatShape shape(out.size.p, out.size.p + out.dims);
int sh0[] = {4, 3, 2}; EXPECT_EQ(shape, targetShape);
MatShape shape0(sh0, sh0+3); }
EXPECT_EQ(shape, shape0);
TEST(Layer_Test_Reshape, Accuracy)
{
{
int inp[] = {4, 3, 1, 2};
int out[] = {4, 3, 2};
testReshape(MatShape(inp, inp + 4), MatShape(out, out + 3), 2, 1);
}
{
int inp[] = {1, 128, 4, 4};
int out[] = {1, 2048};
int mask[] = {-1, 2048};
testReshape(MatShape(inp, inp + 4), MatShape(out, out + 2), 0, -1, true,
MatShape(mask, mask + 2));
}
} }
TEST(Layer_Test_BatchNorm, Accuracy) TEST(Layer_Test_BatchNorm, Accuracy)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment