caffe_converter.cpp 7.04 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
/*
  By downloading, copying, installing or using the software you agree to this license.
  If you do not agree to this license, do not download, install,
  copy or use the software.


                            License Agreement
                 For Open Source Computer Vision Library
                         (3-clause BSD License)

  Copyright (C) 2000-2016, Intel Corporation, all rights reserved.
  Copyright (C) 2009-2011, Willow Garage Inc., all rights reserved.
  Copyright (C) 2009-2016, NVIDIA Corporation, all rights reserved.
  Copyright (C) 2010-2013, Advanced Micro Devices, Inc., all rights reserved.
  Copyright (C) 2015-2016, OpenCV Foundation, all rights reserved.
  Copyright (C) 2015-2016, Itseez Inc., all rights reserved.
  Third party copyrights are property of their respective owners.

  Redistribution and use in source and binary forms, with or without modification,
  are permitted provided that the following conditions are met:

    * Redistributions of source code must retain the above copyright notice,
      this list of conditions and the following disclaimer.

    * Redistributions in binary form must reproduce the above copyright notice,
      this list of conditions and the following disclaimer in the documentation
      and/or other materials provided with the distribution.

    * Neither the names of the copyright holders nor the names of the contributors
      may be used to endorse or promote products derived from this software
      without specific prior written permission.

  This software is provided by the copyright holders and contributors "as is" and
  any express or implied warranties, including, but not limited to, the implied
  warranties of merchantability and fitness for a particular purpose are disclaimed.
  In no event shall copyright holders or contributors be liable for any direct,
  indirect, incidental, special, exemplary, or consequential damages
  (including, but not limited to, procurement of substitute goods or services;
  loss of use, data, or profits; or business interruption) however caused
  and on any theory of liability, whether in contract, strict liability,
  or tort (including negligence or otherwise) arising in any way out of
  the use of this software, even if advised of the possibility of such damage.
 */

#include "precomp.hpp"
46
#include <opencv2/imgproc.hpp>
47 48

#include <tiny_dnn/tiny_dnn.h>
49
#include <tiny_dnn/io/caffe/caffe.pb.cc>
50 51 52 53 54 55 56 57 58 59 60 61 62

using namespace tiny_dnn;
using namespace tiny_dnn::activation;
using namespace std;

namespace cv {
namespace dnn2 {

/*
 !CaffeConverter Implementation
 */
class CaffeConverter_Impl : public CaffeConverter {
 public:
63 64 65
    explicit CaffeConverter_Impl(const String& model_file,
                                 const String& trained_file,
                                 const String& mean_file) {
66 67 68 69 70 71 72 73 74 75 76
        net_ = create_net_from_caffe_prototxt(model_file);
        reload_weight_from_caffe_protobinary(trained_file, net_.get());

        const size_t width  = (*net_)[0]->in_data_shape()[0].width_;
        const size_t height = (*net_)[0]->in_data_shape()[0].height_;

        mean_ = compute_mean(mean_file, width, height);
    }

    ~CaffeConverter_Impl() {}

77
    virtual void eval(InputArray image, std::vector<float>& results);
78 79

 private:
80
    Mat compute_mean(const string& mean_file, const size_t width,
81 82
		         const size_t height);

83
    ColorConversionCodes get_cvt_codes(const int src_channels,
84 85
                                           const int dst_channels);

86 87 88
    void preprocess(const Mat& img, const Mat& mean,
            const int num_channels, const Size& geometry,
            vector<Mat>* input_channels);
89

90
    Mat mean_;
91 92 93
    std::shared_ptr<network<sequential>> net_;
};

94
Mat
95 96 97 98 99 100
CaffeConverter_Impl::compute_mean(const string& mean_file,
                                  const size_t width,
				  const size_t height) {
    caffe::BlobProto blob;
    ::detail::read_proto_from_binary(mean_file, &blob);

101
    vector<Mat> channels;
102 103 104 105 106 107 108 109
    auto data = blob.mutable_data()->mutable_data();

    const size_t offset = blob.height() * blob.width();

    for (int i = 0; i < blob.channels(); i++, data += offset) {
        channels.emplace_back(blob.height(), blob.width(), CV_32FC1, data);
    }

110 111
    Mat meanChannel;
    merge(channels, meanChannel);
112

113
    return Mat(Size(width, height), meanChannel.type(), mean(meanChannel));
114 115
}

116
ColorConversionCodes
117 118 119 120 121
CaffeConverter_Impl::get_cvt_codes(const int src_channels,
                                   const int dst_channels) {
    assert(src_channels != dst_channels);

    if (dst_channels == 3) {
122
        return src_channels == 1 ? COLOR_GRAY2BGR : COLOR_BGRA2BGR;
123
    } else if (dst_channels == 1) {
124
        return src_channels == 3 ? COLOR_BGR2GRAY : COLOR_BGRA2GRAY;
125 126 127 128 129
    } else {
        throw runtime_error("unsupported color code");
    }
}

130 131
void CaffeConverter_Impl::preprocess(const Mat& img,
                                     const Mat& mean,
132
                                     const int num_channels,
133 134 135
                                     const Size& geometry,
                                     vector<Mat>* input_channels) {
    Mat sample;
136 137 138

    // convert color
    if (img.channels() != num_channels) {
139
        cvtColor(img, sample,
140 141 142 143 144 145
                     get_cvt_codes(img.channels(), num_channels));
    } else {
        sample = img;
    }

    // resize
146 147
    Mat sample_resized;
    resize(sample, sample_resized, geometry);
148

149
    Mat sample_float;
150 151 152 153 154
    sample_resized.convertTo(sample_float,
                             num_channels == 3 ? CV_32FC3 : CV_32FC1);

    // subtract mean
    if (mean.size().width > 0) {
155 156 157
        Mat sample_normalized;
        subtract(sample_float, mean, sample_normalized);
        split(sample_normalized, *input_channels);
158 159
    }
    else {
160
        split(sample_float, *input_channels);
161 162 163
    }
}

164 165 166
void CaffeConverter_Impl::eval(InputArray image,
                               std::vector<float>& results) {
    const Mat img = image.getMat();
167 168 169 170 171

    const size_t channels = (*net_)[0]->in_data_shape()[0].depth_;
    const size_t width    = (*net_)[0]->in_data_shape()[0].width_;
    const size_t height   = (*net_)[0]->in_data_shape()[0].height_;

172
    vector<Mat> input_channels;
173 174 175 176 177 178 179 180
    vector<float> inputvec(width*height*channels);

    for (size_t i = 0; i < channels; i++) {
        input_channels.emplace_back(height, width, CV_32FC1,
                                    &inputvec[width*height*i]);
    }

    // subtract mean from input
181
    preprocess(img, mean_, 3, Size(width, height), &input_channels);
182 183 184 185

    const vector<tiny_dnn::float_t> vec(inputvec.begin(), inputvec.end());

    // perform inderence
186
    auto result = net_->predict(vec);
187 188

    // allocate output
189 190
    results.clear();
    results.reserve(result.size());
191 192

    for (size_t i = 0; i < result.size(); i++) {
193
        results.push_back(result[i]);
194 195 196
    }
}

197 198 199
Ptr<CaffeConverter> CaffeConverter::create(const String& model_file,
                                           const String& trained_file,
                                           const String& mean_file) {
200 201 202 203 204
    return makePtr<CaffeConverter_Impl>(model_file, trained_file, mean_file);
}

} // namespace dnn2
} // namespace cv