Commit 81283e9d authored by Vadim Pisarevsky's avatar Vadim Pisarevsky

Merge pull request #1194 from dkurt:fix_torch_reshape

parents d95053f2 e0fe0aa3
......@@ -55,6 +55,7 @@ static void computeShapeByReshapeMask(const MatShape &srcShape,
{
int srcShapeSize = (int)srcShape.size();
int maskShapeSize = (int)maskShape.size();
int maskTotal = abs(total(maskShape)); // Mask might have negative ones.
if (srcRange == Range::all())
srcRange = Range(0, srcShapeSize);
......@@ -65,6 +66,19 @@ static void computeShapeByReshapeMask(const MatShape &srcShape,
srcRange.end = srcRange.end == INT_MAX ? srcShapeSize : srcRange.start + sz;
}
if (maskTotal != 0)
{
for (int i = srcRange.start + 1; i < srcRange.end; ++i)
{
if (total(srcShape, i, srcRange.end) != maskTotal)
{
srcRange.start = i - 1;
break;
}
}
CV_Assert(total(srcShape, srcRange.start, srcRange.end) == maskTotal);
}
CV_Assert(0 <= srcRange.start && srcRange.start <= srcRange.end && srcRange.end <= srcShapeSize);
int dstShapeSize = srcShapeSize - srcRange.size() + maskShapeSize;
dstShape.resize(dstShapeSize);
......
......@@ -122,6 +122,7 @@ TEST(Torch_Importer, run_reshape)
{
runTorchNet("net_reshape");
runTorchNet("net_reshape_batch");
runTorchNet("net_reshape_single_sample");
}
TEST(Torch_Importer, run_linear)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment