Commit dda33bf3 authored by Fanny Monori's avatar Fanny Monori Committed by Alexander Alekhin

Merge pull request #2229 from fannymonori:gsoc_dnn_superres

* Adding dnn based super resolution module.

* Fixed whitespace error in unit test

* Fixed errors with time measuring functions.

* Updated unit tests in dnn superres

* Deleted unnecessary indents in dnn superres

* Refactored includes in dnn superres

* Moved video upsampling functions to sample code in dnn superres.

* Replaced couts with CV_Error in dnn superres

* Moved benchmarking functionality to sample codes in dnn superres.

* Added performance test to dnn superres

* Resolve buildbot errors

* update dnn_superres

- avoid highgui dependency
- cleanup public API
- use InputArray/OutputArray
- test: avoid legacy test API
parent 26129cfe
......@@ -24,6 +24,8 @@ $ cmake -D OPENCV_EXTRA_MODULES_PATH=<opencv_contrib>/modules -D BUILD_opencv_<r
- **dnn_objdetect**: Object Detection using CNNs -- Implements compact CNN Model for object detection. Trained using Caffe but uses opencv_dnn modeule.
- **dnn_superres**: Superresolution using CNNs -- Contains four trained convolutional neural networks to upscale images.
- **dnns_easily_fooled**: Subvert DNNs -- This code can use the activations in a network to fool the networks into recognizing something else.
- **dpm**: Deformable Part Model -- Felzenszwalb's Cascade with deformable parts object recognition code.
......
set(the_description "Super Resolution using CNNs")
ocv_define_module(dnn_superres opencv_core opencv_imgproc opencv_dnn
OPTIONAL opencv_datasets opencv_quality # samples
)
# Super Resolution using Convolutional Neural Networks
This module contains several learning-based algorithms for upscaling an image.
## Usage
Run the following command to build this module:
```make
cmake -DOPENCV_EXTRA_MODULES_PATH=<opencv_contrib>/modules -Dopencv_dnn_superres=ON <opencv_source_dir>
```
Refer to the tutorials to understand how to use this module.
## Models
There are four models which are trained.
#### EDSR
Trained models can be downloaded from [here](https://github.com/Saafke/EDSR_Tensorflow/tree/master/models).
- Size of the model: ~38.5MB. This is a quantized version, so that it can be uploaded to GitHub. (Original was 150MB.)
- This model was trained for 3 days with a batch size of 16
- Link to implementation code: https://github.com/Saafke/EDSR_Tensorflow
- x2, x3, x4 trained models available
- Advantage: Highly accurate
- Disadvantage: Slow and large filesize
- Speed: < 3 sec for every scaling factor on 256x256 images on an Intel i7-9700K CPU.
- Original paper: [Enhanced Deep Residual Networks for Single Image Super-Resolution](https://arxiv.org/pdf/1707.02921.pdf) [1]
#### ESPCN
Trained models can be downloaded from [here](https://github.com/fannymonori/TF-ESPCN/tree/master/export).
- Size of the model: ~100kb
- This model was trained for ~100 iterations with a batch size of 32
- Link to implementation code: https://github.com/fannymonori/TF-ESPCN
- x2, x3, x4 trained models available
- Advantage: It is tiny and fast, and still performs well.
- Disadvantage: Perform worse visually than newer, more robust models.
- Speed: < 0.01 sec for every scaling factor on 256x256 images on an Intel i7-9700K CPU.
- Original paper: [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](<https://arxiv.org/abs/1609.05158>) [2]
#### FSRCNN
Trained models can be downloaded from [here](https://github.com/Saafke/FSRCNN_Tensorflow/tree/master/models).
- Size of the model: ~40KB (~9kb for FSRCNN-small)
- This model was trained for ~30 iterations with a batch size of 1
- Link to implementation code: https://github.com/Saafke/FSRCNN_Tensorflow
- Advantage: Fast, small and accurate
- Disadvantage: Not state-of-the-art accuracy
- Speed: < 0.01 sec for every scaling factor on 256x256 images on an Intel i7-9700K CPU.
- Notes: FSRCNN-small has fewer parameters, thus less accurate but faster.
- Original paper: [Accelerating the Super-Resolution Convolutional Neural Network](http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html) [3]
#### LapSRN
Trained models can be downloaded from [here](https://github.com/fannymonori/TF-LapSRN/tree/master/export).
- Size of the model: between 1-5Mb
- This model was trained for ~50 iterations with a batch size of 32
- Link to implementation code: https://github.com/fannymonori/TF-LAPSRN
- x2, x4, x8 trained models available
- Advantage: The model can do multi-scale super-resolution with one forward pass. It can now support 2x, 4x, 8x, and [2x, 4x] and [2x, 4x, 8x] super-resolution.
- Disadvantage: It is slower than ESPCN and FSRCNN, and the accuracy is worse than EDSR.
- Speed: < 0.1 sec for every scaling factor on 256x256 images on an Intel i7-9700K CPU.
- Original paper: [Deep laplacian pyramid networks for fast and accurate super-resolution](<https://arxiv.org/abs/1710.01992>) [4]
### Benchmarks
Comparing different algorithms. Scale x4 on monarch.png.
| | Inference time in seconds (CPU)| PSNR | SSIM |
| ------------- |:-------------------:| ---------:|--------:|
| ESPCN |0.01159 | 26.5471 | 0.88116 |
| EDSR |3.26758 |**29.2404** |**0.92112** |
| FSRCNN | 0.01298 | 26.5646 | 0.88064 |
| LapSRN |0.28257 |26.7330 |0.88622 |
| Bicubic |0.00031 |26.0635 |0.87537 |
| Nearest neighbor |**0.00014** |23.5628 |0.81741 |
| Lanczos |0.00101 |25.9115 |0.87057 |
### References
[1] Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution"**, <i> 2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**. </i> [[PDF](http://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf)] [[arXiv](https://arxiv.org/abs/1707.02921)] [[Slide](https://cv.snu.ac.kr/research/EDSR/Presentation_v3(release).pptx)]
[2] Shi, W., Caballero, J., Huszár, F., Totz, J., Aitken, A., Bishop, R., Rueckert, D. and Wang, Z., **"Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network"**, <i>Proceedings of the IEEE conference on computer vision and pattern recognition</i> **CVPR 2016**. [[PDF](http://openaccess.thecvf.com/content_cvpr_2016/papers/Shi_Real-Time_Single_Image_CVPR_2016_paper.pdf)] [[arXiv](https://arxiv.org/abs/1609.05158)]
[3] Chao Dong, Chen Change Loy, Xiaoou Tang. **"Accelerating the Super-Resolution Convolutional Neural Network"**, <i> in Proceedings of European Conference on Computer Vision </i>**ECCV 2016**. [[PDF](http://personal.ie.cuhk.edu.hk/~ccloy/files/eccv_2016_accelerating.pdf)]
[[arXiv](https://arxiv.org/abs/1608.00367)] [[Project Page](http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html)]
[4] Lai, W. S., Huang, J. B., Ahuja, N., and Yang, M. H., **"Deep laplacian pyramid networks for fast and accurate super-resolution"**, <i> In Proceedings of the IEEE conference on computer vision and pattern recognition </i>**CVPR 2017**. [[PDF](http://openaccess.thecvf.com/content_cvpr_2017/papers/Lai_Deep_Laplacian_Pyramid_CVPR_2017_paper.pdf)] [[arXiv](https://arxiv.org/abs/1710.01992)] [[Project Page](http://vllab.ucmerced.edu/wlai24/LapSRN/)]
\ No newline at end of file
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef _OPENCV_DNN_SUPERRES_HPP_
#define _OPENCV_DNN_SUPERRES_HPP_
/** @defgroup dnn_superres DNN used for super resolution
This module contains functionality for upscaling an image via convolutional neural networks.
The following four models are implemented:
- EDSR <https://arxiv.org/abs/1707.02921>
- ESPCN <https://arxiv.org/abs/1609.05158>
- FSRCNN <https://arxiv.org/abs/1608.00367>
- LapSRN <https://arxiv.org/abs/1710.01992>
*/
#include "opencv2/core.hpp"
#include "opencv2/dnn.hpp"
namespace cv
{
namespace dnn_superres
{
//! @addtogroup dnn_superres
//! @{
/** @brief A class to upscale images via convolutional neural networks.
The following four models are implemented:
- edsr
- espcn
- fsrcnn
- lapsrn
*/
class CV_EXPORTS DnnSuperResImpl
{
private:
/** @brief Net which holds the desired neural network
*/
dnn::Net net;
std::string alg; //algorithm
int sc; //scale factor
void preprocess(InputArray inpImg, OutputArray outpImg);
void reconstruct_YCrCb(InputArray inpImg, InputArray origImg, OutputArray outpImg, int scale);
void reconstruct_YCrCb(InputArray inpImg, InputArray origImg, OutputArray outpImg);
void preprocess_YCrCb(InputArray inpImg, OutputArray outpImg);
public:
/** @brief Empty constructor
*/
DnnSuperResImpl();
/** @brief Constructor which immediately sets the desired model
@param algo String containing one of the desired models:
- __edsr__
- __espcn__
- __fsrcnn__
- __lapsrn__
@param scale Integer specifying the upscale factor
*/
DnnSuperResImpl(const std::string& algo, int scale);
/** @brief Read the model from the given path
@param path Path to the model file.
*/
void readModel(const std::string& path);
/** @brief Read the model from the given path
@param weights Path to the model weights file.
@param definition Path to the model definition file.
*/
void readModel(const std::string& weights, const std::string& definition);
/** @brief Set desired model
@param algo String containing one of the desired models:
- __edsr__
- __espcn__
- __fsrcnn__
- __lapsrn__
@param scale Integer specifying the upscale factor
*/
void setModel(const std::string& algo, int scale);
/** @brief Upsample via neural network
@param img Image to upscale
@param result Destination upscaled image
*/
void upsample(InputArray img, OutputArray result);
/** @brief Upsample via neural network of multiple outputs
@param img Image to upscale
@param imgs_new Destination upscaled images
@param scale_factors Scaling factors of the output nodes
@param node_names Names of the output nodes in the neural network
*/
void upsampleMultioutput(InputArray img, std::vector<Mat> &imgs_new, const std::vector<int>& scale_factors, const std::vector<String>& node_names);
/** @brief Returns the scale factor of the model:
@return Current scale factor.
*/
int getScale();
/** @brief Returns the scale factor of the model:
@return Current algorithm.
*/
std::string getAlgorithm();
};
//! @} dnn_superres
}} // cv::dnn_superres::
#endif
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "perf_precomp.hpp"
using namespace std;
using namespace cv;
using namespace perf;
namespace opencv_test { namespace {
typedef perf::TestBaseWithParam<tuple<tuple<string,string,int>, string> > dnn_superres;
#define MODEL testing::Values(tuple<string,string,int> {"espcn","ESPCN_x2.pb",2}, \
tuple<string,string,int> {"lapsrn","LapSRN_x4.pb",4})
#define IMAGES testing::Values("cv/dnn_superres/butterfly.png", "cv/shared/baboon.png", "cv/shared/lena.png")
const string TEST_DIR = "cv/dnn_superres";
PERF_TEST_P(dnn_superres, upsample, testing::Combine(MODEL, IMAGES))
{
tuple<string,string,int> model = get<0>( GetParam() );
string image_name = get<1>( GetParam() );
string model_name = get<0>(model);
string model_filename = get<1>(model);
int scale = get<2>(model);
string model_path = getDataPath( TEST_DIR + "/" + model_filename );
string image_path = getDataPath( image_name );
DnnSuperResImpl sr;
sr.readModel(model_path);
sr.setModel(model_name, scale);
Mat img = imread(image_path);
Mat img_new(img.rows * scale, img.cols * scale, CV_8UC3);
declare.in(img, WARMUP_RNG).out(img_new).iterations(10);
TEST_CYCLE() { sr.upsample(img, img_new); }
SANITY_CHECK_NOTHING();
}
}}
\ No newline at end of file
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "perf_precomp.hpp"
CV_PERF_TEST_MAIN( dnn_superres )
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef __OPENCV_PERF_PRECOMP_HPP__
#define __OPENCV_PERF_PRECOMP_HPP__
#include "opencv2/ts.hpp"
#include "opencv2/dnn_superres.hpp"
namespace opencv_test {
using namespace cv::dnn_superres;
}
#endif
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <iostream>
#include <opencv2/dnn_superres.hpp>
#include <opencv2/quality.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace std;
using namespace cv;
using namespace dnn_superres;
static void showBenchmark(vector<Mat> images, string title, Size imageSize,
const vector<String> imageTitles,
const vector<double> psnrValues,
const vector<double> ssimValues)
{
int fontFace = FONT_HERSHEY_COMPLEX_SMALL;
int fontScale = 1;
Scalar fontColor = Scalar(255, 255, 255);
int len = static_cast<int>(images.size());
int cols = 2, rows = 2;
Mat fullImage = Mat::zeros(Size((cols * 10) + imageSize.width * cols, (rows * 10) + imageSize.height * rows),
images[0].type());
stringstream ss;
int h_ = -1;
for (int i = 0; i < len; i++) {
int fontStart = 15;
int w_ = i % cols;
if (i % cols == 0)
h_++;
Rect ROI((w_ * (10 + imageSize.width)), (h_ * (10 + imageSize.height)), imageSize.width, imageSize.height);
Mat tmp;
resize(images[i], tmp, Size(ROI.width, ROI.height));
ss << imageTitles[i];
putText(tmp,
ss.str(),
Point(5, fontStart),
fontFace,
fontScale,
fontColor,
1,
16);
ss.str("");
fontStart += 20;
ss << "PSNR: " << psnrValues[i];
putText(tmp,
ss.str(),
Point(5, fontStart),
fontFace,
fontScale,
fontColor,
1,
16);
ss.str("");
fontStart += 20;
ss << "SSIM: " << ssimValues[i];
putText(tmp,
ss.str(),
Point(5, fontStart),
fontFace,
fontScale,
fontColor,
1,
16);
ss.str("");
fontStart += 20;
tmp.copyTo(fullImage(ROI));
}
namedWindow(title, 1);
imshow(title, fullImage);
waitKey();
}
static Vec2d getQualityValues(Mat orig, Mat upsampled)
{
double psnr = PSNR(upsampled, orig);
Scalar q = quality::QualitySSIM::compute(upsampled, orig, noArray());
double ssim = mean(Vec3d((q[0]), q[1], q[2]))[0];
return Vec2d(psnr, ssim);
}
int main(int argc, char *argv[])
{
// Check for valid command line arguments, print usage
// if insufficient arguments were given.
if (argc < 4) {
cout << "usage: Arg 1: image path | Path to image" << endl;
cout << "\t Arg 2: algorithm | edsr, espcn, fsrcnn or lapsrn" << endl;
cout << "\t Arg 3: path to model file 2 \n";
cout << "\t Arg 4: scale | 2, 3, 4 or 8 \n";
return -1;
}
string path = string(argv[1]);
string algorithm = string(argv[2]);
string model = string(argv[3]);
int scale = atoi(argv[4]);
Mat img = imread(path);
if (img.empty()) {
cerr << "Couldn't load image: " << img << "\n";
return -2;
}
//Crop the image so the images will be aligned
int width = img.cols - (img.cols % scale);
int height = img.rows - (img.rows % scale);
Mat cropped = img(Rect(0, 0, width, height));
//Downscale the image for benchmarking
Mat img_downscaled;
resize(cropped, img_downscaled, Size(), 1.0 / scale, 1.0 / scale);
//Make dnn super resolution instance
DnnSuperResImpl sr;
vector <Mat> allImages;
Mat img_new;
//Read and set the dnn model
sr.readModel(model);
sr.setModel(algorithm, scale);
sr.upsample(img_downscaled, img_new);
vector<double> psnrValues = vector<double>();
vector<double> ssimValues = vector<double>();
//DL MODEL
Vec2f quality = getQualityValues(cropped, img_new);
psnrValues.push_back(quality[0]);
ssimValues.push_back(quality[1]);
cout << sr.getAlgorithm() << ":" << endl;
cout << "PSNR: " << quality[0] << " SSIM: " << quality[1] << endl;
cout << "----------------------" << endl;
//BICUBIC
Mat bicubic;
resize(img_downscaled, bicubic, Size(), scale, scale, INTER_CUBIC);
quality = getQualityValues(cropped, bicubic);
psnrValues.push_back(quality[0]);
ssimValues.push_back(quality[1]);
cout << "Bicubic " << endl;
cout << "PSNR: " << quality[0] << " SSIM: " << quality[1] << endl;
cout << "----------------------" << endl;
//NEAREST NEIGHBOR
Mat nearest;
resize(img_downscaled, nearest, Size(), scale, scale, INTER_NEAREST);
quality = getQualityValues(cropped, nearest);
psnrValues.push_back(quality[0]);
ssimValues.push_back(quality[1]);
cout << "Nearest neighbor" << endl;
cout << "PSNR: " << quality[0] << " SSIM: " << quality[1] << endl;
cout << "----------------------" << endl;
//LANCZOS
Mat lanczos;
resize(img_downscaled, lanczos, Size(), scale, scale, INTER_LANCZOS4);
quality = getQualityValues(cropped, lanczos);
psnrValues.push_back(quality[0]);
ssimValues.push_back(quality[1]);
cout << "Lanczos" << endl;
cout << "PSNR: " << quality[0] << " SSIM: " << quality[1] << endl;
cout << "-----------------------------------------------" << endl;
vector <Mat> imgs{img_new, bicubic, nearest, lanczos};
vector <String> titles{sr.getAlgorithm(), "Bicubic", "Nearest neighbor", "Lanczos"};
showBenchmark(imgs, "Quality benchmark", Size(bicubic.cols, bicubic.rows), titles, psnrValues, ssimValues);
waitKey(0);
return 0;
}
\ No newline at end of file
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <iostream>
#include <opencv2/dnn_superres.hpp>
#include <opencv2/quality.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace std;
using namespace cv;
using namespace dnn_superres;
static void showBenchmark(vector<Mat> images, string title, Size imageSize,
const vector<String> imageTitles,
const vector<double> perfValues)
{
int fontFace = FONT_HERSHEY_COMPLEX_SMALL;
int fontScale = 1;
Scalar fontColor = Scalar(255, 255, 255);
int len = static_cast<int>(images.size());
int cols = 2, rows = 2;
Mat fullImage = Mat::zeros(Size((cols * 10) + imageSize.width * cols, (rows * 10) + imageSize.height * rows),
images[0].type());
stringstream ss;
int h_ = -1;
for (int i = 0; i < len; i++) {
int fontStart = 15;
int w_ = i % cols;
if (i % cols == 0)
h_++;
Rect ROI((w_ * (10 + imageSize.width)), (h_ * (10 + imageSize.height)), imageSize.width, imageSize.height);
Mat tmp;
resize(images[i], tmp, Size(ROI.width, ROI.height));
ss << imageTitles[i];
putText(tmp,
ss.str(),
Point(5, fontStart),
fontFace,
fontScale,
fontColor,
1,
16);
ss.str("");
fontStart += 20;
ss << perfValues[i];
putText(tmp,
ss.str(),
Point(5, fontStart),
fontFace,
fontScale,
fontColor,
1,
16);
tmp.copyTo(fullImage(ROI));
}
namedWindow(title, 1);
imshow(title, fullImage);
waitKey();
}
int main(int argc, char *argv[])
{
// Check for valid command line arguments, print usage
// if insufficient arguments were given.
if (argc < 4) {
cout << "usage: Arg 1: image path | Path to image" << endl;
cout << "\t Arg 2: algorithm | edsr, espcn, fsrcnn or lapsrn" << endl;
cout << "\t Arg 3: path to model file 2 \n";
cout << "\t Arg 4: scale | 2, 3, 4 or 8 \n";
return -1;
}
string path = string(argv[1]);
string algorithm = string(argv[2]);
string model = string(argv[3]);
int scale = atoi(argv[4]);
Mat img = imread(path);
if (img.empty()) {
cerr << "Couldn't load image: " << img << "\n";
return -2;
}
//Crop the image so the images will be aligned
int width = img.cols - (img.cols % scale);
int height = img.rows - (img.rows % scale);
Mat cropped = img(Rect(0, 0, width, height));
//Downscale the image for benchmarking
Mat img_downscaled;
resize(cropped, img_downscaled, Size(), 1.0 / scale, 1.0 / scale);
//Make dnn super resolution instance
DnnSuperResImpl sr;
Mat img_new;
//Read and set the dnn model
sr.readModel(model);
sr.setModel(algorithm, scale);
double elapsed = 0.0;
vector<double> perf;
TickMeter tm;
//DL MODEL
tm.start();
sr.upsample(img_downscaled, img_new);
tm.stop();
elapsed = tm.getTimeSec() / tm.getCounter();
perf.push_back(elapsed);
cout << sr.getAlgorithm() << " : " << elapsed << endl;
//BICUBIC
Mat bicubic;
tm.start();
resize(img_downscaled, bicubic, Size(), scale, scale, INTER_CUBIC);
tm.stop();
elapsed = tm.getTimeSec() / tm.getCounter();
perf.push_back(elapsed);
cout << "Bicubic" << " : " << elapsed << endl;
//NEAREST NEIGHBOR
Mat nearest;
tm.start();
resize(img_downscaled, nearest, Size(), scale, scale, INTER_NEAREST);
tm.stop();
elapsed = tm.getTimeSec() / tm.getCounter();
perf.push_back(elapsed);
cout << "Nearest" << " : " << elapsed << endl;
//LANCZOS
Mat lanczos;
tm.start();
resize(img_downscaled, lanczos, Size(), scale, scale, INTER_LANCZOS4);
tm.stop();
elapsed = tm.getTimeSec() / tm.getCounter();
perf.push_back(elapsed);
cout << "Lanczos" << " : " << elapsed << endl;
vector <Mat> imgs{img_new, bicubic, nearest, lanczos};
vector <String> titles{sr.getAlgorithm(), "Bicubic", "Nearest neighbor", "Lanczos"};
showBenchmark(imgs, "Time benchmark", Size(bicubic.cols, bicubic.rows), titles, perf);
waitKey(0);
return 0;
}
\ No newline at end of file
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <iostream>
#include <sstream>
#include <opencv2/dnn_superres.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace std;
using namespace cv;
using namespace dnn_superres;
int main(int argc, char *argv[])
{
// Check for valid command line arguments, print usage
// if insufficient arguments were given.
if (argc < 4) {
cout << "usage: Arg 1: image | Path to image" << endl;
cout << "\t Arg 2: scales in a format of 2,4,8\n";
cout << "\t Arg 3: output node names in a format of nchw_output_0,nchw_output_1\n";
cout << "\t Arg 4: path to model file \n";
return -1;
}
string img_path = string(argv[1]);
string scales_str = string(argv[2]);
string output_names_str = string(argv[3]);
std::string path = string(argv[4]);
//Parse the scaling factors
std::stringstream ss(scales_str);
std::vector<int> scales;
std::string token;
char delim = ',';
while (std::getline(ss, token, delim)) {
scales.push_back(atoi(token.c_str()));
}
//Parse the output node names
ss = std::stringstream(output_names_str);
std::vector<String> node_names;
while (std::getline(ss, token, delim)) {
node_names.push_back(token);
}
// Load the image
Mat img = cv::imread(img_path);
Mat original_img(img);
if (img.empty())
{
std::cerr << "Couldn't load image: " << img << "\n";
return -2;
}
//Make dnn super resolution instance
DnnSuperResImpl sr;
int scale = *max_element(scales.begin(), scales.end());
std::vector<Mat> outputs;
sr.readModel(path);
sr.setModel("lapsrn", scale);
sr.upsampleMultioutput(img, outputs, scales, node_names);
for(unsigned int i = 0; i < outputs.size(); i++)
{
cv::namedWindow("Upsampled image", WINDOW_AUTOSIZE);
cv::imshow("Upsampled image", outputs[i]);
//cv::imwrite("./saved.jpg", img_new);
cv::waitKey(0);
}
return 0;
}
\ No newline at end of file
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include <iostream>
#include <opencv2/dnn_superres.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace std;
using namespace cv;
using namespace dnn_superres;
int main(int argc, char *argv[])
{
// Check for valid command line arguments, print usage
// if insufficient arguments were given.
if (argc < 4) {
cout << "usage: Arg 1: input video path" << endl;
cout << "\t Arg 2: output video path" << endl;
cout << "\t Arg 3: algorithm | edsr, espcn, fsrcnn or lapsrn" << endl;
cout << "\t Arg 4: scale | 2, 3, 4 or 8 \n";
cout << "\t Arg 5: path to model file \n";
return -1;
}
string input_path = string(argv[1]);
string output_path = string(argv[2]);
string algorithm = string(argv[3]);
int scale = atoi(argv[4]);
string path = string(argv[5]);
VideoCapture input_video(input_path);
int ex = static_cast<int>(input_video.get(CAP_PROP_FOURCC));
Size S = Size((int) input_video.get(CAP_PROP_FRAME_WIDTH) * scale,
(int) input_video.get(CAP_PROP_FRAME_HEIGHT) * scale);
VideoWriter output_video;
output_video.open(output_path, ex, input_video.get(CAP_PROP_FPS), S, true);
if (!input_video.isOpened())
{
std::cerr << "Could not open the video." << std::endl;
return -1;
}
DnnSuperResImpl sr;
sr.readModel(path);
sr.setModel(algorithm, scale);
for(;;)
{
Mat frame, output_frame;
input_video >> frame;
if ( frame.empty() )
break;
sr.upsample(frame, output_frame);
output_video << output_frame;
namedWindow("Upsampled video", WINDOW_AUTOSIZE);
imshow("Upsampled video", output_frame);
namedWindow("Original video", WINDOW_AUTOSIZE);
imshow("Original video", frame);
char c=(char)waitKey(25);
if(c==27)
break;
}
input_video.release();
output_video.release();
return 0;
}
\ No newline at end of file
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "precomp.hpp"
#include "opencv2/dnn_superres.hpp"
namespace cv
{
namespace dnn_superres
{
/** @brief Class for importing DepthToSpace layer from the ESPCN model
*/
class DepthToSpace CV_FINAL : public cv::dnn::Layer
{
public:
DepthToSpace(const cv::dnn::LayerParams &params);
static cv::Ptr<cv::dnn::Layer> create(cv::dnn::LayerParams& params);
virtual bool getMemoryShapes(const std::vector<std::vector<int> > &inputs,
const int,
std::vector<std::vector<int> > &outputs,
std::vector<std::vector<int> > &) const CV_OVERRIDE;
virtual void forward(cv::InputArrayOfArrays inputs_arr,
cv::OutputArrayOfArrays outputs_arr,
cv::OutputArrayOfArrays) CV_OVERRIDE;
/// Register this layer
static void registerLayer()
{
static bool initialized = false;
if (!initialized)
{
//Register custom layer that implements pixel shuffling
std::string name = "DepthToSpace";
dnn::LayerParams layerParams = dnn::LayerParams();
cv::dnn::LayerFactory::registerLayer("DepthToSpace", DepthToSpace::create);
initialized = true;
}
}
};
DnnSuperResImpl::DnnSuperResImpl()
{
DepthToSpace::registerLayer();
}
DnnSuperResImpl::DnnSuperResImpl(const std::string& algo, int scale)
: alg(algo), sc(scale)
{
DepthToSpace::registerLayer();
}
void DnnSuperResImpl::readModel(const std::string& path)
{
if ( path.size() )
{
this->net = dnn::readNetFromTensorflow(path);
CV_LOG_INFO(NULL, "Successfully loaded model: " << path);
}
else
{
CV_Error(Error::StsBadArg, std::string("Could not load model: ") + path);
}
}
void DnnSuperResImpl::readModel(const std::string& weights, const std::string& definition)
{
if ( weights.size() && definition.size() )
{
this->net = dnn::readNetFromTensorflow(weights, definition);
CV_LOG_INFO(NULL, "Successfully loaded model: " << weights << " " << definition);
}
else
{
CV_Error(Error::StsBadArg, std::string("Could not load model: ") + weights + " " + definition);
}
}
void DnnSuperResImpl::setModel(const std::string& algo, int scale)
{
this->sc = scale;
this->alg = algo;
}
void DnnSuperResImpl::upsample(InputArray img, OutputArray result)
{
if (net.empty())
CV_Error(Error::StsError, "Model not specified. Please set model via setModel().");
if (this->alg == "espcn" || this->alg == "lapsrn" || this->alg == "fsrcnn")
{
//Preprocess the image: convert to YCrCb float image and normalize
Mat preproc_img;
preprocess_YCrCb(img, preproc_img);
//Split the image: only the Y channel is used for inference
Mat ycbcr_channels[3];
split(preproc_img, ycbcr_channels);
Mat Y = ycbcr_channels[0];
//Create blob from image so it has size 1,1,Width,Height
cv::Mat blob;
dnn::blobFromImage(Y, blob, 1.0);
//Get the HR output
this->net.setInput(blob);
Mat blob_output = this->net.forward();
//Convert from blob
std::vector <Mat> model_outs;
dnn::imagesFromBlob(blob_output, model_outs);
Mat out_img = model_outs[0];
//Reconstruct: upscale the Cr and Cb space and merge the three layer
reconstruct_YCrCb(out_img, preproc_img, result, this->sc);
}
else if (this->alg == "edsr")
{
//BGR mean of the Div2K dataset
Scalar mean = Scalar(103.1545782, 111.561547, 114.35629928);
//Convert to float
Mat float_img;
img.getMat().convertTo(float_img, CV_32F, 1.0);
//Create blob from image so it has size [1,3,Width,Height] and subtract dataset mean
cv::Mat blob;
dnn::blobFromImage(float_img, blob, 1.0, Size(), mean);
//Get the HR output
this->net.setInput(blob);
Mat blob_output = this->net.forward();
//Convert from blob
std::vector <Mat> model_outs;
dnn::imagesFromBlob(blob_output, model_outs);
//Post-process: add mean.
Mat(model_outs[0] + mean).convertTo(result, CV_8U);
}
else
{
CV_Error(cv::Error::StsNotImplemented, std::string("Unknown/unsupported superres algorithm: ") + this->alg);
}
}
void DnnSuperResImpl::upsampleMultioutput(InputArray img, std::vector<Mat> &imgs_new, const std::vector<int>& scale_factors, const std::vector<String>& node_names)
{
CV_Assert(!img.empty());
CV_Assert(scale_factors.size() == node_names.size());
CV_Assert(!scale_factors.empty());
CV_Assert(!node_names.empty());
if ( this->alg != "lapsrn" )
{
CV_Error(cv::Error::StsBadArg, "Only LapSRN support multiscale upsampling for now.");
return;
}
if (net.empty())
CV_Error(Error::StsError, "Model not specified. Please set model via setModel().");
if (this->alg == "lapsrn")
{
Mat orig = img.getMat();
//Preprocess the image: convert to YCrCb float image and normalize
Mat preproc_img;
preprocess_YCrCb(orig, preproc_img);
//Split the image: only the Y channel is used for inference
Mat ycbcr_channels[3];
split(preproc_img, ycbcr_channels);
Mat Y = ycbcr_channels[0];
//Create blob from image so it has size 1,1,Width,Height
cv::Mat blob;
dnn::blobFromImage(Y, blob, 1.0);
//Get the HR outputs
std::vector <Mat> outputs_blobs;
this->net.setInput(blob);
this->net.forward(outputs_blobs, node_names);
for(unsigned int i = 0; i < scale_factors.size(); i++)
{
std::vector <Mat> model_outs;
dnn::imagesFromBlob(outputs_blobs[i], model_outs);
Mat out_img = model_outs[0];
Mat reconstructed;
reconstruct_YCrCb(out_img, preproc_img, reconstructed, scale_factors[i]);
imgs_new.push_back(reconstructed);
}
}
}
int DnnSuperResImpl::getScale()
{
return this->sc;
}
std::string DnnSuperResImpl::getAlgorithm()
{
return this->alg;
}
void DnnSuperResImpl::preprocess_YCrCb(InputArray inpImg, OutputArray outImg)
{
if ( inpImg.type() == CV_8UC1 )
{
inpImg.getMat().convertTo(outImg, CV_32F, 1.0 / 255.0);
}
else if ( inpImg.type() == CV_32FC1 )
{
inpImg.getMat().convertTo(outImg, CV_32F, 1.0 / 255.0);
}
else if ( inpImg.type() == CV_32FC3 )
{
Mat img_float;
inpImg.getMat().convertTo(img_float, CV_32F, 1.0 / 255.0);
cvtColor(img_float, outImg, COLOR_BGR2YCrCb);
}
else if ( inpImg.type() == CV_8UC3 )
{
Mat ycrcb;
cvtColor(inpImg, ycrcb, COLOR_BGR2YCrCb);
ycrcb.convertTo(outImg, CV_32F, 1.0 / 255.0);
}
else
{
CV_Error(Error::StsBadArg, std::string("Not supported image type: ") + typeToString(inpImg.type()));
}
}
void DnnSuperResImpl::reconstruct_YCrCb(InputArray inpImg, InputArray origImg, OutputArray outImg, int scale)
{
if ( origImg.type() == CV_32FC3 )
{
Mat orig_channels[3];
split(origImg.getMat(), orig_channels);
Mat Cr, Cb;
cv::resize(orig_channels[1], Cr, cv::Size(), scale, scale);
cv::resize(orig_channels[2], Cb, cv::Size(), scale, scale);
std::vector <Mat> channels;
channels.push_back(inpImg.getMat());
channels.push_back(Cr);
channels.push_back(Cb);
Mat merged_img;
merge(channels, merged_img);
Mat merged_8u_img;
merged_img.convertTo(merged_8u_img, CV_8U, 255.0);
cvtColor(merged_8u_img, outImg, COLOR_YCrCb2BGR);
}
else if ( origImg.type() == CV_32FC1 )
{
inpImg.getMat().convertTo(outImg, CV_8U, 255.0);
}
else
{
CV_Error(Error::StsBadArg, std::string("Not supported image type: ") + typeToString(origImg.type()));
}
}
DepthToSpace::DepthToSpace(const cv::dnn::LayerParams &params) : Layer(params)
{
}
cv::Ptr<cv::dnn::Layer> DepthToSpace::create(cv::dnn::LayerParams &params)
{
return cv::Ptr<cv::dnn::Layer>(new DepthToSpace(params));
}
bool DepthToSpace::getMemoryShapes(const std::vector <std::vector<int>> &inputs,
const int, std::vector <std::vector<int>> &outputs, std::vector <std::vector<int>> &) const
{
std::vector<int> outShape(4);
int scale;
if( inputs[0][1] == 4 || inputs[0][1] == 9 || inputs[0][1] == 16 ) //Only one image channel
{
scale = static_cast<int>(sqrt(inputs[0][1]));
}
else // Three image channels
{
scale = static_cast<int>(sqrt(inputs[0][1]/3));
}
outShape[0] = inputs[0][0];
outShape[1] = static_cast<int>(inputs[0][1] / pow(scale,2));
outShape[2] = static_cast<int>(scale * inputs[0][2]);
outShape[3] = static_cast<int>(scale * inputs[0][3]);
outputs.assign(4, outShape);
return false;
}
void DepthToSpace::forward(cv::InputArrayOfArrays inputs_arr, cv::OutputArrayOfArrays outputs_arr,
cv::OutputArrayOfArrays)
{
std::vector <cv::Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
cv::Mat &inp = inputs[0];
cv::Mat &out = outputs[0];
const float *inpData = (float *) inp.data;
float *outData = (float *) out.data;
const int inpHeight = inp.size[2];
const int inpWidth = inp.size[3];
const int numChannels = out.size[1];
const int outHeight = out.size[2];
const int outWidth = out.size[3];
int scale = int(outHeight / inpHeight);
int count = 0;
for (int ch = 0; ch < numChannels; ch++)
{
for (int y = 0; y < outHeight; y++)
{
for (int x = 0; x < outWidth; x++)
{
int x_coord = static_cast<int>(floor((y / scale)));
int y_coord = static_cast<int>(floor((x / scale)));
int c_coord = numChannels * scale * (y % scale) + numChannels * (x % scale) + ch;
int index = (((c_coord * inpHeight) + x_coord) * inpWidth) + y_coord;
outData[count++] = inpData[index];
}
}
}
}
}} // cv::dnn_superres::
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef __OPENCV_DNN_SUPERRES_PRECOMP_HPP__
#define __OPENCV_DNN_SUPERRES_PRECOMP_HPP__
#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include "opencv2/core.hpp"
#include <opencv2/core/utils/logger.hpp>
#include "opencv2/dnn.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/dnn_superres.hpp"
#endif // __OPENCV_DNN_SUPERRES_PRECOMP_HPP__
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "test_precomp.hpp"
namespace opencv_test { namespace {
const std::string DNN_SUPERRES_DIR = "dnn_superres";
const std::string IMAGE_FILENAME = "butterfly.png";
/****************************************************************************************\
* Test single output models *
\****************************************************************************************/
void runSingleModel(std::string algorithm, int scale, std::string model_filename)
{
SCOPED_TRACE(algorithm);
Ptr <DnnSuperResImpl> dnn_sr = makePtr<DnnSuperResImpl>();
std::string path = cvtest::findDataFile(DNN_SUPERRES_DIR + "/" + IMAGE_FILENAME);
Mat img = imread(path);
ASSERT_FALSE(img.empty()) << "Test image can't be loaded: " << path;
std::string pb_path = cvtest::findDataFile(DNN_SUPERRES_DIR + "/" + model_filename);
dnn_sr->readModel(pb_path);
dnn_sr->setModel(algorithm, scale);
ASSERT_EQ(scale, dnn_sr->getScale());
ASSERT_EQ(algorithm, dnn_sr->getAlgorithm());
Mat result;
dnn_sr->upsample(img, result);
ASSERT_FALSE(result.empty()) << "Could not perform upsampling for scale algorithm " << algorithm << " and scale factor " << scale;
int new_cols = img.cols * scale;
int new_rows = img.rows * scale;
ASSERT_EQ(new_cols, result.cols);
ASSERT_EQ(new_rows, result.rows);
}
TEST(CV_DnnSuperResSingleOutputTest, accuracy)
{
//x2
runSingleModel("espcn", 2, "ESPCN_x2.pb");
}
/****************************************************************************************\
* Test multi output models *
\****************************************************************************************/
void runMultiModel(std::string algorithm, int scale, std::string model_filename,
std::vector<int> scales, std::vector<String> node_names)
{
SCOPED_TRACE(algorithm);
Ptr <DnnSuperResImpl> dnn_sr = makePtr<DnnSuperResImpl>();
std::string path = cvtest::findDataFile(DNN_SUPERRES_DIR + "/" + IMAGE_FILENAME);
Mat img = imread(path);
ASSERT_FALSE(img.empty()) << "Test image can't be loaded: " << path;
std::string pb_path = cvtest::findDataFile(DNN_SUPERRES_DIR + "/" + model_filename);
dnn_sr->readModel(pb_path);
dnn_sr->setModel(algorithm, scale);
ASSERT_EQ(scale, dnn_sr->getScale());
ASSERT_EQ(algorithm, dnn_sr->getAlgorithm());
std::vector<Mat> outputs;
dnn_sr->upsampleMultioutput(img, outputs, scales, node_names);
for(unsigned int i = 0; i < outputs.size(); i++)
{
SCOPED_TRACE(cv::format("i=%d scale[i]=%d", i, scales[i]));
ASSERT_FALSE(outputs[i].empty());
int new_cols = img.cols * scales[i];
int new_rows = img.rows * scales[i];
EXPECT_EQ(new_cols, outputs[i].cols);
EXPECT_EQ(new_rows, outputs[i].rows);
}
}
TEST(CV_DnnSuperResMultiOutputTest, accuracy)
{
//LAPSRN
//x4
std::vector<String> names_4x {"NCHW_output_2x", "NCHW_output_4x"};
std::vector<int> scales_4x {2, 4};
runMultiModel("lapsrn", 4, "LapSRN_x4.pb", scales_4x, names_4x);
}
}}
\ No newline at end of file
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "test_precomp.hpp"
CV_TEST_MAIN("cv")
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef __OPENCV_TEST_PRECOMP_HPP__
#define __OPENCV_TEST_PRECOMP_HPP__
#include "opencv2/ts.hpp"
#include "opencv2/dnn_superres.hpp"
namespace opencv_test {
using namespace cv::dnn_superres;
}
#endif
Super-resolution benchmarking {#tutorial_dnn_superres_benchmark}
===========================
Benchmarking
----
The super-resolution module contains sample codes for benchmarking, in order to compare different models and algorithms.
Here is presented a sample code for performing benchmarking, and then a few benchmarking results are collected.
It was performed on an Intel i7-9700K CPU on an Ubuntu 18.04.02 OS.
Source Code of the sample
-----------
@includelineno dnn_superres/samples/dnn_superres_benchmark_quality.cpp
Explanation
-----------
-# **Read and downscale the image**
@code{.cpp}
int width = img.cols - (img.cols % scale);
int height = img.rows - (img.rows % scale);
Mat cropped = img(Rect(0, 0, width, height));
Mat img_downscaled;
cv::resize(cropped, img_downscaled, cv::Size(), 1.0 / scale, 1.0 / scale);
@endcode
Resize the image by the scaling factor. Before that a cropping is necessary, so the images will align.
-# **Set the model**
@code{.cpp}
DnnSuperResImpl sr;
sr.readModel(path);
sr.setModel(algorithm, scale);
sr.upsample(img_downscaled, img_new);
@endcode
Instantiate a dnn super-resolution object. Read and set the algorithm and scaling factor.
-# **Perform benchmarking**
@code{.cpp}
double psnr = PSNR(img_new, cropped);
Scalar q = cv::quality::QualitySSIM::compute(img_new, cropped, cv::noArray());
double ssim = mean(cv::Vec3f(q[0], q[1], q[2]))[0];
@endcode
Calculate PSNR and SSIM. Use OpenCVs PSNR (core opencv) and SSIM (contrib) functions to compare the images.
Repeat it with other upscaling algorithms, such as other DL models or interpolation methods (eg. bicubic, nearest neighbor).
Benchmarking results
-----------
Dataset benchmarking
----
###General100 dataset
<center>
#####2x scaling factor
| | Avg inference time in sec (CPU)| Avg PSNR | Avg SSIM |
| ------------- |:-------------------:| ---------:|--------:|
| ESPCN | **0.008795** | 32.7059 | 0.9276 |
| EDSR | 5.923450 | **34.1300** | **0.9447** |
| FSRCNN | 0.021741 | 32.8886 | 0.9301 |
| LapSRN | 0.114812 | 32.2681 | 0.9248 |
| Bicubic | 0.000208 | 32.1638 | 0.9305 |
| Nearest neighbor | 0.000114 | 29.1665 | 0.9049 |
| Lanczos | 0.001094 | 32.4687 | 0.9327 |
#####3x scaling factor
| | Avg inference time in sec (CPU)| Avg PSNR | Avg SSIM |
| ------------- |:-------------------:| ---------:|--------:|
| ESPCN | **0.005495** | 28.4229 | 0.8474 |
| EDSR | 2.455510 | **29.9828** | **0.8801** |
| FSRCNN | 0.008807 | 28.3068 | 0.8429 |
| LapSRN | 0.282575 |26.7330 |0.8862 |
| Bicubic | 0.000311 |26.0635 |0.8754 |
| Nearest neighbor | 0.000148 |23.5628 |0.8174 |
| Lanczos | 0.001012 |25.9115 |0.8706 |
#####4x scaling factor
| | Avg inference time in sec (CPU)| Avg PSNR | Avg SSIM |
| ------------- |:-------------------:| ---------:|--------:|
| ESPCN | **0.004311** | 26.6870 | 0.7891 |
| EDSR | 1.607570 | **28.1552** | **0.8317** |
| FSRCNN | 0.005302 | 26.6088 | 0.7863 |
| LapSRN | 0.121229 |26.7383 |0.7896 |
| Bicubic | 0.000311 |26.0635 |0.8754 |
| Nearest neighbor | 0.000148 |23.5628 |0.8174 |
| Lanczos | 0.001012 |25.9115 |0.8706 |
</center>
Images
----
<center>
####2x scaling factor
|Set5: butterfly.png | size: 256x256 | ||
|:-------------:|:-------------------:|:-------------:|:----:|
|![Original](images/orig_butterfly.jpg)|![Bicubic interpolation](images/bicubic_butterfly.jpg)|![Nearest neighbor interpolation](images/nearest_butterfly.jpg)|![Lanczos interpolation](images/lanczos_butterfly.jpg) |
|PSRN / SSIM / Speed (CPU)|26.6645 / 0.9048 / 0.000201 |23.6854 / 0.8698 / **0.000075** | **26.9476** / **0.9075** / 0.001039|
![ESPCN](images/espcn_butterfly.jpg)| ![FSRCNN](images/fsrcnn_butterfly.jpg) | ![LapSRN](images/lapsrn_butterfly.jpg) | ![EDSR](images/edsr_butterfly.jpg)
|29.0341 / 0.9354 / **0.004157**| 29.0077 / 0.9345 / 0.006325 | 27.8212 / 0.9230 / 0.037937 | **30.0347** / **0.9453** / 2.077280 |
####3x scaling factor
|Urban100: img_001.png | size: 1024x644 | ||
|:-------------:|:-------------------:|:-------------:|:----:|
|![Original](images/orig_urban.jpg)|![Bicubic interpolation](images/bicubic_urban.jpg)|![Nearest neighbor interpolation](images/nearest_urban.jpg)|![Lanczos interpolation](images/lanczos_urban.jpg) |
|PSRN / SSIM / Speed (CPU)| 27.0474 / **0.8484** / 0.000391 | 26.0842 / 0.8353 / **0.000236** | **27.0704** / 0.8483 / 0.002234|
|![ESPCN](images/espcn_urban.jpg)| ![FSRCNN](images/fsrcnn_urban.jpg) | LapSRN is not trained for 3x <br/> because of its architecture | ![EDSR](images/edsr_urban.jpg)
|28.0118 / 0.8588 / **0.030748**| 28.0184 / 0.8597 / 0.094173 | | **30.5671** / **0.9019** / 9.517580 |
####4x scaling factor
|Set14: comic.png | size: 250x361 | ||
|:-------------:|:-------------------:|:-------------:|:----:|
|![Original](images/orig_comic.jpg)|![Bicubic interpolation](images/bicubic_comic.jpg)|![Nearest neighbor interpolation](images/nearest_comic.jpg)|![Lanczos interpolation](images/lanczos_comic.jpg) |
|PSRN / SSIM / Speed (CPU)| **19.6766** / **0.6413** / 0.000262 |18.5106 / 0.5879 / **0.000085** | 19.4948 / 0.6317 / 0.001098|
|![ESPCN](images/espcn_comic.jpg)| ![FSRCNN](images/fsrcnn_comic.jpg) | ![LapSRN](images/lapsrn_comic.jpg) | ![EDSR](images/edsr_comic.jpg)
|20.0417 / 0.6302 / **0.001894**| 20.0885 / 0.6384 / 0.002103 | 20.0676 / 0.6339 / 0.061640 | **20.5233** / **0.6901** / 0.665876 |
####8x scaling factor
|Div2K: 0006.png | size: 1356x2040 | |
|:-------------:|:-------------------:|:-------------:|
|![Original](images/orig_div2k.jpg)|![Bicubic interpolation](images/bicubic_div2k.jpg)|![Nearest neighbor interpolation](images/nearest_div2k.jpg)|
|PSRN / SSIM / Speed (CPU)| 26.3139 / **0.8033** / 0.001107| 23.8291 / 0.7340 / **0.000611** |
|![Lanczos interpolation](images/lanczos_div2k.jpg)| ![LapSRN](images/lapsrn_div2k.jpg) | |
|26.1565 / 0.7962 / 0.004782| **26.7046** / 0.7987 / 2.274290 | |
</center>
\ No newline at end of file
Super Resolution using CNNs {#tutorial_table_of_content_dnn_superres}
============================
- @subpage tutorial_dnn_superres_upscale_image_multi
*Author:* Fanny Monori
How to upscale images using the 'dnn_superres' interface: multi-output
- @subpage tutorial_dnn_superres_upscale_video
*Author:* Fanny Monori
How to upscale a video using the 'dnn_superres' interface.
- @subpage tutorial_dnn_superres_benchmark
Authors:* Fanny Monori & Xavier Weber
Benchmarking of the algorithms.
Upscaling images: multi-output {#tutorial_dnn_superres_upscale_image_multi}
===========================
In this tutorial you will learn how to use the 'dnn_superres' interface to upscale an image via a multi-output pre-trained neural network.
OpenCVs dnn module supports accessing multiple nodes in one inference, if the names of the nodes are given.
Currently there is one model included that is capable of giving more output in one inference run, that is the LapSRN model.
LapSRN supports multiple outputs with one forward pass. It can now support 2x, 4x, 8x, and (2x, 4x) and (2x, 4x, 8x) super-resolution.
The uploaded trained model files have the following output node names:
- 2x model: NCHW_output
- 4x model: NCHW_output_2x, NCHW_output_4x
- 8x model: NCHW_output_2x, NCHW_output_4x, NCHW_output_8x
Building
----
When building OpenCV, run the following command to build the 'dnn_superres' module:
```make
cmake -DOPENCV_EXTRA_MODULES_PATH=<opencv_contrib>/modules -Dopencv_dnn_superres=ON <opencv_source_dir>
```
Or make sure you check the dnn_superres module in the GUI version of CMake: cmake-gui.
Source Code of the sample
-----------
Run the sample code with the following command
```run
./bin/example_dnn_superres_dnn_superres_multioutput path/to/image.png 2,4 NCHW_output_2x,NCHW_output_4x \
path/to/opencv_contrib/modules/dnn_superres/models/LapSRN_x4.pb
```
@includelineno dnn_superres/samples/dnn_superres_multioutput.cpp
Explanation
-----------
-# **Set header and namespaces**
@code{.cpp}
#include <opencv2/dnn_superres.hpp>
using namespace std;
using namespace cv;
using namespace dnn_superres;
@endcode
-# **Create the Dnn Superres object**
@code{.cpp}
DnnSuperResImpl sr;
@endcode
Instantiate a dnn super-resolution object.
-# **Read the model**
@code{.cpp}
path = "models/LapSRN_x8.pb"
sr.readModel(path);
@endcode
Read the model from the given path.
-# **Set the model**
@code{.cpp}
sr.setModel("lapsrn", 8);
@endcode
Sets the algorithm and scaling factor. The last (largest) scaling factor should be given here.
-# **Give the node names and scaling factors**
@code{.cpp}
std::vector<int> scales{2, 4, 8}
std::vector<int> node_names{'NCHW_output_2x','NCHW_output_4x','NCHW_output_8x'}
@endcode
Set the scaling factors, and the output node names in the model.
-# **Upscale an image**
@code{.cpp}
Mat img = cv::imread(img_path);
std::vector<Mat> outputs;
sr.upsampleMultioutput(img, outputs, scales, node_names);
@endcode
Run the inference. The output images will be stored in a Mat vector.
Upscaling video {#tutorial_dnn_superres_upscale_video}
===========================
In this tutorial you will learn how to use the 'dnn_superres' interface to upscale video via pre-trained neural networks.
Building
----
When building OpenCV, run the following command to build the 'dnn_superres' module:
```make
cmake -DOPENCV_EXTRA_MODULES_PATH=<opencv_contrib>/modules -Dopencv_dnn_superres=ON <opencv_source_dir>
```
Or make sure you check the dnn_superres module in the GUI version of CMake: cmake-gui.
Source Code of the sample
-----------
@includelineno dnn_superres/samples/dnn_superres_video.cpp
Explanation
-----------
-# **Set header and namespaces**
@code{.cpp}
#include <opencv2/dnn_superres.hpp>
using namespace std;
using namespace cv;
using namespace dnn_superres;
@endcode
-# **Create the Dnn Superres object**
@code{.cpp}
DnnSuperResImpl sr;
@endcode
Instantiate a dnn super-resolution object.
-# **Read the model**
@code{.cpp}
path = "models/ESPCN_x2.pb"
sr.readModel(path);
sr.setModel("espcn", 2);
@endcode
Read the model from the given path and sets the algorithm and scaling factor.
-# **Upscale a video**
@code{.cpp}
for(;;)
{
Mat frame, output_frame;
input_video >> frame;
if ( frame.empty() )
break;
sr.upsample(frame, output_frame);
...
}
@endcode
Process and upsample video frame by frame.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment