Commit fcfb2976 authored by Easton Liu's avatar Easton Liu

Add ability to read thresh and nms_threshold from YOLO layer in YOLOV3 cfg file.

Currently the thresh is hard-coded to be 0.2 and nms_threshold as 0.4.
parent 3132c8ee
...@@ -371,7 +371,7 @@ namespace cv { ...@@ -371,7 +371,7 @@ namespace cv {
fused_layer_names.push_back(last_layer); fused_layer_names.push_back(last_layer);
} }
void setYolo(int classes, const std::vector<int>& mask, const std::vector<float>& anchors) void setYolo(int classes, const std::vector<int>& mask, const std::vector<float>& anchors, float thresh, float nms_threshold)
{ {
cv::dnn::LayerParams region_param; cv::dnn::LayerParams region_param;
region_param.name = "Region-name"; region_param.name = "Region-name";
...@@ -382,6 +382,8 @@ namespace cv { ...@@ -382,6 +382,8 @@ namespace cv {
region_param.set<int>("classes", classes); region_param.set<int>("classes", classes);
region_param.set<int>("anchors", numAnchors); region_param.set<int>("anchors", numAnchors);
region_param.set<bool>("logistic", true); region_param.set<bool>("logistic", true);
region_param.set<float>("thresh", thresh);
region_param.set<float>("nms_threshold", nms_threshold);
std::vector<float> usedAnchors(numAnchors * 2); std::vector<float> usedAnchors(numAnchors * 2);
for (int i = 0; i < numAnchors; ++i) for (int i = 0; i < numAnchors; ++i)
...@@ -646,6 +648,8 @@ namespace cv { ...@@ -646,6 +648,8 @@ namespace cv {
{ {
int classes = getParam<int>(layer_params, "classes", -1); int classes = getParam<int>(layer_params, "classes", -1);
int num_of_anchors = getParam<int>(layer_params, "num", -1); int num_of_anchors = getParam<int>(layer_params, "num", -1);
float thresh = getParam<float>(layer_params, "thresh", 0.2);
float nms_threshold = getParam<float>(layer_params, "nms_threshold", 0.4);
std::string anchors_values = getParam<std::string>(layer_params, "anchors", std::string()); std::string anchors_values = getParam<std::string>(layer_params, "anchors", std::string());
CV_Assert(!anchors_values.empty()); CV_Assert(!anchors_values.empty());
...@@ -658,7 +662,7 @@ namespace cv { ...@@ -658,7 +662,7 @@ namespace cv {
CV_Assert(classes > 0 && num_of_anchors > 0 && (num_of_anchors * 2) == anchors_vec.size()); CV_Assert(classes > 0 && num_of_anchors > 0 && (num_of_anchors * 2) == anchors_vec.size());
setParams.setPermute(false); setParams.setPermute(false);
setParams.setYolo(classes, mask_vec, anchors_vec); setParams.setYolo(classes, mask_vec, anchors_vec, thresh, nms_threshold);
} }
else { else {
CV_Error(cv::Error::StsParseError, "Unknown layer type: " + layer_type); CV_Error(cv::Error::StsParseError, "Unknown layer type: " + layer_type);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment