ssd_object_detection.cpp 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace cv;
using namespace cv::dnn;

#include <fstream>
#include <iostream>
#include <cstdlib>
using namespace std;

12 13
const size_t width = 300;
const size_t height = 300;
14

15
Mat getMean(const size_t& imageHeight, const size_t& imageWidth)
16 17 18 19 20 21 22
{
    Mat mean;

    const int meanValues[3] = {104, 117, 123};
    vector<Mat> meanChannels;
    for(size_t i = 0; i < 3; i++)
    {
23
        Mat channel(imageHeight, imageWidth, CV_32F, Scalar(meanValues[i]));
24 25 26 27 28 29
        meanChannels.push_back(channel);
    }
    cv::merge(meanChannels, mean);
    return mean;
}

30
Mat preprocess(const Mat& frame)
31
{
32 33 34
    Mat preprocessed;
    frame.convertTo(preprocessed, CV_32FC3);
    resize(preprocessed, preprocessed, Size(width, height)); //SSD accepts 300x300 RGB-images
35 36

    Mat mean = getMean(width, height);
37 38 39
    cv::subtract(preprocessed, mean, preprocessed);

    return preprocessed;
40 41 42 43 44
}

const char* about = "This sample uses Single-Shot Detector "
                    "(https://arxiv.org/abs/1512.02325)"
                    "to detect objects on image\n"; // TODO: link
45 46

const char* params
47 48 49 50 51
    = "{ help           | false | print usage         }"
      "{ proto          |       | model configuration }"
      "{ model          |       | model weights       }"
      "{ image          |       | image for detection }"
      "{ min_confidence | 0.5   | min confidence      }";
52 53 54 55 56 57 58 59 60 61 62 63

int main(int argc, char** argv)
{
    cv::CommandLineParser parser(argc, argv, params);

    if (parser.get<bool>("help"))
    {
        std::cout << about << std::endl;
        parser.printMessage();
        return 0;
    }

64 65
    cv::dnn::initModule();          //Required if OpenCV is built as static libs

66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    String modelConfiguration = parser.get<string>("proto");
    String modelBinary = parser.get<string>("model");

    //! [Create the importer of Caffe model]
    Ptr<dnn::Importer> importer;

    // Import Caffe SSD model
    try
    {
        importer = dnn::createCaffeImporter(modelConfiguration, modelBinary);
    }
    catch (const cv::Exception &err) //Importer can throw errors, we will catch them
    {
        cerr << err.msg << endl;
    }
    //! [Create the importer of Caffe model]

    if (!importer)
    {
        cerr << "Can't load network by using the following files: " << endl;
        cerr << "prototxt:   " << modelConfiguration << endl;
        cerr << "caffemodel: " << modelBinary << endl;
        cerr << "Models can be downloaded here:" << endl;
        cerr << "https://github.com/weiliu89/caffe/tree/ssd#models" << endl;
        exit(-1);
    }

    //! [Initialize network]
    dnn::Net net;
    importer->populateNet(net);
    importer.release();          //We don't need importer anymore
    //! [Initialize network]

99
    cv::Mat frame = cv::imread(parser.get<string>("image"), -1);
100

101
    //! [Prepare blob]
102
    Mat preprocessedFrame = preprocess(frame);
103

104
    dnn::Blob inputBlob = dnn::Blob::fromImages(preprocessedFrame); //Convert Mat to dnn::Blob image
105 106 107 108 109 110 111 112 113 114 115 116 117 118
    //! [Prepare blob]

    //! [Set input blob]
    net.setBlob(".data", inputBlob);                //set the network input
    //! [Set input blob]

    //! [Make forward pass]
    net.forward();                                  //compute output
    //! [Make forward pass]

    //! [Gather output]
    dnn::Blob detection = net.getBlob("detection_out");
    Mat detectionMat(detection.rows(), detection.cols(), CV_32F, detection.ptrf());

119
    float confidenceThreshold = parser.get<float>("min_confidence");
Anna Petrovicheva's avatar
Anna Petrovicheva committed
120
    for(int i = 0; i < detectionMat.rows; i++)
121
    {
122 123 124 125 126 127 128 129 130 131
        float confidence = detectionMat.at<float>(i, 2);

        if(confidence > confidenceThreshold)
        {
            size_t objectClass = detectionMat.at<float>(i, 1);

            float xLeftBottom = detectionMat.at<float>(i, 3) * frame.cols;
            float yLeftBottom = detectionMat.at<float>(i, 4) * frame.rows;
            float xRightTop = detectionMat.at<float>(i, 5) * frame.cols;
            float yRightTop = detectionMat.at<float>(i, 6) * frame.rows;
132

133 134
            std::cout << "Class: " << objectClass << std::endl;
            std::cout << "Confidence: " << confidence << std::endl;
135

136 137 138 139
            std::cout << " " << xLeftBottom
                      << " " << yLeftBottom
                      << " " << xRightTop
                      << " " << yRightTop << std::endl;
140

141 142 143
            Rect object(xLeftBottom, yLeftBottom,
                        xRightTop - xLeftBottom,
                        yRightTop - yLeftBottom);
144

145 146
            rectangle(frame, object, Scalar(0, 255, 0));
        }
147 148
    }

149 150 151 152
    imshow("detections", frame);
    waitKey();

    return 0;
153
} // main