blob.cpp 11 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                           License Agreement
//                For Open Source Computer Vision Library
//
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of the copyright holders may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"
43
#include <opencv2/dnn/shape_utils.hpp>
44 45 46 47 48 49

namespace cv
{
namespace dnn
{

50 51 52 53
Blob::Blob()
{
    CV_DNN_UMAT_ONLY(state = UNINITIALIZED);
}
54

55 56 57 58 59
Blob::Blob(const BlobShape &shape, int type, int allocFlags)
{
    CV_DNN_UMAT_ONLY(state = UNINITIALIZED);
    this->create(shape, type, allocFlags);
}
60

61 62
Blob::Blob(InputArray data)
{
63
#ifndef CV_DNN_UMAT
64
    m = data.getMat();
65
#else
66
    if (data.isUMat())
67
    {
68 69
        um = data.getUMat();
        state = HEAD_AT_UMAT;
70
    }
71
    else
72
    {
73 74
        m = data.getMat();
        state = HEAD_AT_MAT;
75 76 77 78 79 80
    }
#endif
}

void Blob::create(const BlobShape &shape, int type, int allocFlags)
{
81
#ifndef CV_DNN_UMAT
82 83
    CV_Assert(allocFlags & ALLOC_MAT);
    m.create(shape.dims(), shape.ptr(), type);
84
#else
85
    CV_Assert(allocFlags & ALLOC_MAT || allocFlags & ALLOC_UMAT);
86

87 88 89 90
    if (allocFlags & ALLOC_MAT)
        m.create(shape.dims(), shape.ptr(), type);
    if (allocFlags & ALLOC_UMAT)
        um.create(shape.dims(), shape.ptr(), type);
91

92
    if (state == UNINITIALIZED)
93
    {
94 95 96
        if (allocFlags & ALLOC_MAT && allocFlags & ALLOC_UMAT)
            state = SYNCED;
        else if (allocFlags & ALLOC_MAT)
97 98 99 100
            state = HEAD_AT_MAT;
        else
            state = HEAD_AT_UMAT;
    }
101 102
#endif
}
103

104 105 106 107 108
void Blob::fill(InputArray in)
{
#ifdef CV_DNN_UMAT
    CV_Assert(in.isMat() || in.isUMat());
    if (in.isMat())
109
    {
110 111
        m = in.getMat();
        state = HEAD_AT_MAT;
112
    }
113
    else
114
    {
115 116 117 118 119 120 121 122
        um = in.getUMat();
        state = HEAD_AT_UMAT;
    }
#else
    CV_Assert(in.isMat());
    m = in.getMat();
#endif
}
123

124 125 126 127
static inline int getMatChannels(const Mat &mat)
{
    return (mat.dims <= 2) ? mat.channels() : mat.size[0];
}
128

129 130 131 132
static BlobShape getBlobShape(std::vector<Mat> &vmat, int requestedCn = -1)
{
    BlobShape shape(BlobShape::all(4));
    int cnSum = 0, matCn;
133

134
    CV_Assert(vmat.size() > 0);
135

136 137 138 139 140
    for (size_t i = 0; i < vmat.size(); i++)
    {
        Mat &mat = vmat[i];
        CV_Assert(!mat.empty());
        CV_Assert((mat.dims == 3 && mat.channels() == 1) || mat.dims <= 2);
141

142 143
        matCn = getMatChannels(mat);
        cnSum += getMatChannels(mat);
144

145
        if (i == 0)
146
        {
147 148 149
            shape[-1] = mat.cols;
            shape[-2] = mat.rows;
            shape[-3] = (requestedCn <= 0) ? matCn : requestedCn;
150 151 152
        }
        else
        {
153 154 155 156 157
            if (mat.cols != shape[-1] || mat.rows != shape[-2])
                CV_Error(Error::StsError, "Each Mat.size() must be equal");

            if (requestedCn <= 0 && matCn != shape[-3])
                CV_Error(Error::StsError, "Each Mat.chnannels() (or number of planes) must be equal");
158 159 160
        }
    }

161 162 163 164 165 166 167 168 169 170 171 172 173 174
    if (cnSum % shape[-3] != 0)
        CV_Error(Error::StsError, "Total number of channels in vector is not a multiple of requsted channel number");

    shape[0] = cnSum / shape[-3];
    return shape;
}

static std::vector<Mat> extractMatVector(InputArray in)
{
    if (in.isMat() || in.isUMat())
    {
        return std::vector<Mat>(1, in.getMat());
    }
    else if (in.isMatVector())
175
    {
176 177 178 179 180 181 182 183 184 185 186 187 188 189
        return *static_cast<const std::vector<Mat>*>(in.getObj());
    }
    else if (in.isUMatVector())
    {
        std::vector<Mat> vmat;
        in.getMatVector(vmat);
        return vmat;
    }
    else
    {
        CV_Assert(in.isMat() || in.isMatVector() || in.isUMat() || in.isUMatVector());
        return std::vector<Mat>();
    }
}
190

191 192 193 194 195
void Blob::batchFromImages(InputArray image, int dstCn)
{
    CV_Assert(dstCn == -1 || dstCn > 0);
    std::vector<Mat> inMats = extractMatVector(image);
    BlobShape dstShape = getBlobShape(inMats, dstCn);
196

197 198 199 200
    int dtype = CV_32F;
    this->create(dstShape, dtype, ALLOC_MAT);
    uchar *dstPtr = this->matRef().ptr();
    int elemSize = CV_ELEM_SIZE(dtype);
201

202 203 204 205
    std::vector<Mat> wrapBuf(dstShape[-3]);
    for (size_t i = 0; i < inMats.size(); i++)
    {
        Mat inMat = inMats[i];
206

207 208 209
        if (inMat.dims <= 2)
        {
            inMat.convertTo(inMat, dtype);
210

211 212
            wrapBuf.resize(0);
            for (int cn = 0; cn < inMat.channels(); cn++)
213
            {
214
                wrapBuf.push_back(Mat(inMat.rows, inMat.cols, dtype, dstPtr));
215
                dstPtr += elemSize * inMat.total();
216 217
            }

218
            cv::split(inMat, wrapBuf);
219 220 221
        }
        else
        {
222 223
            inMat.convertTo(Mat(inMat.dims, inMat.size, dtype, dstPtr), dtype);
            dstPtr += elemSize * inMat.total();
224 225
        }
    }
226
}
227

228 229 230 231 232 233 234 235 236 237
Blob Blob::fromImages(InputArray image, int dstCn)
{
    Blob res;
    res.batchFromImages(image, dstCn);
    return res;
}

void Blob::fill(const BlobShape &shape, int type, void *data, bool deepCopy)
{
    if (deepCopy)
238
    {
239 240 241 242 243 244 245 246 247
        create(shape, type);
        memcpy(ptr(), data, this->total() * CV_ELEM_SIZE(type));
    }
    else
    {
        m = Mat(shape.dims(), shape.ptr(), type, data);
    }
    CV_DNN_UMAT_ONLY(state = HEAD_AT_MAT);
}
248

249 250 251 252 253 254
void Blob::setTo(InputArray value, int allocFlags)
{
#ifdef CV_DNN_UMAT
    if (allocFlags == -1)
    {
        if (state == HEAD_AT_UMAT)
255
            um.setTo(value);
256 257 258
        else if (state == HEAD_AT_MAT)
            m.setTo(value);
        else //SYNCED or UNINITIALIZED
259
        {
260 261 262 263 264
            um.setTo(value);
            m.setTo(value);

            if (state == UNINITIALIZED)
                state = SYNCED;
265
        }
266 267 268
    }
    else if (allocFlags == ALLOC_BOTH)
    {
269
        m.setTo(value);
270 271
        um.setTo(value);
        state = SYNCED;
272
    }
273
    else if (allocFlags == ALLOC_MAT)
274
    {
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
        matRef().setTo(value);
    }
    else if (allocFlags == ALLOC_UMAT)
    {
        umatRef().setTo(value);
    }
    else
    {
        CV_Error(Error::StsBadArg, "allocFlags sholud be -1 or one of Blob::AllocFlag values");
    }
#else
    m.setTo(value);
#endif
}

void Blob::updateMat(bool syncData) const
{
292
#ifdef CV_DNN_UMAT
293 294 295 296 297 298 299 300
    if (state == UNINITIALIZED || state == SYNCED || state == HEAD_AT_MAT)
    {
        return;
    }
    else if (state == HEAD_AT_UMAT)
    {
        if (syncData)
            um.copyTo(m);
301
        else
302 303 304 305 306 307 308
            m.create(dims(), sizes(), type());
        state = SYNCED;
    }
    else
    {
        CV_Error(Error::StsInternal, "");
    }
309
#else
310
    (void)syncData;
311
#endif
312 313 314 315 316 317 318 319
}

void Blob::updateUMat(bool syncData) const
{
#ifdef CV_DNN_UMAT
    if (state == UNINITIALIZED || state == SYNCED || state == HEAD_AT_UMAT)
    {
        return;
320
    }
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
    else if (state == HEAD_AT_MAT)
    {
        if (syncData)
            m.copyTo(um);
        else
            um.create(dims(), sizes(), type());
    }
    else
    {
        CV_Error(Error::StsInternal, "");
    }
#else
    (void)syncData;
#endif
}

void Blob::sync() const
{
    updateMat();
    updateUMat();
}
342

343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
Vec4i Blob::shape4() const
{
    return Vec4i(num(), channels(), rows(), cols());
}

//BlobShape

std::ostream &operator<< (std::ostream &stream, const BlobShape &shape)
{
    stream << "[";

    for (int i = 0; i < shape.dims() - 1; i++)
        stream << shape[i] << ", ";
    if (shape.dims() > 0)
        stream << shape[-1];

    return stream << "]";
}

BlobShape computeShapeByReshapeMask(const BlobShape &srcShape, const BlobShape &maskShape, Range srcRange /*= Range::all()*/)
{
    if (srcRange == Range::all())
        srcRange = Range(0, srcShape.dims());
366 367 368 369 370 371
    else
    {
        int sz = srcRange.size();
        srcRange.start = srcShape.canonicalAxis(srcRange.start);
        srcRange.end =  (srcRange.end == INT_MAX) ? srcShape.dims() : srcRange.start + sz;
    }
372 373

    CV_Assert(0 <= srcRange.start && srcRange.start <= srcRange.end && srcRange.end <= srcShape.dims());
374
    BlobShape dstShape(srcShape.dims() - srcRange.size() + maskShape.dims(), (const int*)NULL);
375 376 377 378 379 380

    std::copy(srcShape.ptr(), srcShape.ptr() + srcRange.start, dstShape.ptr());
    std::copy(srcShape.ptr() + srcRange.end, srcShape.ptr() + srcShape.dims(), dstShape.ptr() + srcRange.start + maskShape.dims());

    int inferDim = -1;
    for (int i = 0; i < maskShape.dims(); i++)
381
    {
382
        if (maskShape[i] > 0)
383
        {
384
            dstShape[srcRange.start + i] = maskShape[i];
385
        }
386
        else if (maskShape[i] == 0)
387
        {
388 389 390
            if (srcRange.start + i >= srcShape.dims())
                CV_Error(Error::StsBadArg, format("Copy dim[%d] (which has zero size) is out of the source shape bounds", srcRange.start + i));
            dstShape[srcRange.start + i] = srcShape[srcRange.start + i];
391
        }
392
        else if (maskShape[i] == -1)
393
        {
394 395 396 397
            if (inferDim != -1)
                CV_Error(Error::StsAssert, "Duplicate of inferred dim (which is denoted by -1)");
            inferDim = srcRange.start + i;
            dstShape[inferDim] = 1;
398
        }
399 400
        else
            CV_Error(Error::StsBadArg, "maskShape[i] >= -1");
401
    }
402

403
    if (inferDim != -1)
404
    {
405 406 407 408
        ptrdiff_t srcTotal = srcShape.total();
        ptrdiff_t dstTotal = dstShape.total();
        if (srcTotal % dstTotal != 0)
            CV_Error(Error::StsBackTrace, "Can't infer a dim denoted by -1");
409

410
        dstShape[inferDim] = (int)(srcTotal / dstTotal);
411
    }
412
    else
413
    {
414 415
        CV_Assert(srcShape.total() == dstShape.total());
    }
416

417 418
    return dstShape;
}
419 420 421

}
}