Commit 575b7189 authored by Sergey Lavrushkin's avatar Sergey Lavrushkin Committed by Pedro Arthur

Adds ESPCN super resolution filter merged with SRCNN filter.

Signed-off-by: 's avatarPedro Arthur <bygrandao@gmail.com>
parent d24c9e55
...@@ -260,7 +260,7 @@ External library support: ...@@ -260,7 +260,7 @@ External library support:
--enable-libsrt enable Haivision SRT protocol via libsrt [no] --enable-libsrt enable Haivision SRT protocol via libsrt [no]
--enable-libssh enable SFTP protocol via libssh [no] --enable-libssh enable SFTP protocol via libssh [no]
--enable-libtensorflow enable TensorFlow as a DNN module backend --enable-libtensorflow enable TensorFlow as a DNN module backend
for DNN based filters like srcnn [no] for DNN based filters like sr [no]
--enable-libtesseract enable Tesseract, needed for ocr filter [no] --enable-libtesseract enable Tesseract, needed for ocr filter [no]
--enable-libtheora enable Theora encoding via libtheora [no] --enable-libtheora enable Theora encoding via libtheora [no]
--enable-libtls enable LibreSSL (via libtls), needed for https support --enable-libtls enable LibreSSL (via libtls), needed for https support
...@@ -3402,8 +3402,8 @@ spectrumsynth_filter_deps="avcodec" ...@@ -3402,8 +3402,8 @@ spectrumsynth_filter_deps="avcodec"
spectrumsynth_filter_select="fft" spectrumsynth_filter_select="fft"
spp_filter_deps="gpl avcodec" spp_filter_deps="gpl avcodec"
spp_filter_select="fft idctdsp fdctdsp me_cmp pixblockdsp" spp_filter_select="fft idctdsp fdctdsp me_cmp pixblockdsp"
srcnn_filter_deps="avformat" sr_filter_deps="avformat swscale"
srcnn_filter_select="dnn" sr_filter_select="dnn"
stereo3d_filter_deps="gpl" stereo3d_filter_deps="gpl"
subtitles_filter_deps="avformat avcodec libass" subtitles_filter_deps="avformat avcodec libass"
super2xsai_filter_deps="gpl" super2xsai_filter_deps="gpl"
...@@ -6823,7 +6823,7 @@ enabled signature_filter && prepend avfilter_deps "avcodec avformat" ...@@ -6823,7 +6823,7 @@ enabled signature_filter && prepend avfilter_deps "avcodec avformat"
enabled smartblur_filter && prepend avfilter_deps "swscale" enabled smartblur_filter && prepend avfilter_deps "swscale"
enabled spectrumsynth_filter && prepend avfilter_deps "avcodec" enabled spectrumsynth_filter && prepend avfilter_deps "avcodec"
enabled spp_filter && prepend avfilter_deps "avcodec" enabled spp_filter && prepend avfilter_deps "avcodec"
enabled srcnn_filter && prepend avfilter_deps "avformat" enabled sr_filter && prepend avfilter_deps "avformat"
enabled subtitles_filter && prepend avfilter_deps "avformat avcodec" enabled subtitles_filter && prepend avfilter_deps "avformat avcodec"
enabled uspp_filter && prepend avfilter_deps "avcodec" enabled uspp_filter && prepend avfilter_deps "avcodec"
enabled zoompan_filter && prepend avfilter_deps "swscale" enabled zoompan_filter && prepend avfilter_deps "swscale"
......
...@@ -340,7 +340,7 @@ OBJS-$(CONFIG_SMARTBLUR_FILTER) += vf_smartblur.o ...@@ -340,7 +340,7 @@ OBJS-$(CONFIG_SMARTBLUR_FILTER) += vf_smartblur.o
OBJS-$(CONFIG_SOBEL_FILTER) += vf_convolution.o OBJS-$(CONFIG_SOBEL_FILTER) += vf_convolution.o
OBJS-$(CONFIG_SPLIT_FILTER) += split.o OBJS-$(CONFIG_SPLIT_FILTER) += split.o
OBJS-$(CONFIG_SPP_FILTER) += vf_spp.o OBJS-$(CONFIG_SPP_FILTER) += vf_spp.o
OBJS-$(CONFIG_SRCNN_FILTER) += vf_srcnn.o OBJS-$(CONFIG_SR_FILTER) += vf_sr.o
OBJS-$(CONFIG_SSIM_FILTER) += vf_ssim.o framesync.o OBJS-$(CONFIG_SSIM_FILTER) += vf_ssim.o framesync.o
OBJS-$(CONFIG_STEREO3D_FILTER) += vf_stereo3d.o OBJS-$(CONFIG_STEREO3D_FILTER) += vf_stereo3d.o
OBJS-$(CONFIG_STREAMSELECT_FILTER) += f_streamselect.o framesync.o OBJS-$(CONFIG_STREAMSELECT_FILTER) += f_streamselect.o framesync.o
......
...@@ -328,7 +328,7 @@ extern AVFilter ff_vf_smartblur; ...@@ -328,7 +328,7 @@ extern AVFilter ff_vf_smartblur;
extern AVFilter ff_vf_sobel; extern AVFilter ff_vf_sobel;
extern AVFilter ff_vf_split; extern AVFilter ff_vf_split;
extern AVFilter ff_vf_spp; extern AVFilter ff_vf_spp;
extern AVFilter ff_vf_srcnn; extern AVFilter ff_vf_sr;
extern AVFilter ff_vf_ssim; extern AVFilter ff_vf_ssim;
extern AVFilter ff_vf_stereo3d; extern AVFilter ff_vf_stereo3d;
extern AVFilter ff_vf_streamselect; extern AVFilter ff_vf_streamselect;
......
This diff is collapsed.
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "dnn_backend_tf.h" #include "dnn_backend_tf.h"
#include "dnn_srcnn.h" #include "dnn_srcnn.h"
#include "dnn_espcn.h"
#include "libavformat/avio.h" #include "libavformat/avio.h"
#include <tensorflow/c/c_api.h> #include <tensorflow/c/c_api.h>
...@@ -35,9 +36,7 @@ typedef struct TFModel{ ...@@ -35,9 +36,7 @@ typedef struct TFModel{
TF_Status* status; TF_Status* status;
TF_Output input, output; TF_Output input, output;
TF_Tensor* input_tensor; TF_Tensor* input_tensor;
TF_Tensor* output_tensor; DNNData* output_data;
const DNNData* input_data;
const DNNData* output_data;
} TFModel; } TFModel;
static void free_buffer(void* data, size_t length) static void free_buffer(void* data, size_t length)
...@@ -78,13 +77,13 @@ static TF_Buffer* read_graph(const char* model_filename) ...@@ -78,13 +77,13 @@ static TF_Buffer* read_graph(const char* model_filename)
return graph_buf; return graph_buf;
} }
static DNNReturnType set_input_output_tf(void* model, const DNNData* input, const DNNData* output) static DNNReturnType set_input_output_tf(void* model, DNNData* input, DNNData* output)
{ {
TFModel* tf_model = (TFModel*)model; TFModel* tf_model = (TFModel*)model;
int64_t input_dims[] = {1, input->height, input->width, input->channels}; int64_t input_dims[] = {1, input->height, input->width, input->channels};
int64_t output_dims[] = {1, output->height, output->width, output->channels};
TF_SessionOptions* sess_opts; TF_SessionOptions* sess_opts;
const TF_Operation* init_op = TF_GraphOperationByName(tf_model->graph, "init"); const TF_Operation* init_op = TF_GraphOperationByName(tf_model->graph, "init");
TF_Tensor* output_tensor;
// Input operation should be named 'x' // Input operation should be named 'x'
tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x"); tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x");
...@@ -100,6 +99,7 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons ...@@ -100,6 +99,7 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons
if (!tf_model->input_tensor){ if (!tf_model->input_tensor){
return DNN_ERROR; return DNN_ERROR;
} }
input->data = (float*)TF_TensorData(tf_model->input_tensor);
// Output operation should be named 'y' // Output operation should be named 'y'
tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y"); tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y");
...@@ -107,17 +107,6 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons ...@@ -107,17 +107,6 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons
return DNN_ERROR; return DNN_ERROR;
} }
tf_model->output.index = 0; tf_model->output.index = 0;
if (tf_model->output_tensor){
TF_DeleteTensor(tf_model->output_tensor);
}
tf_model->output_tensor = TF_AllocateTensor(TF_FLOAT, output_dims, 4,
output_dims[1] * output_dims[2] * output_dims[3] * sizeof(float));
if (!tf_model->output_tensor){
return DNN_ERROR;
}
tf_model->input_data = input;
tf_model->output_data = output;
if (tf_model->session){ if (tf_model->session){
TF_CloseSession(tf_model->session, tf_model->status); TF_CloseSession(tf_model->session, tf_model->status);
...@@ -144,6 +133,26 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons ...@@ -144,6 +133,26 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons
} }
} }
// Execute network to get output height, width and number of channels
TF_SessionRun(tf_model->session, NULL,
&tf_model->input, &tf_model->input_tensor, 1,
&tf_model->output, &output_tensor, 1,
NULL, 0, NULL, tf_model->status);
if (TF_GetCode(tf_model->status) != TF_OK){
return DNN_ERROR;
}
else{
output->height = TF_Dim(output_tensor, 1);
output->width = TF_Dim(output_tensor, 2);
output->channels = TF_Dim(output_tensor, 3);
output->data = av_malloc(output->height * output->width * output->channels * sizeof(float));
if (!output->data){
return DNN_ERROR;
}
tf_model->output_data = output;
TF_DeleteTensor(output_tensor);
}
return DNN_SUCCESS; return DNN_SUCCESS;
} }
...@@ -166,7 +175,7 @@ DNNModel* ff_dnn_load_model_tf(const char* model_filename) ...@@ -166,7 +175,7 @@ DNNModel* ff_dnn_load_model_tf(const char* model_filename)
} }
tf_model->session = NULL; tf_model->session = NULL;
tf_model->input_tensor = NULL; tf_model->input_tensor = NULL;
tf_model->output_tensor = NULL; tf_model->output_data = NULL;
graph_def = read_graph(model_filename); graph_def = read_graph(model_filename);
if (!graph_def){ if (!graph_def){
...@@ -215,6 +224,17 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type) ...@@ -215,6 +224,17 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type)
graph_def->length = srcnn_tf_size; graph_def->length = srcnn_tf_size;
graph_def->data_deallocator = free_buffer; graph_def->data_deallocator = free_buffer;
break; break;
case DNN_ESPCN:
graph_data = av_malloc(espcn_tf_size);
if (!graph_data){
TF_DeleteBuffer(graph_def);
return NULL;
}
memcpy(graph_data, espcn_tf_model, espcn_tf_size);
graph_def->data = (void*)graph_data;
graph_def->length = espcn_tf_size;
graph_def->data_deallocator = free_buffer;
break;
default: default:
TF_DeleteBuffer(graph_def); TF_DeleteBuffer(graph_def);
return NULL; return NULL;
...@@ -234,7 +254,7 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type) ...@@ -234,7 +254,7 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type)
} }
tf_model->session = NULL; tf_model->session = NULL;
tf_model->input_tensor = NULL; tf_model->input_tensor = NULL;
tf_model->output_tensor = NULL; tf_model->output_data = NULL;
tf_model->graph = TF_NewGraph(); tf_model->graph = TF_NewGraph();
tf_model->status = TF_NewStatus(); tf_model->status = TF_NewStatus();
...@@ -259,23 +279,21 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type) ...@@ -259,23 +279,21 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type)
DNNReturnType ff_dnn_execute_model_tf(const DNNModel* model) DNNReturnType ff_dnn_execute_model_tf(const DNNModel* model)
{ {
TFModel* tf_model = (TFModel*)model->model; TFModel* tf_model = (TFModel*)model->model;
TF_Tensor* output_tensor;
memcpy(TF_TensorData(tf_model->input_tensor), tf_model->input_data->data,
tf_model->input_data->height * tf_model->input_data->width *
tf_model->input_data->channels * sizeof(float));
TF_SessionRun(tf_model->session, NULL, TF_SessionRun(tf_model->session, NULL,
&tf_model->input, &tf_model->input_tensor, 1, &tf_model->input, &tf_model->input_tensor, 1,
&tf_model->output, &tf_model->output_tensor, 1, &tf_model->output, &output_tensor, 1,
NULL, 0, NULL, tf_model->status); NULL, 0, NULL, tf_model->status);
if (TF_GetCode(tf_model->status) != TF_OK){ if (TF_GetCode(tf_model->status) != TF_OK){
return DNN_ERROR; return DNN_ERROR;
} }
else{ else{
memcpy(tf_model->output_data->data, TF_TensorData(tf_model->output_tensor), memcpy(tf_model->output_data->data, TF_TensorData(output_tensor),
tf_model->output_data->height * tf_model->output_data->width * tf_model->output_data->height * tf_model->output_data->width *
tf_model->output_data->channels * sizeof(float)); tf_model->output_data->channels * sizeof(float));
TF_DeleteTensor(output_tensor);
return DNN_SUCCESS; return DNN_SUCCESS;
} }
...@@ -300,9 +318,7 @@ void ff_dnn_free_model_tf(DNNModel** model) ...@@ -300,9 +318,7 @@ void ff_dnn_free_model_tf(DNNModel** model)
if (tf_model->input_tensor){ if (tf_model->input_tensor){
TF_DeleteTensor(tf_model->input_tensor); TF_DeleteTensor(tf_model->input_tensor);
} }
if (tf_model->output_tensor){ av_freep(&tf_model->output_data->data);
TF_DeleteTensor(tf_model->output_tensor);
}
av_freep(&tf_model); av_freep(&tf_model);
av_freep(model); av_freep(model);
} }
......
This diff is collapsed.
...@@ -30,7 +30,7 @@ typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType; ...@@ -30,7 +30,7 @@ typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType;
typedef enum {DNN_NATIVE, DNN_TF} DNNBackendType; typedef enum {DNN_NATIVE, DNN_TF} DNNBackendType;
typedef enum {DNN_SRCNN} DNNDefaultModel; typedef enum {DNN_SRCNN, DNN_ESPCN} DNNDefaultModel;
typedef struct DNNData{ typedef struct DNNData{
float* data; float* data;
...@@ -42,7 +42,7 @@ typedef struct DNNModel{ ...@@ -42,7 +42,7 @@ typedef struct DNNModel{
void* model; void* model;
// Sets model input and output, while allocating additional memory for intermediate calculations. // Sets model input and output, while allocating additional memory for intermediate calculations.
// Should be called at least once before model execution. // Should be called at least once before model execution.
DNNReturnType (*set_input_output)(void* model, const DNNData* input, const DNNData* output); DNNReturnType (*set_input_output)(void* model, DNNData* input, DNNData* output);
} DNNModel; } DNNModel;
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends. // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
......
...@@ -20,13 +20,13 @@ ...@@ -20,13 +20,13 @@
/** /**
* @file * @file
* Default cnn weights for x2 upsampling with srcnn filter. * Default cnn weights for x2 upsampling with srcnn model.
*/ */
#ifndef AVFILTER_DNN_SRCNN_H #ifndef AVFILTER_DNN_SRCNN_H
#define AVFILTER_DNN_SRCNN_H #define AVFILTER_DNN_SRCNN_H
static const float conv1_kernel[] = { static const float srcnn_conv1_kernel[] = {
-0.08866338f, 0.055409566f, 0.037196506f, -0.11961404f, -0.08866338f, 0.055409566f, 0.037196506f, -0.11961404f,
-0.12341991f, 0.29963422f, -0.0911817f, -0.00013613555f, -0.12341991f, 0.29963422f, -0.0911817f, -0.00013613555f,
-0.049023595f, 0.038421184f, -0.077267796f, 0.027273094f, -0.049023595f, 0.038421184f, -0.077267796f, 0.027273094f,
...@@ -1325,7 +1325,7 @@ static const float conv1_kernel[] = { ...@@ -1325,7 +1325,7 @@ static const float conv1_kernel[] = {
-0.013759381f, 0.026358005f, 0.088238746f, 0.082134426f -0.013759381f, 0.026358005f, 0.088238746f, 0.082134426f
}; };
static const float conv1_biases[] = { static const float srcnn_conv1_biases[] = {
-0.016606892f, -0.011107335f, -0.0048309686f, -0.04867378f, -0.016606892f, -0.011107335f, -0.0048309686f, -0.04867378f,
-0.030040957f, -0.07297248f, -0.019458665f, -0.009738028f, -0.030040957f, -0.07297248f, -0.019458665f, -0.009738028f,
0.6951231f, -0.07369442f, -0.01354204f, 0.010336088f, 0.6951231f, -0.07369442f, -0.01354204f, 0.010336088f,
...@@ -1344,7 +1344,7 @@ static const float conv1_biases[] = { ...@@ -1344,7 +1344,7 @@ static const float conv1_biases[] = {
0.054407462f, -0.08068252f, -0.009446503f, -0.04663234f 0.054407462f, -0.08068252f, -0.009446503f, -0.04663234f
}; };
static const float conv2_kernel[] = { static const float srcnn_conv2_kernel[] = {
-0.24004751f, 0.1037138f, 0.11173403f, 0.04352092f, -0.24004751f, 0.1037138f, 0.11173403f, 0.04352092f,
-0.23728481f, 0.12153747f, -0.23676059f, -0.28548065f, -0.23728481f, 0.12153747f, -0.23676059f, -0.28548065f,
-0.612738f, -0.12218937f, -0.06005159f, 0.1850652f, -0.612738f, -0.12218937f, -0.06005159f, 0.1850652f,
...@@ -1859,7 +1859,7 @@ static const float conv2_kernel[] = { ...@@ -1859,7 +1859,7 @@ static const float conv2_kernel[] = {
0.11089696f, -0.08941251f, -0.3529318f, 0.0654588f 0.11089696f, -0.08941251f, -0.3529318f, 0.0654588f
}; };
static const float conv2_biases[] = { static const float srcnn_conv2_biases[] = {
0.12326373f, 0.13270757f, 0.07082674f, 0.051456157f, 0.12326373f, 0.13270757f, 0.07082674f, 0.051456157f,
0.058445618f, 0.13153197f, 0.0809729f, 0.10153213f, 0.058445618f, 0.13153197f, 0.0809729f, 0.10153213f,
0.055915363f, 0.05228166f, -0.11212896f, 0.07462141f, 0.055915363f, 0.05228166f, -0.11212896f, 0.07462141f,
...@@ -1870,7 +1870,7 @@ static const float conv2_biases[] = { ...@@ -1870,7 +1870,7 @@ static const float conv2_biases[] = {
-0.086404406f, 0.06046943f, -0.1733751f, 0.2654999f -0.086404406f, 0.06046943f, -0.1733751f, 0.2654999f
}; };
static const float conv3_kernel[] = { static const float srcnn_conv3_kernel[] = {
-0.01733648f, 0.01492609f, 0.019393086f, -0.004445322f, -0.01733648f, 0.01492609f, 0.019393086f, -0.004445322f,
0.026939709f, 0.00038831023f, 0.004221528f, 0.0050745453f, 0.026939709f, 0.00038831023f, 0.004221528f, 0.0050745453f,
0.0129861f, 0.008007169f, 0.008950762f, 0.005279691f, 0.0129861f, 0.008007169f, 0.008950762f, 0.005279691f,
...@@ -2073,7 +2073,7 @@ static const float conv3_kernel[] = { ...@@ -2073,7 +2073,7 @@ static const float conv3_kernel[] = {
0.012931146f, 0.0046948805f, 0.013098622f, -0.015422701f 0.012931146f, 0.0046948805f, 0.013098622f, -0.015422701f
}; };
static const float conv3_biases[] = { static const float srcnn_conv3_biases[] = {
0.05037664f 0.05037664f
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment