caffe_importer.cpp 13 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                           License Agreement
//                For Open Source Computer Vision Library
//
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of the copyright holders may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"
using namespace cv;
using namespace cv::dnn;

#if HAVE_PROTOBUF
#include "caffe.pb.h"

#include <iostream>
#include <fstream>
51
#include <sstream>
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
#include <algorithm>
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include "caffe_io.hpp"

using ::google::protobuf::RepeatedField;
using ::google::protobuf::RepeatedPtrField;
using ::google::protobuf::Message;
using ::google::protobuf::Descriptor;
using ::google::protobuf::FieldDescriptor;
using ::google::protobuf::Reflection;

namespace
{

68 69 70 71 72 73 74
template<typename T>
static cv::String toString(const T &v)
{
    std::ostringstream ss;
    ss << v;
    return ss.str();
}
75

76 77 78 79
class CaffeImporter : public Importer
{
    caffe::NetParameter net;
    caffe::NetParameter netBinary;
80

81
public:
82

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    CaffeImporter(const char *pototxt, const char *caffeModel)
    {
        ReadNetParamsFromTextFileOrDie(pototxt, &net);

        if (caffeModel && caffeModel[0])
            ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
    }

    void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams &params)
    {
        const Reflection *refl = msg.GetReflection();
        int type = field->cpp_type();
        bool isRepeated = field->is_repeated();
        const std::string &name = field->name();

        #define SET_UP_FILED(getter, arrayConstr, gtype)                                    \
            if (isRepeated) {                                                               \
                const RepeatedField<gtype> &v = refl->GetRepeatedField<gtype>(msg, field);  \
                params.set(name, DictValue::arrayConstr(v.begin(), (int)v.size()));                  \
            }                                                                               \
            else {                                                                          \
                params.set(name, refl->getter(msg, field));                               \
105 106
            }

107
        switch (type)
108
        {
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
        case FieldDescriptor::CPPTYPE_INT32:
            SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int32);
            break;
        case FieldDescriptor::CPPTYPE_UINT32:
            SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint32);
            break;
        case FieldDescriptor::CPPTYPE_INT64:
            SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int64);
            break;
        case FieldDescriptor::CPPTYPE_UINT64:
            SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint64);
            break;
        case FieldDescriptor::CPPTYPE_BOOL:
            SET_UP_FILED(GetBool, arrayInt, bool);
            break;
        case FieldDescriptor::CPPTYPE_DOUBLE:
            SET_UP_FILED(GetDouble, arrayReal, double);
            break;
        case FieldDescriptor::CPPTYPE_FLOAT:
            SET_UP_FILED(GetFloat, arrayReal, float);
            break;
        case FieldDescriptor::CPPTYPE_STRING:
            if (isRepeated) {
                const RepeatedPtrField<std::string> &v = refl->GetRepeatedPtrField<std::string>(msg, field);
                params.set(name, DictValue::arrayString(v.begin(), (int)v.size()));
            }
            else {
                params.set(name, refl->GetString(msg, field));
            }
            break;
        case FieldDescriptor::CPPTYPE_ENUM:
            if (isRepeated) {
                int size = refl->FieldSize(msg, field);
                std::vector<cv::String> buf(size);
                for (int i = 0; i < size; i++)
                    buf[i] = refl->GetRepeatedEnum(msg, field, i)->name();
                params.set(name, DictValue::arrayString(buf.begin(), size));
            }
            else {
                params.set(name, refl->GetEnum(msg, field)->name());
            }
            break;
        default:
            CV_Error(Error::StsError, "Unknown type \"" + String(field->type_name()) + "\" in prototxt");
            break;
154
        }
155
    }
156

157 158 159 160 161
    inline static bool ends_with_param(const std::string &str)
    {
        static const std::string _param("_param");
        return (str.size() >= _param.size()) && str.compare(str.size() - _param.size(), _param.size(), _param) == 0;
    }
162

163 164 165 166
    void extractLayerParams(const Message &msg, cv::dnn::LayerParams &params, bool isInternal = false)
    {
        const Descriptor *msgDesc = msg.GetDescriptor();
        const Reflection *msgRefl = msg.GetReflection();
167

168
        for (int fieldId = 0; fieldId < msgDesc->field_count(); fieldId++)
169
        {
170 171 172 173
            const FieldDescriptor *fd = msgDesc->field(fieldId);

            if (!isInternal && !ends_with_param(fd->name()))
                continue;
174

175 176 177 178 179
            bool hasData =  fd->is_required() ||
                            (fd->is_optional() && msgRefl->HasField(msg, fd)) ||
                            (fd->is_repeated() && msgRefl->FieldSize(msg, fd) > 0);
            if (!hasData)
                continue;
180

181 182 183 184 185 186
            if (fd->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE)
            {
                if (fd->is_repeated()) //Extract only first item!
                    extractLayerParams(msgRefl->GetRepeatedMessage(msg, fd, 0), params, true);
                else
                    extractLayerParams(msgRefl->GetMessage(msg, fd), params, true);
187 188 189
            }
            else
            {
190
                addParam(msg, fd, params);
191 192
            }
        }
193
    }
194

195 196 197
    BlobShape blobShapeFromProto(const caffe::BlobProto &pbBlob)
    {
        if (pbBlob.has_num() || pbBlob.has_channels() || pbBlob.has_height() || pbBlob.has_width())
198
        {
199 200 201 202 203 204
            return BlobShape(pbBlob.num(), pbBlob.channels(), pbBlob.height(), pbBlob.width());
        }
        else if (pbBlob.has_shape())
        {
            const caffe::BlobShape &_shape = pbBlob.shape();
            BlobShape shape = BlobShape::all(_shape.dim_size());
205

206 207
            for (int i = 0; i < _shape.dim_size(); i++)
                shape[i] = (int)_shape.dim(i);
208

209
            return shape;
210
        }
211
        else
212
        {
213 214 215 216
            CV_Error(Error::StsError, "Unknown shape of input blob");
            return BlobShape();
        }
    }
217

218 219 220
    void blobFromProto(const caffe::BlobProto &pbBlob, cv::dnn::Blob &dstBlob)
    {
        BlobShape shape = blobShapeFromProto(pbBlob);
221

222 223
        dstBlob.create(shape, CV_32F);
        CV_Assert(pbBlob.data_size() == (int)dstBlob.matRefConst().total());
224

225 226
        CV_DbgAssert(pbBlob.GetDescriptor()->FindFieldByLowercaseName("data")->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT);
        float *dstData = dstBlob.matRef().ptr<float>();
227

228 229 230
        for (int i = 0; i < pbBlob.data_size(); i++)
            dstData[i] = pbBlob.data(i);
    }
231

232 233 234
    void extractBinaryLayerParms(const caffe::LayerParameter& layer, LayerParams& layerParams)
    {
        const std::string &name = layer.name();
235

236 237
        int li;
        for (li = 0; li != netBinary.layer_size(); li++)
238
        {
239 240 241
            if (netBinary.layer(li).name() == name)
                break;
        }
242

243 244
        if (li == netBinary.layer_size() || netBinary.layer(li).blobs_size() == 0)
            return;
245

246 247 248 249 250 251 252
        const caffe::LayerParameter &binLayer = netBinary.layer(li);
        layerParams.blobs.resize(binLayer.blobs_size());
        for (int bi = 0; bi < binLayer.blobs_size(); bi++)
        {
            blobFromProto(binLayer.blobs(bi), layerParams.blobs[bi]);
        }
    }
253

254 255 256 257
    struct BlobNote
    {
        BlobNote(const std::string &_name, int _layerId, int _outNum) :
            name(_name.c_str()), layerId(_layerId), outNum(_outNum) {}
258

259 260 261
        const char *name;
        int layerId, outNum;
    };
262

263 264
    std::vector<BlobNote> addedBlobs;
    std::map<String, int> layerCounter;
265

266 267 268 269 270 271 272 273 274 275 276 277 278 279
    void populateNet(Net dstNet)
    {
        int layersSize = net.layer_size();
        layerCounter.clear();
        addedBlobs.clear();
        addedBlobs.reserve(layersSize + 1);

        //setup input layer names
        {
            std::vector<String> netInputs(net.input_size());
            for (int inNum = 0; inNum < net.input_size(); inNum++)
            {
                addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum));
                netInputs[inNum] = net.input(inNum);
280
            }
281
            dstNet.setNetInputs(netInputs);
282 283
        }

284
        for (int li = 0; li < layersSize; li++)
285
        {
286 287 288 289
            const caffe::LayerParameter &layer = net.layer(li);
            String name = layer.name();
            String type = layer.type();
            LayerParams layerParams;
290

291 292
            extractLayerParams(layer, layerParams);
            extractBinaryLayerParms(layer, layerParams);
293

294 295 296 297 298
            int repetitions = layerCounter[name]++;
            if (repetitions)
                name += String("_") + toString(repetitions);

            int id = dstNet.addLayer(name, type, layerParams);
299

300 301 302 303 304
            for (int inNum = 0; inNum < layer.bottom_size(); inNum++)
                addInput(layer.bottom(inNum), id, inNum, dstNet);

            for (int outNum = 0; outNum < layer.top_size(); outNum++)
                addOutput(layer, id, outNum);
305 306
        }

307 308 309 310 311 312
        addedBlobs.clear();
    }

    void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum)
    {
        const std::string &name = layer.top(outNum);
313

314 315 316 317
        bool haveDups = false;
        for (int idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)
        {
            if (addedBlobs[idx].name == name)
318
            {
319 320
                haveDups = true;
                break;
321
            }
322
        }
323

324 325 326 327 328
        if (haveDups)
        {
            bool isInplace = layer.bottom_size() > outNum && layer.bottom(outNum) == name;
            if (!isInplace)
                CV_Error(Error::StsBadArg, "Duplicate blobs produced by multiple sources");
329 330
        }

331 332 333 334 335 336 337
        addedBlobs.push_back(BlobNote(name, layerId, outNum));
    }

    void addInput(const std::string &name, int layerId, int inNum, Net &dstNet)
    {
        int idx;
        for (idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)
338
        {
339 340 341
            if (addedBlobs[idx].name == name)
                break;
        }
342

343 344 345 346
        if (idx < 0)
        {
            CV_Error(Error::StsObjectNotFound, "Can't find output blob \"" + name + "\"");
            return;
347 348
        }

349 350
        dstNet.connect(addedBlobs[idx].layerId, addedBlobs[idx].outNum, layerId, inNum);
    }
351

352 353 354 355 356 357
    ~CaffeImporter()
    {

    }

};
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374

}

Ptr<Importer> cv::dnn::createCaffeImporter(const String &prototxt, const String &caffeModel)
{
    return Ptr<Importer>(new CaffeImporter(prototxt.c_str(), caffeModel.c_str()));
}

#else //HAVE_PROTOBUF

Ptr<Importer> cv::dnn::createCaffeImporter(const String&, const String&)
{
    CV_Error(cv::Error::StsNotImplemented, "libprotobuf required to import data from Caffe models");
    return Ptr<Importer>();
}

#endif //HAVE_PROTOBUF
375 376 377

Net cv::dnn::readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String()*/)
{
378 379 380 381 382 383 384 385 386
    Ptr<Importer> caffeImporter;
    try
    {
        caffeImporter = createCaffeImporter(prototxt, caffeModel);
    }
    catch(...)
    {
    }

387 388 389 390 391
    Net net;
    if (caffeImporter)
        caffeImporter->populateNet(net);
    return net;
}