text_detectorCNN.cpp 3.14 KB
Newer Older
1 2 3 4
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.

5 6 7
#include "precomp.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/core.hpp"
8
#include "opencv2/dnn.hpp"
9 10 11

#include <fstream>
#include <algorithm>
12

sghoshcvc's avatar
sghoshcvc committed
13 14
using namespace cv::dnn;

15 16 17 18
namespace cv
{
namespace text
{
19

20 21
class TextDetectorCNNImpl : public TextDetectorCNN
{
sghoshcvc's avatar
sghoshcvc committed
22
protected:
23 24 25
    Net net_;
    std::vector<Size> sizes_;
    int inputChannelCount_;
sghoshcvc's avatar
sghoshcvc committed
26

27 28
    void getOutputs(const float* buffer,int nbrTextBoxes,int nCol,
                               std::vector<Rect>& Bbox, std::vector<float>& confidence, Size inputShape)
sghoshcvc's avatar
sghoshcvc committed
29
    {
30 31 32 33
        for(int k = 0; k < nbrTextBoxes; k++)
        {
            float x_min = buffer[k*nCol + 3]*inputShape.width;
            float y_min = buffer[k*nCol + 4]*inputShape.height;
34

35 36
            float x_max = buffer[k*nCol + 5]*inputShape.width;
            float y_max = buffer[k*nCol + 6]*inputShape.height;
sghoshcvc's avatar
sghoshcvc committed
37

38
            CV_Assert(x_min < x_max, y_min < y_max);
sghoshcvc's avatar
sghoshcvc committed
39

40 41
            x_min = std::max(0.f, x_min);
            y_min = std::max(0.f, y_min);
sghoshcvc's avatar
sghoshcvc committed
42

43 44
            x_max = std::min(inputShape.width - 1.f,  x_max);
            y_max = std::min(inputShape.height - 1.f,  y_max);
sghoshcvc's avatar
sghoshcvc committed
45

46 47
            int wd = cvRound(x_max - x_min);
            int ht = cvRound(y_max - y_min);
sghoshcvc's avatar
sghoshcvc committed
48

49 50 51
            Bbox.push_back(Rect(cvRound(x_min), cvRound(y_min), wd, ht));
            confidence.push_back(buffer[k*nCol + 2]);
        }
sghoshcvc's avatar
sghoshcvc committed
52 53 54
    }

public:
55 56
    TextDetectorCNNImpl(const String& modelArchFilename, const String& modelWeightsFilename, std::vector<Size> detectionSizes) :
        sizes_(detectionSizes)
sghoshcvc's avatar
sghoshcvc committed
57
    {
58 59 60
        net_ = readNetFromCaffe(modelArchFilename, modelWeightsFilename);
        CV_Assert(!net_.empty());
        inputChannelCount_ = 3;
sghoshcvc's avatar
sghoshcvc committed
61 62
    }

63
    void detect(InputArray inputImage_, std::vector<Rect>& Bbox, std::vector<float>& confidence)
sghoshcvc's avatar
sghoshcvc committed
64
    {
65
        CV_Assert(inputImage_.channels() == inputChannelCount_);
66
        Mat inputImage = inputImage_.getMat();
67 68
        Bbox.resize(0);
        confidence.resize(0);
69

70 71 72
        for(size_t i = 0; i < sizes_.size(); i++)
        {
            Size inputGeometry = sizes_[i];
73
            net_.setInput(blobFromImage(inputImage, 1, inputGeometry, Scalar(123, 117, 104), false, false), "data");
74 75 76 77 78
            Mat outputNet = net_.forward();
            int nbrTextBoxes = outputNet.size[2];
            int nCol = outputNet.size[3];
            int outputChannelCount = outputNet.size[1];
            CV_Assert(outputChannelCount == 1);
79
            getOutputs((float*)(outputNet.data), nbrTextBoxes, nCol, Bbox, confidence, inputImage.size());
80 81
        }
     }
sghoshcvc's avatar
sghoshcvc committed
82 83
};

84 85 86 87 88 89
Ptr<TextDetectorCNN> TextDetectorCNN::create(const String &modelArchFilename, const String &modelWeightsFilename, std::vector<Size> detectionSizes)
{
    return makePtr<TextDetectorCNNImpl>(modelArchFilename, modelWeightsFilename, detectionSizes);
}

Ptr<TextDetectorCNN> TextDetectorCNN::create(const String &modelArchFilename, const String &modelWeightsFilename)
90
{
91
    return create(modelArchFilename, modelWeightsFilename, std::vector<Size>(1, Size(300, 300)));
92
}
93 94
} //namespace text
} //namespace cv