reshape_layer.cpp 8.17 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                           License Agreement
//                For Open Source Computer Vision Library
//
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of the copyright holders may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "../precomp.hpp"
#include "layers_common.hpp"
44
#include <opencv2/dnn/shape_utils.hpp>
45 46 47 48 49 50

namespace cv
{
namespace dnn
{

51 52
static void computeShapeByReshapeMask(const MatShape &srcShape,
                                      const MatShape &maskShape,
53
                                      Range srcRange /*= Range::all()*/,
54
                                      MatShape& dstShape)
55
{
56 57
    int srcShapeSize = (int)srcShape.size();
    int maskShapeSize = (int)maskShape.size();
58

59 60 61 62 63
    if (srcRange == Range::all())
        srcRange = Range(0, srcShapeSize);
    else
    {
        int sz = srcRange.size();
64
        srcRange.start = clamp(srcRange.start, srcShapeSize);
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
        srcRange.end = srcRange.end == INT_MAX ? srcShapeSize : srcRange.start + sz;
    }

    CV_Assert(0 <= srcRange.start && srcRange.start <= srcRange.end && srcRange.end <= srcShapeSize);
    int dstShapeSize = srcShapeSize - srcRange.size() + maskShapeSize;
    dstShape.resize(dstShapeSize);

    std::copy(srcShape.begin(), srcShape.begin() + srcRange.start, dstShape.begin());
    std::copy(srcShape.begin() + srcRange.end, srcShape.begin() + srcShapeSize, dstShape.begin() + srcRange.start + maskShapeSize);

    int inferDim = -1;
    for (int i = 0; i < maskShapeSize; i++)
    {
        if (maskShape[i] > 0)
        {
            dstShape[srcRange.start + i] = maskShape[i];
        }
        else if (maskShape[i] == 0)
        {
            if (srcRange.start + i >= srcShapeSize)
                CV_Error(Error::StsBadArg, format("Copy dim[%d] (which has zero size) is out of the source shape bounds", srcRange.start + i));
            dstShape[srcRange.start + i] = srcShape[srcRange.start + i];
        }
        else if (maskShape[i] == -1)
        {
            if (inferDim != -1)
                CV_Error(Error::StsAssert, "Duplicate of inferred dim (which is denoted by -1)");
            inferDim = srcRange.start + i;
            dstShape[inferDim] = 1;
        }
        else
            CV_Error(Error::StsBadArg, "maskShape[i] >= -1");
    }

99 100
    size_t srcTotal = total(srcShape);
    size_t dstTotal = total(dstShape);
101 102 103 104 105

    if (inferDim != -1)
    {
        if (srcTotal % dstTotal != 0)
            CV_Error(Error::StsBackTrace, "Can't infer a dim denoted by -1");
106

107 108 109
        dstShape[inferDim] = (int)(srcTotal / dstTotal);
    }
    else
110
    {
111
        CV_Assert(srcTotal == dstTotal);
112 113 114
    }
}

115 116

class ReshapeLayerImpl : public ReshapeLayer
117
{
118
public:
119 120
    ReshapeLayerImpl(const LayerParams& params):
        performReordering(false)
121
    {
122 123 124 125 126 127 128 129 130
        setParamsFrom(params);
        int axis = params.get<int>("axis", 0);
        int numAxes = params.get<int>("num_axes", -1);
        enableReordering = params.get<bool>("reorder_dims", false);
        CV_Assert(numAxes >= -1);
        newShapeRange = (numAxes == -1) ? Range(axis, INT_MAX) : Range(axis, axis + numAxes);

        newShapeDesc.clear();
        if (params.has("dim"))
131
        {
132 133 134 135 136 137 138
            const DictValue &paramShape = params.get("dim");
            int i, dims = paramShape.size();
            newShapeDesc.resize(dims);
            for (i = 0; i < dims; i++)
                newShapeDesc[i] = paramShape.get<int>(i);
        }
    }
139

140 141 142 143
    bool getMemoryShapes(const std::vector<MatShape> &inputs,
                         const int requiredOutputs,
                         std::vector<MatShape> &outputs,
                         std::vector<MatShape> &internals) const
144
    {
145
        outputs.clear();
146

147 148
        for (size_t i = 0; i < inputs.size(); i++)
        {
149 150
            outputs.push_back(MatShape());
            computeShapeByReshapeMask(inputs[i], newShapeDesc, newShapeRange, outputs.back());
151
        }
152
        internals = outputs;
153 154 155 156 157 158 159 160 161 162 163

        return true;
    }

    void finalize(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
    {
        CV_Assert(inputs.size());
        CV_Assert(outputs.size());
        Mat srcBlob = *inputs[0];
        int dims = srcBlob.dims;
        MatShape inputShape = shape(srcBlob), outShape = shape(outputs[0]);
164 165 166 167 168 169 170 171 172 173

        // input.total() == output.total(). So if reordering is require,
        // one of the sizes will be are not equal.
        // Example where reordering is require: from 1x128x4x4 to 1x2048
        // Example where reordering is NOT require: from 1x1024x1x1 to 1x1024.
        bool reorderingRequire = false;
        const int minDims = min(dims, (int)outShape.size());
        for (int i = 0; !reorderingRequire && i < minDims; ++i)
            reorderingRequire = inputShape[i] != outShape[i];
        performReordering = enableReordering && reorderingRequire;
174
    }
175

176
    void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
177
    {
178
        for (size_t i = 0; i < inputs.size(); i++)
179 180
        {
            Mat srcBlob = *inputs[i];
181
            MatShape inputShape = shape(srcBlob);
182 183 184

            if (performReordering)
            {
185
                float *dstData = internals[i].ptr<float>();
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
                const float *srcData = srcBlob.ptr<float>();

                int num = inputShape[0], channels = inputShape[1], height = inputShape[2], width = inputShape[3];
                int total = num*channels*height*width;
                for(int i_n = 0; i_n < num; i_n++) {
                    for(int i_c = 0; i_c < channels; i_c++) {
                        for(int i_h = 0; i_h < height; i_h++) {
                            for(int i_w = 0; i_w < width; i_w++) {
                                int src_i = channels*height*width*i_n + height*width*i_c + width*i_h + i_w;
                                int dst_i = channels*height*width*i_n + i_c + channels*width*i_h + channels*i_w;

                                CV_Assert(dst_i < total);
                                CV_Assert(src_i < total);

                                dstData[dst_i] = srcData[src_i];
                            }
202 203 204
                        }
                    }
                }
205
                internals[i].copyTo(outputs[i]);
206 207
            }
        }
208 209
    }

210
private:
211
    std::vector<std::vector<int> > outShapes;
212
    bool enableReordering, performReordering;
213 214 215
};

Ptr<ReshapeLayer> ReshapeLayer::create(const LayerParams& params)
216
{
217
    return Ptr<ReshapeLayer>(new ReshapeLayerImpl(params));
218 219 220 221 222
}


}
}