Commit 08112f38 authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Faster-RCNN models support

parent 84535a60
......@@ -74,7 +74,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
class CV_EXPORTS BlankLayer : public Layer
{
public:
static Ptr<BlankLayer> create(const LayerParams &params);
static Ptr<Layer> create(const LayerParams &params);
};
//! LSTM recurrent layer
......@@ -567,6 +567,12 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
static Ptr<ResizeNearestNeighborLayer> create(const LayerParams& params);
};
class CV_EXPORTS ProposalLayer : public Layer
{
public:
static Ptr<ProposalLayer> create(const LayerParams& params);
};
//! @}
//! @}
CV__DNN_EXPERIMENTAL_NS_END
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
......@@ -547,6 +547,7 @@ message LayerParameter {
optional PowerParameter power_param = 122;
optional PReLUParameter prelu_param = 131;
optional PriorBoxParameter prior_box_param = 150;
optional ProposalParameter proposal_param = 201;
optional PythonParameter python_param = 130;
optional RecurrentParameter recurrent_param = 146;
optional ReductionParameter reduction_param = 136;
......@@ -854,6 +855,9 @@ message SaveOutputParameter {
message DropoutParameter {
optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio
// Faster-RCNN framework's parameter.
// source: https://github.com/rbgirshick/caffe-fast-rcnn/tree/faster-rcnn
optional bool scale_train = 2 [default = true]; // scale train or test phase
}
// DummyDataLayer fills any number of arbitrarily shaped blobs with random
......@@ -1618,3 +1622,14 @@ message ROIPoolingParameter {
// input scale to the scale used when pooling
optional float spatial_scale = 3 [default = 1];
}
message ProposalParameter {
optional uint32 feat_stride = 1 [default = 16];
optional uint32 base_size = 2 [default = 16];
optional uint32 min_size = 3 [default = 16];
repeated float ratio = 4;
repeated float scale = 5;
optional uint32 pre_nms_topn = 6 [default = 6000];
optional uint32 post_nms_topn = 7 [default = 300];
optional float nms_thresh = 8 [default = 0.7];
}
......@@ -122,6 +122,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(Normalize, NormalizeBBoxLayer);
CV_DNN_REGISTER_LAYER_CLASS(Shift, ShiftLayer);
CV_DNN_REGISTER_LAYER_CLASS(Padding, PaddingLayer);
CV_DNN_REGISTER_LAYER_CLASS(Proposal, ProposalLayer);
CV_DNN_REGISTER_LAYER_CLASS(Scale, ScaleLayer);
CV_DNN_REGISTER_LAYER_CLASS(LSTM, LSTMLayer);
......
......@@ -92,9 +92,25 @@ public:
}
};
Ptr<BlankLayer> BlankLayer::create(const LayerParams& params)
Ptr<Layer> BlankLayer::create(const LayerParams& params)
{
return Ptr<BlankLayer>(new BlankLayerImpl(params));
// In case of Caffe's Dropout layer from Faster-RCNN framework,
// https://github.com/rbgirshick/caffe-fast-rcnn/tree/faster-rcnn
// return Power layer.
if (!params.get<bool>("scale_train", true))
{
float scale = 1 - params.get<float>("dropout_ratio", 0.5f);
CV_Assert(scale > 0);
LayerParams powerParams;
powerParams.name = params.name;
powerParams.type = "Power";
powerParams.set("scale", scale);
return PowerLayer::create(powerParams);
}
else
return Ptr<BlankLayer>(new BlankLayerImpl(params));
}
}
......
......@@ -88,6 +88,7 @@ public:
else if (params.has("pooled_w") || params.has("pooled_h") || params.has("spatial_scale"))
{
type = ROI;
computeMaxIdx = false;
}
setParamsFrom(params);
ceilMode = params.get<bool>("ceil_mode", true);
......@@ -294,24 +295,17 @@ public:
int ystart, yend;
const float *srcData;
int xstartROI = 0;
float roiRatio = 0;
if (poolingType == ROI)
{
const float *roisData = rois->ptr<float>(n);
int ystartROI = scaleAndRoundRoi(roisData[2], spatialScale);
int yendROI = scaleAndRoundRoi(roisData[4], spatialScale);
int roiHeight = std::max(yendROI - ystartROI + 1, 1);
roiRatio = (float)roiHeight / height;
float roiRatio = (float)roiHeight / height;
ystart = ystartROI + y0 * roiRatio;
yend = ystartROI + std::ceil((y0 + 1) * roiRatio);
xstartROI = scaleAndRoundRoi(roisData[1], spatialScale);
int xendROI = scaleAndRoundRoi(roisData[3], spatialScale);
int roiWidth = std::max(xendROI - xstartROI + 1, 1);
roiRatio = (float)roiWidth / width;
CV_Assert(roisData[0] < src->size[0]);
srcData = src->ptr<float>(roisData[0], c);
}
......@@ -331,22 +325,12 @@ public:
ofs0 += delta;
int x1 = x0 + delta;
if( poolingType == MAX || poolingType == ROI)
if( poolingType == MAX)
for( ; x0 < x1; x0++ )
{
int xstart, xend;
if (poolingType == ROI)
{
xstart = xstartROI + x0 * roiRatio;
xend = xstartROI + std::ceil((x0 + 1) * roiRatio);
}
else
{
xstart = x0 * stride_w - pad_w;
xend = xstart + kernel_w;
}
int xstart = x0 * stride_w - pad_w;
int xend = min(xstart + kernel_w, inp_width);
xstart = max(xstart, 0);
xend = min(xend, inp_width);
if (xstart >= xend || ystart >= yend)
{
dstData[x0] = 0;
......@@ -493,7 +477,7 @@ public:
}
}
}
else
else if (poolingType == AVE)
{
for( ; x0 < x1; x0++ )
{
......@@ -543,6 +527,37 @@ public:
}
}
}
else // ROI
{
const float *roisData = rois->ptr<float>(n);
int xstartROI = scaleAndRoundRoi(roisData[1], spatialScale);
int xendROI = scaleAndRoundRoi(roisData[3], spatialScale);
int roiWidth = std::max(xendROI - xstartROI + 1, 1);
float roiRatio = (float)roiWidth / width;
for( ; x0 < x1; x0++ )
{
int xstart = xstartROI + x0 * roiRatio;
int xend = xstartROI + std::ceil((x0 + 1) * roiRatio);
xstart = max(xstart, 0);
xend = min(xend, inp_width);
if (xstart >= xend || ystart >= yend)
{
dstData[x0] = 0;
if (compMaxIdx && dstMaskData)
dstMaskData[x0] = -1;
continue;
}
float max_val = -FLT_MAX;
for (int y = ystart; y < yend; ++y)
for (int x = xstart; x < xend; ++x)
{
const int index = y * inp_width + x;
float val = srcData[index];
max_val = std::max(max_val, val);
}
dstData[x0] = max_val;
}
}
}
}
};
......
......@@ -183,6 +183,7 @@ public:
_minSize = getParameter<float>(params, "min_size", 0, false, 0);
_flip = getParameter<bool>(params, "flip", 0, false, true);
_clip = getParameter<bool>(params, "clip", 0, false, true);
_bboxesNormalized = getParameter<bool>(params, "normalized_bbox", 0, false, true);
_scales.clear();
_aspectRatios.clear();
......@@ -251,7 +252,7 @@ public:
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const
{
CV_Assert(inputs.size() == 2);
CV_Assert(!inputs.empty());
int layerHeight = inputs[0][2];
int layerWidth = inputs[0][3];
......@@ -282,6 +283,8 @@ public:
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
CV_Assert(inputs.size() == 2);
size_t real_numPriors = _numPriors / pow(2, _offsetsX.size() - 1);
if (_scales.empty())
_scales.resize(real_numPriors, 1.0f);
......@@ -323,7 +326,8 @@ public:
{
float center_x = (w + _offsetsX[i]) * stepX;
float center_y = (h + _offsetsY[i]) * stepY;
outputPtr = addPrior(center_x, center_y, _boxWidth, _boxHeight, _imageWidth, _imageHeight, outputPtr);
outputPtr = addPrior(center_x, center_y, _boxWidth, _boxHeight, _imageWidth,
_imageHeight, _bboxesNormalized, outputPtr);
}
if (_maxSize > 0)
{
......@@ -333,7 +337,8 @@ public:
{
float center_x = (w + _offsetsX[i]) * stepX;
float center_y = (h + _offsetsY[i]) * stepY;
outputPtr = addPrior(center_x, center_y, _boxWidth, _boxHeight, _imageWidth, _imageHeight, outputPtr);
outputPtr = addPrior(center_x, center_y, _boxWidth, _boxHeight, _imageWidth,
_imageHeight, _bboxesNormalized, outputPtr);
}
}
......@@ -349,7 +354,8 @@ public:
{
float center_x = (w + _offsetsX[i]) * stepX;
float center_y = (h + _offsetsY[i]) * stepY;
outputPtr = addPrior(center_x, center_y, _boxWidth, _boxHeight, _imageWidth, _imageHeight, outputPtr);
outputPtr = addPrior(center_x, center_y, _boxWidth, _boxHeight, _imageWidth,
_imageHeight, _bboxesNormalized, outputPtr);
}
}
......@@ -363,7 +369,8 @@ public:
{
float center_x = (w + _offsetsX[j]) * stepX;
float center_y = (h + _offsetsY[j]) * stepY;
outputPtr = addPrior(center_x, center_y, _boxWidth, _boxHeight, _imageWidth, _imageHeight, outputPtr);
outputPtr = addPrior(center_x, center_y, _boxWidth, _boxHeight, _imageWidth,
_imageHeight, _bboxesNormalized, outputPtr);
}
}
}
......@@ -437,6 +444,7 @@ private:
bool _flip;
bool _clip;
bool _explicitSizes;
bool _bboxesNormalized;
size_t _numPriors;
......@@ -444,12 +452,22 @@ private:
static const std::string _layerName;
static float* addPrior(float center_x, float center_y, float width, float height,
float imgWidth, float imgHeight, float* dst)
float imgWidth, float imgHeight, bool normalized, float* dst)
{
dst[0] = (center_x - width * 0.5f) / imgWidth; // xmin
dst[1] = (center_y - height * 0.5f) / imgHeight; // ymin
dst[2] = (center_x + width * 0.5f) / imgWidth; // xmax
dst[3] = (center_y + height * 0.5f) / imgHeight; // ymax
if (normalized)
{
dst[0] = (center_x - width * 0.5f) / imgWidth; // xmin
dst[1] = (center_y - height * 0.5f) / imgHeight; // ymin
dst[2] = (center_x + width * 0.5f) / imgWidth; // xmax
dst[3] = (center_y + height * 0.5f) / imgHeight; // ymax
}
else
{
dst[0] = center_x - width * 0.5f; // xmin
dst[1] = center_y - height * 0.5f; // ymin
dst[2] = center_x + width * 0.5f - 1.0f; // xmax
dst[3] = center_y + height * 0.5f - 1.0f; // ymax
}
return dst + 4;
}
};
......
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2017, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#include "../precomp.hpp"
#include "layers_common.hpp"
namespace cv { namespace dnn {
class ProposalLayerImpl : public ProposalLayer
{
public:
ProposalLayerImpl(const LayerParams& params)
{
setParamsFrom(params);
uint32_t featStride = params.get<uint32_t>("feat_stride", 16);
uint32_t baseSize = params.get<uint32_t>("base_size", 16);
// uint32_t minSize = params.get<uint32_t>("min_size", 16);
uint32_t keepTopBeforeNMS = params.get<uint32_t>("pre_nms_topn", 6000);
keepTopAfterNMS = params.get<uint32_t>("post_nms_topn", 300);
float nmsThreshold = params.get<float>("nms_thresh", 0.7);
DictValue ratios = params.get("ratio");
DictValue scales = params.get("scale");
{
LayerParams lp;
lp.set("step", featStride);
lp.set("flip", false);
lp.set("clip", false);
lp.set("normalized_bbox", false);
// Unused values.
float variance[] = {0.1f, 0.1f, 0.2f, 0.2f};
lp.set("variance", DictValue::arrayReal<float*>(&variance[0], 4));
// Compute widths and heights explicitly.
std::vector<float> widths, heights;
widths.reserve(ratios.size() * scales.size());
heights.reserve(ratios.size() * scales.size());
for (int i = 0; i < ratios.size(); ++i)
{
float ratio = ratios.get<float>(i);
for (int j = 0; j < scales.size(); ++j)
{
float scale = scales.get<float>(j);
float width = std::floor(baseSize / sqrt(ratio) + 0.5f);
float height = std::floor(width * ratio + 0.5f);
widths.push_back(scale * width);
heights.push_back(scale * height);
}
}
lp.set("width", DictValue::arrayReal<float*>(&widths[0], widths.size()));
lp.set("height", DictValue::arrayReal<float*>(&heights[0], heights.size()));
priorBoxLayer = PriorBoxLayer::create(lp);
}
{
int order[] = {0, 2, 3, 1};
LayerParams lp;
lp.set("order", DictValue::arrayInt<int*>(&order[0], 4));
deltasPermute = PermuteLayer::create(lp);
scoresPermute = PermuteLayer::create(lp);
}
{
LayerParams lp;
lp.set("code_type", "CENTER_SIZE");
lp.set("num_classes", 1);
lp.set("share_location", true);
lp.set("background_label_id", 1); // We won't pass background scores so set it out of range [0, num_classes)
lp.set("variance_encoded_in_target", true);
lp.set("keep_top_k", keepTopAfterNMS);
lp.set("top_k", keepTopBeforeNMS);
lp.set("nms_threshold", nmsThreshold);
lp.set("normalized_bbox", false);
lp.set("clip", true);
detectionOutputLayer = DetectionOutputLayer::create(lp);
}
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const
{
// We need to allocate the following blobs:
// - output priors from PriorBoxLayer
// - permuted priors
// - permuted scores
CV_Assert(inputs.size() == 3);
const MatShape& scores = inputs[0];
const MatShape& bboxDeltas = inputs[1];
std::vector<MatShape> layerInputs, layerOutputs, layerInternals;
// Prior boxes layer.
layerInputs.assign(1, scores);
priorBoxLayer->getMemoryShapes(layerInputs, 1, layerOutputs, layerInternals);
CV_Assert(layerOutputs.size() == 1);
CV_Assert(layerInternals.empty());
internals.push_back(layerOutputs[0]);
// Scores permute layer.
CV_Assert(scores.size() == 4);
MatShape objectScores = scores;
CV_Assert((scores[1] & 1) == 0); // Number of channels is even.
objectScores[1] /= 2;
layerInputs.assign(1, objectScores);
scoresPermute->getMemoryShapes(layerInputs, 1, layerOutputs, layerInternals);
CV_Assert(layerOutputs.size() == 1);
CV_Assert(layerInternals.empty());
internals.push_back(layerOutputs[0]);
// BBox predictions permute layer.
layerInputs.assign(1, bboxDeltas);
deltasPermute->getMemoryShapes(layerInputs, 1, layerOutputs, layerInternals);
CV_Assert(layerOutputs.size() == 1);
CV_Assert(layerInternals.empty());
internals.push_back(layerOutputs[0]);
outputs.resize(1, shape(keepTopAfterNMS, 5));
return false;
}
void finalize(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
{
std::vector<Mat*> layerInputs;
std::vector<Mat> layerOutputs;
// Scores permute layer.
Mat scores = getObjectScores(*inputs[0]);
layerInputs.assign(1, &scores);
layerOutputs.assign(1, Mat(shape(scores.size[0], scores.size[2],
scores.size[3], scores.size[1]), CV_32FC1));
scoresPermute->finalize(layerInputs, layerOutputs);
// BBox predictions permute layer.
Mat* bboxDeltas = inputs[1];
CV_Assert(bboxDeltas->dims == 4);
layerInputs.assign(1, bboxDeltas);
layerOutputs.assign(1, Mat(shape(bboxDeltas->size[0], bboxDeltas->size[2],
bboxDeltas->size[3], bboxDeltas->size[1]), CV_32FC1));
deltasPermute->finalize(layerInputs, layerOutputs);
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr)
{
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
Layer::forward_fallback(inputs_arr, outputs_arr, internals_arr);
}
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
{
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
CV_Assert(inputs.size() == 3);
CV_Assert(internals.size() == 3);
const Mat& scores = *inputs[0];
const Mat& bboxDeltas = *inputs[1];
const Mat& imInfo = *inputs[2];
Mat& priorBoxes = internals[0];
Mat& permuttedScores = internals[1];
Mat& permuttedDeltas = internals[2];
CV_Assert(imInfo.total() >= 2);
// We've chosen the smallest data type because we need just a shape from it.
fakeImageBlob.create(shape(1, 1, imInfo.at<float>(0), imInfo.at<float>(1)), CV_8UC1);
// Generate prior boxes.
std::vector<Mat> layerInputs(2), layerOutputs(1, priorBoxes);
layerInputs[0] = scores;
layerInputs[1] = fakeImageBlob;
priorBoxLayer->forward(layerInputs, layerOutputs, internals);
// Permute scores.
layerInputs.assign(1, getObjectScores(scores));
layerOutputs.assign(1, permuttedScores);
scoresPermute->forward(layerInputs, layerOutputs, internals);
// Permute deltas.
layerInputs.assign(1, bboxDeltas);
layerOutputs.assign(1, permuttedDeltas);
deltasPermute->forward(layerInputs, layerOutputs, internals);
// Sort predictions by scores and apply NMS. DetectionOutputLayer allocates
// output internally because of different number of objects after NMS.
layerInputs.resize(4);
layerInputs[0] = permuttedDeltas;
layerInputs[1] = permuttedScores;
layerInputs[2] = priorBoxes;
layerInputs[3] = fakeImageBlob;
layerOutputs[0] = Mat();
detectionOutputLayer->forward(layerInputs, layerOutputs, internals);
// DetectionOutputLayer produces 1x1xNx7 output where N might be less or
// equal to keepTopAfterNMS. We fill the rest by zeros.
const int numDets = layerOutputs[0].total() / 7;
CV_Assert(numDets <= keepTopAfterNMS);
Mat src = layerOutputs[0].reshape(1, numDets).colRange(3, 7);
Mat dst = outputs[0].rowRange(0, numDets);
src.copyTo(dst.colRange(1, 5));
dst.col(0).setTo(0); // First column are batch ids. Keep it zeros too.
if (numDets < keepTopAfterNMS)
outputs[0].rowRange(numDets, keepTopAfterNMS).setTo(0);
}
private:
// A first half of channels are background scores. We need only a second one.
static Mat getObjectScores(const Mat& m)
{
CV_Assert(m.dims == 4);
CV_Assert(m.size[0] == 1);
int channels = m.size[1];
CV_Assert((channels & 1) == 0);
return slice(m, Range::all(), Range(channels / 2, channels));
}
Ptr<PriorBoxLayer> priorBoxLayer;
Ptr<DetectionOutputLayer> detectionOutputLayer;
Ptr<PermuteLayer> deltasPermute;
Ptr<PermuteLayer> scoresPermute;
uint32_t keepTopAfterNMS;
Mat fakeImageBlob;
};
Ptr<ProposalLayer> ProposalLayer::create(const LayerParams& params)
{
return Ptr<ProposalLayer>(new ProposalLayerImpl(params));
}
} // namespace dnn
} // namespace cv
......@@ -576,4 +576,27 @@ TEST(Layer_Test_ROIPooling, Accuracy)
normAssert(out, ref);
}
TEST(Layer_Test_FasterRCNN_Proposal, Accuracy)
{
Net net = readNetFromCaffe(_tf("net_faster_rcnn_proposal.prototxt"));
Mat scores = blobFromNPY(_tf("net_faster_rcnn_proposal.scores.npy"));
Mat deltas = blobFromNPY(_tf("net_faster_rcnn_proposal.deltas.npy"));
Mat imInfo = (Mat_<float>(1, 3) << 600, 800, 1.6f);
Mat ref = blobFromNPY(_tf("net_faster_rcnn_proposal.npy"));
net.setInput(scores, "rpn_cls_prob_reshape");
net.setInput(deltas, "rpn_bbox_pred");
net.setInput(imInfo, "im_info");
Mat out = net.forward();
const int numDets = ref.size[0];
EXPECT_LE(numDets, out.size[0]);
normAssert(out.rowRange(0, numDets), ref);
if (numDets < out.size[0])
EXPECT_EQ(countNonZero(out.rowRange(numDets, out.size[0])), 0);
}
}
// Faster-RCNN models use custom layer called 'Proposal' written in Python. To
// map it into OpenCV's layer replace a layer node with [type: 'Python'] to the
// following definition:
// layer {
// name: 'proposal'
// type: 'Proposal'
// bottom: 'rpn_cls_prob_reshape'
// bottom: 'rpn_bbox_pred'
// bottom: 'im_info'
// top: 'rois'
// proposal_param {
// ratio: 0.5
// ratio: 1.0
// ratio: 2.0
// scale: 8
// scale: 16
// scale: 32
// }
// }
#include <iostream>
#include <opencv2/dnn.hpp>
#include <opencv2/dnn/all_layers.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace cv;
using namespace dnn;
const char* about = "This sample is used to run Faster-RCNN object detection "
"models from https://github.com/rbgirshick/py-faster-rcnn with OpenCV.";
const char* keys =
"{ help h | | print help message }"
"{ proto p | | path to .prototxt }"
"{ model m | | path to .caffemodel }"
"{ image i | | path to input image }"
"{ conf c | 0.8 | minimal confidence }";
const char* classNames[] = {
"__background__",
"aeroplane", "bicycle", "bird", "boat",
"bottle", "bus", "car", "cat", "chair",
"cow", "diningtable", "dog", "horse",
"motorbike", "person", "pottedplant",
"sheep", "sofa", "train", "tvmonitor"
};
static const int kInpWidth = 800;
static const int kInpHeight = 600;
int main(int argc, char** argv)
{
// Parse command line arguments.
CommandLineParser parser(argc, argv, keys);
if (argc == 1 || parser.has("help"))
{
std::cout << about << std::endl;
return 0;
}
String protoPath = parser.get<String>("proto");
String modelPath = parser.get<String>("model");
String imagePath = parser.get<String>("image");
float confThreshold = parser.get<float>("conf");
CV_Assert(!protoPath.empty(), !modelPath.empty(), !imagePath.empty());
// Load a model.
Net net = readNetFromCaffe(protoPath, modelPath);
// Create a preprocessing layer that does final bounding boxes applying predicted
// deltas to objects locations proposals and doing non-maximum suppression over it.
LayerParams lp;
lp.set("code_type", "CENTER_SIZE"); // An every bounding box is [xmin, ymin, xmax, ymax]
lp.set("num_classes", 21);
lp.set("share_location", (int)false); // Separate predictions for different classes.
lp.set("background_label_id", 0);
lp.set("variance_encoded_in_target", (int)true);
lp.set("keep_top_k", 100);
lp.set("nms_threshold", 0.3);
lp.set("normalized_bbox", (int)false);
Ptr<Layer> detectionOutputLayer = DetectionOutputLayer::create(lp);
Mat img = imread(imagePath);
resize(img, img, Size(kInpWidth, kInpHeight));
Mat blob = blobFromImage(img, 1.0, Size(), Scalar(102.9801, 115.9465, 122.7717), false, false);
Mat imInfo = (Mat_<float>(1, 3) << img.rows, img.cols, 1.6f);
net.setInput(blob, "data");
net.setInput(imInfo, "im_info");
std::vector<Mat> outs;
std::vector<String> outNames(3);
outNames[0] = "proposal";
outNames[1] = "bbox_pred";
outNames[2] = "cls_prob";
net.forward(outs, outNames);
Mat proposals = outs[0].colRange(1, 5).clone(); // Only last 4 columns.
Mat& deltas = outs[1];
Mat& scores = outs[2];
// Reshape proposals from Nx4 to 1x1xN*4
std::vector<int> shape(3, 1);
shape[2] = (int)proposals.total();
proposals = proposals.reshape(1, shape);
// Run postprocessing layer.
std::vector<Mat> layerInputs(3), layerOutputs(1), layerInternals;
layerInputs[0] = deltas.reshape(1, 1);
layerInputs[1] = scores.reshape(1, 1);
layerInputs[2] = proposals;
detectionOutputLayer->forward(layerInputs, layerOutputs, layerInternals);
// Draw detections.
Mat detections = layerOutputs[0];
const float* data = (float*)detections.data;
for (size_t i = 0; i < detections.total(); i += 7)
{
// An every detection is a vector [id, classId, confidence, left, top, right, bottom]
float confidence = data[i + 2];
if (confidence > confThreshold)
{
int classId = (int)data[i + 1];
int left = max(0, min((int)data[i + 3], img.cols - 1));
int top = max(0, min((int)data[i + 4], img.rows - 1));
int right = max(0, min((int)data[i + 5], img.cols - 1));
int bottom = max(0, min((int)data[i + 6], img.rows - 1));
// Draw a bounding box.
rectangle(img, Point(left, top), Point(right, bottom), Scalar(0, 255, 0));
// Put a label with a class name and confidence.
String label = cv::format("%s, %.3f", classNames[classId], confidence);
int baseLine;
Size labelSize = cv::getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
top = max(top, labelSize.height);
rectangle(img, Point(left, top - labelSize.height),
Point(left + labelSize.width, top + baseLine),
Scalar(255, 255, 255), FILLED);
putText(img, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 0, 0));
}
}
imshow("frame", img);
waitKey();
return 0;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment