Commit 575b7189 authored by Sergey Lavrushkin's avatar Sergey Lavrushkin Committed by Pedro Arthur

Adds ESPCN super resolution filter merged with SRCNN filter.

Signed-off-by: 's avatarPedro Arthur <bygrandao@gmail.com>
parent d24c9e55
......@@ -260,7 +260,7 @@ External library support:
--enable-libsrt enable Haivision SRT protocol via libsrt [no]
--enable-libssh enable SFTP protocol via libssh [no]
--enable-libtensorflow enable TensorFlow as a DNN module backend
for DNN based filters like srcnn [no]
for DNN based filters like sr [no]
--enable-libtesseract enable Tesseract, needed for ocr filter [no]
--enable-libtheora enable Theora encoding via libtheora [no]
--enable-libtls enable LibreSSL (via libtls), needed for https support
......@@ -3402,8 +3402,8 @@ spectrumsynth_filter_deps="avcodec"
spectrumsynth_filter_select="fft"
spp_filter_deps="gpl avcodec"
spp_filter_select="fft idctdsp fdctdsp me_cmp pixblockdsp"
srcnn_filter_deps="avformat"
srcnn_filter_select="dnn"
sr_filter_deps="avformat swscale"
sr_filter_select="dnn"
stereo3d_filter_deps="gpl"
subtitles_filter_deps="avformat avcodec libass"
super2xsai_filter_deps="gpl"
......@@ -6823,7 +6823,7 @@ enabled signature_filter && prepend avfilter_deps "avcodec avformat"
enabled smartblur_filter && prepend avfilter_deps "swscale"
enabled spectrumsynth_filter && prepend avfilter_deps "avcodec"
enabled spp_filter && prepend avfilter_deps "avcodec"
enabled srcnn_filter && prepend avfilter_deps "avformat"
enabled sr_filter && prepend avfilter_deps "avformat"
enabled subtitles_filter && prepend avfilter_deps "avformat avcodec"
enabled uspp_filter && prepend avfilter_deps "avcodec"
enabled zoompan_filter && prepend avfilter_deps "swscale"
......
......@@ -340,7 +340,7 @@ OBJS-$(CONFIG_SMARTBLUR_FILTER) += vf_smartblur.o
OBJS-$(CONFIG_SOBEL_FILTER) += vf_convolution.o
OBJS-$(CONFIG_SPLIT_FILTER) += split.o
OBJS-$(CONFIG_SPP_FILTER) += vf_spp.o
OBJS-$(CONFIG_SRCNN_FILTER) += vf_srcnn.o
OBJS-$(CONFIG_SR_FILTER) += vf_sr.o
OBJS-$(CONFIG_SSIM_FILTER) += vf_ssim.o framesync.o
OBJS-$(CONFIG_STEREO3D_FILTER) += vf_stereo3d.o
OBJS-$(CONFIG_STREAMSELECT_FILTER) += f_streamselect.o framesync.o
......
......@@ -328,7 +328,7 @@ extern AVFilter ff_vf_smartblur;
extern AVFilter ff_vf_sobel;
extern AVFilter ff_vf_split;
extern AVFilter ff_vf_spp;
extern AVFilter ff_vf_srcnn;
extern AVFilter ff_vf_sr;
extern AVFilter ff_vf_ssim;
extern AVFilter ff_vf_stereo3d;
extern AVFilter ff_vf_streamselect;
......
This diff is collapsed.
......@@ -25,6 +25,7 @@
#include "dnn_backend_tf.h"
#include "dnn_srcnn.h"
#include "dnn_espcn.h"
#include "libavformat/avio.h"
#include <tensorflow/c/c_api.h>
......@@ -35,9 +36,7 @@ typedef struct TFModel{
TF_Status* status;
TF_Output input, output;
TF_Tensor* input_tensor;
TF_Tensor* output_tensor;
const DNNData* input_data;
const DNNData* output_data;
DNNData* output_data;
} TFModel;
static void free_buffer(void* data, size_t length)
......@@ -78,13 +77,13 @@ static TF_Buffer* read_graph(const char* model_filename)
return graph_buf;
}
static DNNReturnType set_input_output_tf(void* model, const DNNData* input, const DNNData* output)
static DNNReturnType set_input_output_tf(void* model, DNNData* input, DNNData* output)
{
TFModel* tf_model = (TFModel*)model;
int64_t input_dims[] = {1, input->height, input->width, input->channels};
int64_t output_dims[] = {1, output->height, output->width, output->channels};
TF_SessionOptions* sess_opts;
const TF_Operation* init_op = TF_GraphOperationByName(tf_model->graph, "init");
TF_Tensor* output_tensor;
// Input operation should be named 'x'
tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x");
......@@ -100,6 +99,7 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons
if (!tf_model->input_tensor){
return DNN_ERROR;
}
input->data = (float*)TF_TensorData(tf_model->input_tensor);
// Output operation should be named 'y'
tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y");
......@@ -107,17 +107,6 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons
return DNN_ERROR;
}
tf_model->output.index = 0;
if (tf_model->output_tensor){
TF_DeleteTensor(tf_model->output_tensor);
}
tf_model->output_tensor = TF_AllocateTensor(TF_FLOAT, output_dims, 4,
output_dims[1] * output_dims[2] * output_dims[3] * sizeof(float));
if (!tf_model->output_tensor){
return DNN_ERROR;
}
tf_model->input_data = input;
tf_model->output_data = output;
if (tf_model->session){
TF_CloseSession(tf_model->session, tf_model->status);
......@@ -144,6 +133,26 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons
}
}
// Execute network to get output height, width and number of channels
TF_SessionRun(tf_model->session, NULL,
&tf_model->input, &tf_model->input_tensor, 1,
&tf_model->output, &output_tensor, 1,
NULL, 0, NULL, tf_model->status);
if (TF_GetCode(tf_model->status) != TF_OK){
return DNN_ERROR;
}
else{
output->height = TF_Dim(output_tensor, 1);
output->width = TF_Dim(output_tensor, 2);
output->channels = TF_Dim(output_tensor, 3);
output->data = av_malloc(output->height * output->width * output->channels * sizeof(float));
if (!output->data){
return DNN_ERROR;
}
tf_model->output_data = output;
TF_DeleteTensor(output_tensor);
}
return DNN_SUCCESS;
}
......@@ -166,7 +175,7 @@ DNNModel* ff_dnn_load_model_tf(const char* model_filename)
}
tf_model->session = NULL;
tf_model->input_tensor = NULL;
tf_model->output_tensor = NULL;
tf_model->output_data = NULL;
graph_def = read_graph(model_filename);
if (!graph_def){
......@@ -215,6 +224,17 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type)
graph_def->length = srcnn_tf_size;
graph_def->data_deallocator = free_buffer;
break;
case DNN_ESPCN:
graph_data = av_malloc(espcn_tf_size);
if (!graph_data){
TF_DeleteBuffer(graph_def);
return NULL;
}
memcpy(graph_data, espcn_tf_model, espcn_tf_size);
graph_def->data = (void*)graph_data;
graph_def->length = espcn_tf_size;
graph_def->data_deallocator = free_buffer;
break;
default:
TF_DeleteBuffer(graph_def);
return NULL;
......@@ -234,7 +254,7 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type)
}
tf_model->session = NULL;
tf_model->input_tensor = NULL;
tf_model->output_tensor = NULL;
tf_model->output_data = NULL;
tf_model->graph = TF_NewGraph();
tf_model->status = TF_NewStatus();
......@@ -259,23 +279,21 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type)
DNNReturnType ff_dnn_execute_model_tf(const DNNModel* model)
{
TFModel* tf_model = (TFModel*)model->model;
memcpy(TF_TensorData(tf_model->input_tensor), tf_model->input_data->data,
tf_model->input_data->height * tf_model->input_data->width *
tf_model->input_data->channels * sizeof(float));
TF_Tensor* output_tensor;
TF_SessionRun(tf_model->session, NULL,
&tf_model->input, &tf_model->input_tensor, 1,
&tf_model->output, &tf_model->output_tensor, 1,
&tf_model->output, &output_tensor, 1,
NULL, 0, NULL, tf_model->status);
if (TF_GetCode(tf_model->status) != TF_OK){
return DNN_ERROR;
}
else{
memcpy(tf_model->output_data->data, TF_TensorData(tf_model->output_tensor),
memcpy(tf_model->output_data->data, TF_TensorData(output_tensor),
tf_model->output_data->height * tf_model->output_data->width *
tf_model->output_data->channels * sizeof(float));
TF_DeleteTensor(output_tensor);
return DNN_SUCCESS;
}
......@@ -300,9 +318,7 @@ void ff_dnn_free_model_tf(DNNModel** model)
if (tf_model->input_tensor){
TF_DeleteTensor(tf_model->input_tensor);
}
if (tf_model->output_tensor){
TF_DeleteTensor(tf_model->output_tensor);
}
av_freep(&tf_model->output_data->data);
av_freep(&tf_model);
av_freep(model);
}
......
This diff is collapsed.
......@@ -30,7 +30,7 @@ typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType;
typedef enum {DNN_NATIVE, DNN_TF} DNNBackendType;
typedef enum {DNN_SRCNN} DNNDefaultModel;
typedef enum {DNN_SRCNN, DNN_ESPCN} DNNDefaultModel;
typedef struct DNNData{
float* data;
......@@ -42,7 +42,7 @@ typedef struct DNNModel{
void* model;
// Sets model input and output, while allocating additional memory for intermediate calculations.
// Should be called at least once before model execution.
DNNReturnType (*set_input_output)(void* model, const DNNData* input, const DNNData* output);
DNNReturnType (*set_input_output)(void* model, DNNData* input, DNNData* output);
} DNNModel;
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
......
......@@ -20,13 +20,13 @@
/**
* @file
* Default cnn weights for x2 upsampling with srcnn filter.
* Default cnn weights for x2 upsampling with srcnn model.
*/
#ifndef AVFILTER_DNN_SRCNN_H
#define AVFILTER_DNN_SRCNN_H
static const float conv1_kernel[] = {
static const float srcnn_conv1_kernel[] = {
-0.08866338f, 0.055409566f, 0.037196506f, -0.11961404f,
-0.12341991f, 0.29963422f, -0.0911817f, -0.00013613555f,
-0.049023595f, 0.038421184f, -0.077267796f, 0.027273094f,
......@@ -1325,7 +1325,7 @@ static const float conv1_kernel[] = {
-0.013759381f, 0.026358005f, 0.088238746f, 0.082134426f
};
static const float conv1_biases[] = {
static const float srcnn_conv1_biases[] = {
-0.016606892f, -0.011107335f, -0.0048309686f, -0.04867378f,
-0.030040957f, -0.07297248f, -0.019458665f, -0.009738028f,
0.6951231f, -0.07369442f, -0.01354204f, 0.010336088f,
......@@ -1344,7 +1344,7 @@ static const float conv1_biases[] = {
0.054407462f, -0.08068252f, -0.009446503f, -0.04663234f
};
static const float conv2_kernel[] = {
static const float srcnn_conv2_kernel[] = {
-0.24004751f, 0.1037138f, 0.11173403f, 0.04352092f,
-0.23728481f, 0.12153747f, -0.23676059f, -0.28548065f,
-0.612738f, -0.12218937f, -0.06005159f, 0.1850652f,
......@@ -1859,7 +1859,7 @@ static const float conv2_kernel[] = {
0.11089696f, -0.08941251f, -0.3529318f, 0.0654588f
};
static const float conv2_biases[] = {
static const float srcnn_conv2_biases[] = {
0.12326373f, 0.13270757f, 0.07082674f, 0.051456157f,
0.058445618f, 0.13153197f, 0.0809729f, 0.10153213f,
0.055915363f, 0.05228166f, -0.11212896f, 0.07462141f,
......@@ -1870,7 +1870,7 @@ static const float conv2_biases[] = {
-0.086404406f, 0.06046943f, -0.1733751f, 0.2654999f
};
static const float conv3_kernel[] = {
static const float srcnn_conv3_kernel[] = {
-0.01733648f, 0.01492609f, 0.019393086f, -0.004445322f,
0.026939709f, 0.00038831023f, 0.004221528f, 0.0050745453f,
0.0129861f, 0.008007169f, 0.008950762f, 0.005279691f,
......@@ -2073,7 +2073,7 @@ static const float conv3_kernel[] = {
0.012931146f, 0.0046948805f, 0.013098622f, -0.015422701f
};
static const float conv3_biases[] = {
static const float srcnn_conv3_biases[] = {
0.05037664f
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment