torch_enet.cpp 6.36 KB
Newer Older
arrybn's avatar
arrybn committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/*
Sample of using OpenCV dnn module with Torch ENet model.
*/

#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace cv;
using namespace cv::dnn;

#include <fstream>
#include <iostream>
#include <cstdlib>
#include <sstream>
using namespace std;

const String keys =
        "{help h    || Sample app for loading ENet Torch model. "
                       "The model and class names list can be downloaded here: "
                       "https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa }"
        "{model m   || path to Torch .net model file (model_best.net) }"
        "{image i   || path to image file }"
23
        "{c_names c || path to file with classnames for channels (optional, categories.txt) }"
arrybn's avatar
arrybn committed
24
        "{result r  || path to save output blob (optional, binary format, NCHW order) }"
25
        "{show s    || whether to show all output channels or not}"
arrybn's avatar
arrybn committed
26
        "{o_blob    || output blob's name. If empty, last blob's name in net is used}"
arrybn's avatar
arrybn committed
27 28
        ;

29
static void colorizeSegmentation(const Mat &score, Mat &segm,
30 31
                                 Mat &legend, vector<String> &classNames, vector<Vec3b> &colors);
static vector<Vec3b> readColors(const String &filename, vector<String>& classNames);
arrybn's avatar
arrybn committed
32 33 34

int main(int argc, char **argv)
{
35
    CommandLineParser parser(argc, argv, keys);
arrybn's avatar
arrybn committed
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54

    if (parser.has("help"))
    {
        parser.printMessage();
        return 0;
    }

    String modelFile = parser.get<String>("model");
    String imageFile = parser.get<String>("image");

    if (!parser.check())
    {
        parser.printErrors();
        return 0;
    }

    String classNamesFile = parser.get<String>("c_names");
    String resultFile = parser.get<String>("result");

55 56
    //! [Read model and initialize network]
    dnn::Net net = dnn::readNetFromTorch(modelFile);
arrybn's avatar
arrybn committed
57 58

    //! [Prepare blob]
59
    Mat img = imread(imageFile), input;
arrybn's avatar
arrybn committed
60 61 62 63 64 65
    if (img.empty())
    {
        std::cerr << "Can't read image from the file: " << imageFile << std::endl;
        exit(-1);
    }

66 67
    Size origSize = img.size();
    Size inputImgSize = cv::Size(1024, 512);
arrybn's avatar
arrybn committed
68

69
    if (inputImgSize != origSize)
arrybn's avatar
arrybn committed
70 71
        resize(img, img, inputImgSize);       //Resize image to input size

72
    Mat inputBlob = blobFromImage(img, 1./255, true);   //Convert Mat to image batch
arrybn's avatar
arrybn committed
73 74 75
    //! [Prepare blob]

    //! [Set input blob]
76
    net.setBlob("", inputBlob);        //set the network input
arrybn's avatar
arrybn committed
77 78
    //! [Set input blob]

79
    const int N = 3;
80
    TickMeter tm;
arrybn's avatar
arrybn committed
81 82

    //! [Make forward pass]
83 84 85 86 87 88 89 90 91
    for( int i = 0; i < N; i++ )
    {
        TickMeter tm_;
        tm_.start();
        net.forward();                          //compute output
        tm_.stop();
        if( i == 0 || tm_.getTimeTicks() < tm.getTimeTicks() )
            tm = tm_;
    }
arrybn's avatar
arrybn committed
92 93

    //! [Gather output]
94

arrybn's avatar
arrybn committed
95 96 97 98 99 100
    String oBlob = net.getLayerNames().back();
    if (!parser.get<String>("o_blob").empty())
    {
        oBlob = parser.get<String>("o_blob");
    }

101
    Mat result = net.getBlob(oBlob);   //gather output of "prob" layer
arrybn's avatar
arrybn committed
102 103 104 105 106 107 108 109 110

    if (!resultFile.empty()) {
        CV_Assert(result.isContinuous());

        ofstream fout(resultFile.c_str(), ios::out | ios::binary);
        fout.write((char*)result.data, result.total() * sizeof(float));
        fout.close();
    }

111
    std::cout << "Output blob: " << result.size[0] << " x " << result.size[1] << " x " << result.size[2] << " x " << result.size[3] << "\n";
arrybn's avatar
arrybn committed
112 113
    std::cout << "Inference time, ms: " << tm.getTimeMilli()  << std::endl;

114 115 116
    if (parser.has("show"))
    {
        std::vector<String> classNames;
117
        vector<cv::Vec3b> colors;
118
        if(!classNamesFile.empty()) {
119
            colors = readColors(classNamesFile, classNames);
120 121
        }
        Mat segm, legend;
122
        colorizeSegmentation(result, segm, legend, classNames, colors);
123 124

        Mat show;
125
        addWeighted(img, 0.1, segm, 0.9, 0.0, show);
126

127
        cv::resize(show, show, origSize, 0, 0, cv::INTER_NEAREST);
128 129 130 131
        imshow("Result", show);
        if(classNames.size())
            imshow("Legend", legend);
        waitKey();
arrybn's avatar
arrybn committed
132 133 134 135 136
    }

    return 0;
} //main

137
static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vector<String> &classNames, vector<Vec3b> &colors)
138
{
139 140 141
    const int rows = score.size[2];
    const int cols = score.size[3];
    const int chns = score.size[1];
142 143 144 145 146 147 148

    cv::Mat maxCl(rows, cols, CV_8UC1);
    cv::Mat maxVal(rows, cols, CV_32FC1);
    for (int ch = 0; ch < chns; ch++)
    {
        for (int row = 0; row < rows; row++)
        {
149
            const float *ptrScore = score.ptr<float>(0, ch, row);
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
            uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
            float *ptrMaxVal = maxVal.ptr<float>(row);
            for (int col = 0; col < cols; col++)
            {
                if (ptrScore[col] > ptrMaxVal[col])
                {
                    ptrMaxVal[col] = ptrScore[col];
                    ptrMaxCl[col] = ch;
                }
            }
        }
    }

    segm.create(rows, cols, CV_8UC3);
    for (int row = 0; row < rows; row++)
    {
        const uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
        cv::Vec3b *ptrSegm = segm.ptr<cv::Vec3b>(row);
        for (int col = 0; col < cols; col++)
        {
            ptrSegm[col] = colors[ptrMaxCl[col]];
        }
    }

    if (classNames.size() == colors.size())
    {
        int blockHeight = 30;
        legend.create(blockHeight*classNames.size(), 200, CV_8UC3);
        for(int i = 0; i < classNames.size(); i++)
        {
            cv::Mat block = legend.rowRange(i*blockHeight, (i+1)*blockHeight);
            block = colors[i];
            putText(block, classNames[i], Point(0, blockHeight/2), FONT_HERSHEY_SIMPLEX, 0.5, Scalar());
        }
    }
}
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220

static vector<Vec3b> readColors(const String &filename, vector<String>& classNames)
{
    vector<cv::Vec3b> colors;
    classNames.clear();

    ifstream fp(filename.c_str());
    if (!fp.is_open())
    {
        cerr << "File with colors not found: " << filename << endl;
        exit(-1);
    }

    string line;
    while (!fp.eof())
    {
        getline(fp, line);
        if (line.length())
        {
            stringstream ss(line);

            string name; ss >> name;
            int temp;
            cv::Vec3b color;
            ss >> temp; color[0] = temp;
            ss >> temp; color[1] = temp;
            ss >> temp; color[2] = temp;
            classNames.push_back(name);
            colors.push_back(color);
        }
    }

    fp.close();
    return colors;
}