fcn_semsegm.cpp 4.38 KB
Newer Older
1 2 3
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
4
#include <opencv2/core/ocl.hpp>
5 6 7 8 9 10 11 12
using namespace cv;
using namespace cv::dnn;

#include <fstream>
#include <iostream>
#include <cstdlib>
using namespace std;

13 14 15
static const string fcnType = "fcn8s";

static vector<cv::Vec3b> readColors(const string &filename = "pascal-classes.txt")
16
{
17
    vector<cv::Vec3b> colors;
18

19
    ifstream fp(filename.c_str());
20 21
    if (!fp.is_open())
    {
22
        cerr << "File with colors not found: " << filename << endl;
23 24 25
        exit(-1);
    }

26
    string line;
27 28
    while (!fp.eof())
    {
29
        getline(fp, line);
30 31
        if (line.length())
        {
32
            stringstream ss(line);
33

34
            string name; ss >> name;
35 36 37 38 39 40 41 42 43 44 45 46 47
            int temp;
            cv::Vec3b color;
            ss >> temp; color[0] = temp;
            ss >> temp; color[1] = temp;
            ss >> temp; color[2] = temp;
            colors.push_back(color);
        }
    }

    fp.close();
    return colors;
}

48
static void colorizeSegmentation(dnn::Blob &score, const vector<cv::Vec3b> &colors, cv::Mat &segm)
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
{
    const int rows = score.rows();
    const int cols = score.cols();
    const int chns = score.channels();

    cv::Mat maxCl(rows, cols, CV_8UC1);
    cv::Mat maxVal(rows, cols, CV_32FC1);
    for (int ch = 0; ch < chns; ch++)
    {
        for (int row = 0; row < rows; row++)
        {
            const float *ptrScore = score.ptrf(0, ch, row);
            uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
            float *ptrMaxVal = maxVal.ptr<float>(row);
            for (int col = 0; col < cols; col++)
            {
                if (ptrScore[col] > ptrMaxVal[col])
                {
                    ptrMaxVal[col] = ptrScore[col];
                    ptrMaxCl[col] = ch;
                }
            }
        }
    }

    segm.create(rows, cols, CV_8UC3);
    for (int row = 0; row < rows; row++)
    {
        const uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
        cv::Vec3b *ptrSegm = segm.ptr<cv::Vec3b>(row);
        for (int col = 0; col < cols; col++)
        {
            ptrSegm[col] = colors[ptrMaxCl[col]];
        }
    }

}

int main(int argc, char **argv)
{
89 90
    cv::dnn::initModule();          //Required if OpenCV is built as static libs
    cv::ocl::setUseOpenCL(false);   //OpenCL switcher
91

92 93 94
    String modelTxt = fcnType + "-heavy-pascal.prototxt";
    String modelBin = fcnType + "-heavy-pascal.caffemodel";
    String imageFile = (argc > 1) ? argv[1] : "rgb.jpg";
95

96
    vector<cv::Vec3b> colors = readColors();
97 98 99 100 101 102 103 104 105

    //! [Create the importer of Caffe model]
    Ptr<dnn::Importer> importer;
    try                                     //Try to import Caffe GoogleNet model
    {
        importer = dnn::createCaffeImporter(modelTxt, modelBin);
    }
    catch (const cv::Exception &err)        //Importer can throw errors, we will catch them
    {
106
        cerr << err.msg << endl;
107 108 109 110 111
    }
    //! [Create the importer of Caffe model]

    if (!importer)
    {
112 113 114 115 116
        cerr << "Can't load network by using the following files: " << endl;
        cerr << "prototxt:   " << modelTxt << endl;
        cerr << "caffemodel: " << modelBin << endl;
        cerr << fcnType << "-heavy-pascal.caffemodel can be downloaded here:" << endl;
        cerr << "http://dl.caffe.berkeleyvision.org/" << fcnType << "-heavy-pascal.caffemodel" << endl;
117 118 119 120 121 122 123 124 125 126 127 128 129
        exit(-1);
    }

    //! [Initialize network]
    dnn::Net net;
    importer->populateNet(net);
    importer.release();                     //We don't need importer anymore
    //! [Initialize network]

    //! [Prepare blob]
    Mat img = imread(imageFile);
    if (img.empty())
    {
130
        cerr << "Can't read image from the file: " << imageFile << endl;
131 132 133 134
        exit(-1);
    }

    resize(img, img, Size(500, 500));       //FCN accepts 500x500 RGB-images
135
    dnn::Blob inputBlob = dnn::Blob::fromImages(img);   //Convert Mat to dnn::Blob batch of images
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    //! [Prepare blob]

    //! [Set input blob]
    net.setBlob(".data", inputBlob);        //set the network input
    //! [Set input blob]

    //! [Make forward pass]
    net.forward();                          //compute output
    //! [Make forward pass]

    //! [Gather output]
    dnn::Blob score = net.getBlob("score");

    cv::Mat colorize;
    colorizeSegmentation(score, colors, colorize);
    cv::Mat show;
    cv::addWeighted(img, 0.4, colorize, 0.6, 0.0, show);
    cv::imshow("show", show);
    cv::waitKey(0);
    return 0;
} //main