Commit f2440cea authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Update tutorials. A new cv::dnn::readNet function

parent 8e4fe30d
...@@ -13,50 +13,53 @@ We will demonstrate results of this example on the following picture. ...@@ -13,50 +13,53 @@ We will demonstrate results of this example on the following picture.
Source Code Source Code
----------- -----------
We will be using snippets from the example application, that can be downloaded [here](https://github.com/opencv/opencv/blob/master/samples/dnn/caffe_googlenet.cpp). We will be using snippets from the example application, that can be downloaded [here](https://github.com/opencv/opencv/blob/master/samples/dnn/classification.cpp).
@include dnn/caffe_googlenet.cpp @include dnn/classification.cpp
Explanation Explanation
----------- -----------
-# Firstly, download GoogLeNet model files: -# Firstly, download GoogLeNet model files:
[bvlc_googlenet.prototxt ](https://raw.githubusercontent.com/opencv/opencv/master/samples/data/dnn/bvlc_googlenet.prototxt) and [bvlc_googlenet.prototxt ](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/bvlc_googlenet.prototxt) and
[bvlc_googlenet.caffemodel](http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel) [bvlc_googlenet.caffemodel](http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel)
Also you need file with names of [ILSVRC2012](http://image-net.org/challenges/LSVRC/2012/browse-synsets) classes: Also you need file with names of [ILSVRC2012](http://image-net.org/challenges/LSVRC/2012/browse-synsets) classes:
[synset_words.txt](https://raw.githubusercontent.com/opencv/opencv/master/samples/data/dnn/synset_words.txt). [classification_classes_ILSVRC2012.txt](https://github.com/opencv/opencv/tree/master/samples/dnn/classification_classes_ILSVRC2012.txt).
Put these files into working dir of this program example. Put these files into working dir of this program example.
-# Read and initialize network using path to .prototxt and .caffemodel files -# Read and initialize network using path to .prototxt and .caffemodel files
@snippet dnn/caffe_googlenet.cpp Read and initialize network @snippet dnn/classification.cpp Read and initialize network
-# Check that network was read successfully You can skip an argument `framework` if one of the files `model` or `config` has an
@snippet dnn/caffe_googlenet.cpp Check that network was read successfully extension `.caffemodel` or `.prototxt`.
This way function cv::dnn::readNet can automatically detects a model's format.
-# Read input image and convert to the blob, acceptable by GoogleNet -# Read input image and convert to the blob, acceptable by GoogleNet
@snippet dnn/caffe_googlenet.cpp Prepare blob @snippet dnn/classification.cpp Open a video file or an image file or a camera stream
We convert the image to a 4-dimensional blob (so-called batch) with 1x3x224x224 shape after applying necessary pre-processing like resizing and mean subtraction using cv::dnn::blobFromImage constructor.
-# Pass the blob to the network cv::VideoCapture can load both images and videos.
@snippet dnn/caffe_googlenet.cpp Set input blob
In bvlc_googlenet.prototxt the network input blob named as "data", therefore this blob labeled as ".data" in opencv_dnn API. @snippet dnn/classification.cpp Create a 4D blob from a frame
We convert the image to a 4-dimensional blob (so-called batch) with `1x3x224x224` shape
after applying necessary pre-processing like resizing and mean subtraction
`(-104, -117, -123)` for each blue, green and red channels correspondingly using cv::dnn::blobFromImage function.
Other blobs labeled as "name_of_layer.name_of_layer_output". -# Pass the blob to the network
@snippet dnn/classification.cpp Set input blob
-# Make forward pass -# Make forward pass
@snippet dnn/caffe_googlenet.cpp Make forward pass @snippet dnn/classification.cpp Make forward pass
During the forward pass output of each network layer is computed, but in this example we need output from "prob" layer only. During the forward pass output of each network layer is computed, but in this example we need output from the last layer only.
-# Determine the best class -# Determine the best class
@snippet dnn/caffe_googlenet.cpp Gather output @snippet dnn/classification.cpp Get a class with a highest score
We put the output of "prob" layer, which contain probabilities for each of 1000 ILSVRC2012 image classes, to the `prob` blob. We put the output of network, which contain probabilities for each of 1000 ILSVRC2012 image classes, to the `prob` blob.
And find the index of element with maximal value in this one. This index correspond to the class of the image. And find the index of element with maximal value in this one. This index corresponds to the class of the image.
-# Print results -# Run an example from command line
@snippet dnn/caffe_googlenet.cpp Print results @code
For our image we get: ./example_dnn_classification --model=bvlc_googlenet.caffemodel --config=bvlc_googlenet.prototxt --width=224 --height=224 --classes=classification_classes_ILSVRC2012.txt --input=space_shuttle.jpg --mean="104 117 123"
> Best class: #812 'space shuttle' @endcode
> For our image we get prediction of class `space shuttle` with more than 99% sureness.
> Probability: 99.6378%
...@@ -74,46 +74,7 @@ When you build OpenCV add the following configuration flags: ...@@ -74,46 +74,7 @@ When you build OpenCV add the following configuration flags:
- `HALIDE_ROOT_DIR` - path to Halide build directory - `HALIDE_ROOT_DIR` - path to Halide build directory
## Sample ## Set Halide as a preferable backend
@code
@include dnn/squeezenet_halide.cpp net.setPreferableBackend(DNN_BACKEND_HALIDE);
@endcode
## Explanation
Download Caffe model from SqueezeNet repository: [train_val.prototxt](https://github.com/DeepScale/SqueezeNet/blob/master/SqueezeNet_v1.1/train_val.prototxt) and [squeezenet_v1.1.caffemodel](https://github.com/DeepScale/SqueezeNet/blob/master/SqueezeNet_v1.1/squeezenet_v1.1.caffemodel).
Also you need file with names of [ILSVRC2012](http://image-net.org/challenges/LSVRC/2012/browse-synsets) classes:
[synset_words.txt](https://raw.githubusercontent.com/opencv/opencv/master/samples/data/dnn/synset_words.txt).
Put these files into working dir of this program example.
-# Read and initialize network using path to .prototxt and .caffemodel files
@snippet dnn/squeezenet_halide.cpp Read and initialize network
-# Check that network was read successfully
@snippet dnn/squeezenet_halide.cpp Check that network was read successfully
-# Read input image and convert to the 4-dimensional blob, acceptable by SqueezeNet v1.1
@snippet dnn/squeezenet_halide.cpp Prepare blob
-# Pass the blob to the network
@snippet dnn/squeezenet_halide.cpp Set input blob
-# Enable Halide backend for layers where it is implemented
@snippet dnn/squeezenet_halide.cpp Enable Halide backend
-# Make forward pass
@snippet dnn/squeezenet_halide.cpp Make forward pass
Remember that the first forward pass after initialization require quite more
time that the next ones. It's because of runtime compilation of Halide pipelines
at the first invocation.
-# Determine the best class
@snippet dnn/squeezenet_halide.cpp Determine the best class
-# Print results
@snippet dnn/squeezenet_halide.cpp Print results
For our image we get:
> Best class: #812 'space shuttle'
>
> Probability: 97.9812%
...@@ -683,6 +683,29 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN ...@@ -683,6 +683,29 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
*/ */
CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true); CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true);
/**
* @brief Read deep learning network represented in one of the supported formats.
* @param[in] model Binary file contains trained weights. The following file
* extensions are expected for models from different frameworks:
* * `*.caffemodel` (Caffe, http://caffe.berkeleyvision.org/)
* * `*.pb` (TensorFlow, https://www.tensorflow.org/)
* * `*.t7` | `*.net` (Torch, http://torch.ch/)
* * `*.weights` (Darknet, https://pjreddie.com/darknet/)
* @param[in] config Text file contains network configuration. It could be a
* file with the following extensions:
* * `*.prototxt` (Caffe, http://caffe.berkeleyvision.org/)
* * `*.pbtxt` (TensorFlow, https://www.tensorflow.org/)
* * `*.cfg` (Darknet, https://pjreddie.com/darknet/)
* @param[in] framework Explicit framework name tag to determine a format.
* @returns Net object.
*
* This function automatically detects an origin framework of trained model
* and calls an appropriate function such @ref readNetFromCaffe, @ref readNetFromTensorflow,
* @ref readNetFromTorch or @ref readNetFromDarknet. An order of @p model and @p config
* arguments does not matter.
*/
CV_EXPORTS_W Net readNet(String model, String config = "", String framework = "");
/** @brief Loads blob which was serialized as torch.Tensor object of Torch7 framework. /** @brief Loads blob which was serialized as torch.Tensor object of Torch7 framework.
* @warning This function has the same limitations as readNetFromTorch(). * @warning This function has the same limitations as readNetFromTorch().
*/ */
......
...@@ -2805,5 +2805,41 @@ BackendWrapper::BackendWrapper(const Ptr<BackendWrapper>& base, const MatShape& ...@@ -2805,5 +2805,41 @@ BackendWrapper::BackendWrapper(const Ptr<BackendWrapper>& base, const MatShape&
BackendWrapper::~BackendWrapper() {} BackendWrapper::~BackendWrapper() {}
Net readNet(String model, String config, String framework)
{
framework = framework.toLowerCase();
const std::string modelExt = model.substr(model.rfind('.') + 1);
const std::string configExt = config.substr(config.rfind('.') + 1);
if (framework == "caffe" || modelExt == "caffemodel" || configExt == "caffemodel" ||
modelExt == "prototxt" || configExt == "prototxt")
{
if (modelExt == "prototxt" || configExt == "caffemodel")
std::swap(model, config);
return readNetFromCaffe(config, model);
}
if (framework == "tensorflow" || modelExt == "pb" || configExt == "pb" ||
modelExt == "pbtxt" || configExt == "pbtxt")
{
if (modelExt == "pbtxt" || configExt == "pb")
std::swap(model, config);
return readNetFromTensorflow(model, config);
}
if (framework == "torch" || modelExt == "t7" || modelExt == "net" ||
configExt == "t7" || configExt == "net")
{
return readNetFromTorch(model.empty() ? config : model);
}
if (framework == "darknet" || modelExt == "weights" || configExt == "weights" ||
modelExt == "cfg" || configExt == "cfg")
{
if (modelExt == "cfg" || configExt == "weights")
std::swap(model, config);
return readNetFromDarknet(config, model);
}
CV_Error(Error::StsError, "Cannot determine an origin framework of files: " +
model + (config.empty() ? "" : ", " + config));
return Net();
}
CV__DNN_EXPERIMENTAL_NS_END CV__DNN_EXPERIMENTAL_NS_END
}} // namespace }} // namespace
...@@ -57,4 +57,22 @@ TEST(imagesFromBlob, Regression) ...@@ -57,4 +57,22 @@ TEST(imagesFromBlob, Regression)
} }
} }
TEST(readNet, Regression)
{
Net net = readNet(findDataFile("dnn/squeezenet_v1.1.prototxt", false),
findDataFile("dnn/squeezenet_v1.1.caffemodel", false));
EXPECT_FALSE(net.empty());
net = readNet(findDataFile("dnn/opencv_face_detector.caffemodel", false),
findDataFile("dnn/opencv_face_detector.prototxt", false));
EXPECT_FALSE(net.empty());
net = readNet(findDataFile("dnn/openface_nn4.small2.v1.t7", false));
EXPECT_FALSE(net.empty());
net = readNet(findDataFile("dnn/tiny-yolo-voc.cfg", false),
findDataFile("dnn/tiny-yolo-voc.weights", false));
EXPECT_FALSE(net.empty());
net = readNet(findDataFile("dnn/ssd_mobilenet_v1_coco.pbtxt", false),
findDataFile("dnn/ssd_mobilenet_v1_coco.pb", false));
EXPECT_FALSE(net.empty());
}
}} // namespace }} // namespace
...@@ -14,14 +14,12 @@ ...@@ -14,14 +14,12 @@
| [Faster-RCNN](https://github.com/rbgirshick/py-faster-rcnn) | `1.0` | `800x600` | `102.9801, 115.9465, 122.7717` | BGR | | [Faster-RCNN](https://github.com/rbgirshick/py-faster-rcnn) | `1.0` | `800x600` | `102.9801, 115.9465, 122.7717` | BGR |
| [R-FCN](https://github.com/YuwenXiong/py-R-FCN) | `1.0` | `800x600` | `102.9801 115.9465 122.7717` | BGR | | [R-FCN](https://github.com/YuwenXiong/py-R-FCN) | `1.0` | `800x600` | `102.9801 115.9465 122.7717` | BGR |
### Classification ### Classification
| Model | Scale | Size WxH| Mean subtraction | Channels order | | Model | Scale | Size WxH| Mean subtraction | Channels order |
|---------------|-------|-----------|--------------------|-------| |---------------|-------|-----------|--------------------|-------|
| GoogLeNet | `1.0` | `224x224` | `104 117 123` | BGR | | GoogLeNet | `1.0` | `224x224` | `104 117 123` | BGR |
| [SqueezeNet](https://github.com/DeepScale/SqueezeNet) | `1.0` | `227x227` | `0 0 0` | BGR | | [SqueezeNet](https://github.com/DeepScale/SqueezeNet) | `1.0` | `227x227` | `0 0 0` | BGR |
## References ## References
* [Models downloading script](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/download_models.py) * [Models downloading script](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/download_models.py)
* [Configuration files adopted for OpenCV](https://github.com/opencv/opencv_extra/tree/master/testdata/dnn) * [Configuration files adopted for OpenCV](https://github.com/opencv/opencv_extra/tree/master/testdata/dnn)
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp> #include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
const char* keys = const char* keys =
"{ help h | | Print help message. }" "{ help h | | Print help message. }"
...@@ -33,8 +34,6 @@ using namespace dnn; ...@@ -33,8 +34,6 @@ using namespace dnn;
std::vector<std::string> classes; std::vector<std::string> classes;
Net readNet(const std::string& model, const std::string& config = "", const std::string& framework = "");
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
CommandLineParser parser(argc, argv, keys); CommandLineParser parser(argc, argv, keys);
...@@ -49,6 +48,11 @@ int main(int argc, char** argv) ...@@ -49,6 +48,11 @@ int main(int argc, char** argv)
bool swapRB = parser.get<bool>("rgb"); bool swapRB = parser.get<bool>("rgb");
int inpWidth = parser.get<int>("width"); int inpWidth = parser.get<int>("width");
int inpHeight = parser.get<int>("height"); int inpHeight = parser.get<int>("height");
String model = parser.get<String>("model");
String config = parser.get<String>("config");
String framework = parser.get<String>("framework");
int backendId = parser.get<int>("backend");
int targetId = parser.get<int>("target");
// Parse mean values. // Parse mean values.
Scalar mean; Scalar mean;
...@@ -77,22 +81,24 @@ int main(int argc, char** argv) ...@@ -77,22 +81,24 @@ int main(int argc, char** argv)
} }
} }
// Load a model.
CV_Assert(parser.has("model")); CV_Assert(parser.has("model"));
Net net = readNet(parser.get<String>("model"), parser.get<String>("config"), parser.get<String>("framework")); //! [Read and initialize network]
net.setPreferableBackend(parser.get<int>("backend")); Net net = readNet(model, config, framework);
net.setPreferableTarget(parser.get<int>("target")); net.setPreferableBackend(backendId);
net.setPreferableTarget(targetId);
//! [Read and initialize network]
// Create a window // Create a window
static const std::string kWinName = "Deep learning image classification in OpenCV"; static const std::string kWinName = "Deep learning image classification in OpenCV";
namedWindow(kWinName, WINDOW_NORMAL); namedWindow(kWinName, WINDOW_NORMAL);
// Open a video file or an image file or a camera stream. //! [Open a video file or an image file or a camera stream]
VideoCapture cap; VideoCapture cap;
if (parser.has("input")) if (parser.has("input"))
cap.open(parser.get<String>("input")); cap.open(parser.get<String>("input"));
else else
cap.open(0); cap.open(0);
//! [Open a video file or an image file or a camera stream]
// Process frames. // Process frames.
Mat frame, blob; Mat frame, blob;
...@@ -105,24 +111,29 @@ int main(int argc, char** argv) ...@@ -105,24 +111,29 @@ int main(int argc, char** argv)
break; break;
} }
// Create a 4D blob from a frame. //! [Create a 4D blob from a frame]
blobFromImage(frame, blob, scale, Size(inpWidth, inpHeight), mean, swapRB, false); blobFromImage(frame, blob, scale, Size(inpWidth, inpHeight), mean, swapRB, false);
//! [Create a 4D blob from a frame]
// Run a model. //! [Set input blob]
net.setInput(blob); net.setInput(blob);
Mat out = net.forward(); //! [Set input blob]
out = out.reshape(1, 1); //! [Make forward pass]
Mat prob = net.forward();
//! [Make forward pass]
// Get a class with a highest score. //! [Get a class with a highest score]
Point classIdPoint; Point classIdPoint;
double confidence; double confidence;
minMaxLoc(out, 0, &confidence, 0, &classIdPoint); minMaxLoc(prob.reshape(1, 1), 0, &confidence, 0, &classIdPoint);
int classId = classIdPoint.x; int classId = classIdPoint.x;
//! [Get a class with a highest score]
// Put efficiency information. // Put efficiency information.
std::vector<double> layersTimes; std::vector<double> layersTimes;
double t = net.getPerfProfile(layersTimes); double freq = getTickFrequency() / 1000;
std::string label = format("Inference time: %.2f", t * 1000 / getTickFrequency()); double t = net.getPerfProfile(layersTimes) / freq;
std::string label = format("Inference time: %.2f ms", t);
putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0)); putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
// Print predicted class. // Print predicted class.
...@@ -135,19 +146,3 @@ int main(int argc, char** argv) ...@@ -135,19 +146,3 @@ int main(int argc, char** argv)
} }
return 0; return 0;
} }
Net readNet(const std::string& model, const std::string& config, const std::string& framework)
{
std::string modelExt = model.substr(model.rfind('.'));
if (framework == "caffe" || modelExt == ".caffemodel")
return readNetFromCaffe(config, model);
else if (framework == "tensorflow" || modelExt == ".pb")
return readNetFromTensorflow(model, config);
else if (framework == "torch" || modelExt == ".t7" || modelExt == ".net")
return readNetFromTorch(model);
else if (framework == "darknet" || modelExt == ".weights")
return readNetFromDarknet(config, model);
else
CV_Error(Error::StsError, "Cannot determine an origin framework of model from file " + model);
return Net();
}
...@@ -48,19 +48,7 @@ if args.classes: ...@@ -48,19 +48,7 @@ if args.classes:
classes = f.read().rstrip('\n').split('\n') classes = f.read().rstrip('\n').split('\n')
# Load a network # Load a network
modelExt = args.model[args.model.rfind('.'):] net = cv.dnn.readNet(args.model, args.config, args.framework)
if args.framework == 'caffe' or modelExt == '.caffemodel':
net = cv.dnn.readNetFromCaffe(args.config, args.model)
elif args.framework == 'tensorflow' or modelExt == '.pb':
net = cv.dnn.readNetFromTensorflow(args.model, args.config)
elif args.framework == 'torch' or modelExt in ['.t7', '.net']:
net = cv.dnn.readNetFromTorch(args.model)
elif args.framework == 'darknet' or modelExt == '.weights':
net = cv.dnn.readNetFromDarknet(args.config, args.model)
else:
print('Cannot determine an origin framework of model from file %s' % args.model)
sys.exit(0)
net.setPreferableBackend(args.backend) net.setPreferableBackend(args.backend)
net.setPreferableTarget(args.target) net.setPreferableTarget(args.target)
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp> #include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
const char* keys = const char* keys =
"{ help h | | Print help message. }" "{ help h | | Print help message. }"
...@@ -35,8 +36,6 @@ using namespace dnn; ...@@ -35,8 +36,6 @@ using namespace dnn;
float confThreshold; float confThreshold;
std::vector<std::string> classes; std::vector<std::string> classes;
Net readNet(const std::string& model, const std::string& config = "", const std::string& framework = "");
void postprocess(Mat& frame, const Mat& out, Net& net); void postprocess(Mat& frame, const Mat& out, Net& net);
void drawPred(int classId, float conf, int left, int top, int right, int bottom, Mat& frame); void drawPred(int classId, float conf, int left, int top, int right, int bottom, Mat& frame);
...@@ -95,7 +94,7 @@ int main(int argc, char** argv) ...@@ -95,7 +94,7 @@ int main(int argc, char** argv)
// Create a window // Create a window
static const std::string kWinName = "Deep learning object detection in OpenCV"; static const std::string kWinName = "Deep learning object detection in OpenCV";
namedWindow(kWinName, WINDOW_NORMAL); namedWindow(kWinName, WINDOW_NORMAL);
int initialConf = confThreshold * 100; int initialConf = (int)(confThreshold * 100);
createTrackbar("Confidence threshold, %", kWinName, &initialConf, 99, callback); createTrackbar("Confidence threshold, %", kWinName, &initialConf, 99, callback);
// Open a video file or an image file or a camera stream. // Open a video file or an image file or a camera stream.
...@@ -135,8 +134,9 @@ int main(int argc, char** argv) ...@@ -135,8 +134,9 @@ int main(int argc, char** argv)
// Put efficiency information. // Put efficiency information.
std::vector<double> layersTimes; std::vector<double> layersTimes;
double t = net.getPerfProfile(layersTimes); double freq = getTickFrequency() / 1000;
std::string label = format("Inference time: %.2f", t * 1000 / getTickFrequency()); double t = net.getPerfProfile(layersTimes) / freq;
std::string label = format("Inference time: %.2f ms", t);
putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0)); putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
imshow(kWinName, frame); imshow(kWinName, frame);
...@@ -160,10 +160,10 @@ void postprocess(Mat& frame, const Mat& out, Net& net) ...@@ -160,10 +160,10 @@ void postprocess(Mat& frame, const Mat& out, Net& net)
float confidence = data[i + 2]; float confidence = data[i + 2];
if (confidence > confThreshold) if (confidence > confThreshold)
{ {
int left = data[i + 3]; int left = (int)data[i + 3];
int top = data[i + 4]; int top = (int)data[i + 4];
int right = data[i + 5]; int right = (int)data[i + 5];
int bottom = data[i + 6]; int bottom = (int)data[i + 6];
int classId = (int)(data[i + 1]) - 1; // Skip 0th background class id. int classId = (int)(data[i + 1]) - 1; // Skip 0th background class id.
drawPred(classId, confidence, left, top, right, bottom, frame); drawPred(classId, confidence, left, top, right, bottom, frame);
} }
...@@ -208,7 +208,7 @@ void postprocess(Mat& frame, const Mat& out, Net& net) ...@@ -208,7 +208,7 @@ void postprocess(Mat& frame, const Mat& out, Net& net)
int height = (int)(data[3] * frame.rows); int height = (int)(data[3] * frame.rows);
int left = centerX - width / 2; int left = centerX - width / 2;
int top = centerY - height / 2; int top = centerY - height / 2;
drawPred(classId, confidence, left, top, left + width, top + height, frame); drawPred(classId, (float)confidence, left, top, left + width, top + height, frame);
} }
} }
} }
...@@ -238,21 +238,5 @@ void drawPred(int classId, float conf, int left, int top, int right, int bottom, ...@@ -238,21 +238,5 @@ void drawPred(int classId, float conf, int left, int top, int right, int bottom,
void callback(int pos, void*) void callback(int pos, void*)
{ {
confThreshold = pos * 0.01; confThreshold = pos * 0.01f;
}
Net readNet(const std::string& model, const std::string& config, const std::string& framework)
{
std::string modelExt = model.substr(model.rfind('.'));
if (framework == "caffe" || modelExt == ".caffemodel")
return readNetFromCaffe(config, model);
else if (framework == "tensorflow" || modelExt == ".pb")
return readNetFromTensorflow(model, config);
else if (framework == "torch" || modelExt == ".t7" || modelExt == ".net")
return readNetFromTorch(model);
else if (framework == "darknet" || modelExt == ".weights")
return readNetFromDarknet(config, model);
else
CV_Error(Error::StsError, "Cannot determine an origin framework of model from file " + model);
return Net();
} }
...@@ -49,19 +49,7 @@ if args.classes: ...@@ -49,19 +49,7 @@ if args.classes:
classes = f.read().rstrip('\n').split('\n') classes = f.read().rstrip('\n').split('\n')
# Load a network # Load a network
modelExt = args.model[args.model.rfind('.'):] net = cv.dnn.readNet(args.model, args.config, args.framework)
if args.framework == 'caffe' or modelExt == '.caffemodel':
net = cv.dnn.readNetFromCaffe(args.config, args.model)
elif args.framework == 'tensorflow' or modelExt == '.pb':
net = cv.dnn.readNetFromTensorflow(args.model, args.config)
elif args.framework == 'torch' or modelExt in ['.t7', '.net']:
net = cv.dnn.readNetFromTorch(args.model)
elif args.framework == 'darknet' or modelExt == '.weights':
net = cv.dnn.readNetFromDarknet(args.config, args.model)
else:
print('Cannot determine an origin framework of model from file %s' % args.model)
sys.exit(0)
net.setPreferableBackend(args.backend) net.setPreferableBackend(args.backend)
net.setPreferableTarget(args.target) net.setPreferableTarget(args.target)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment