From e4aa39f9e5a528bfe64b97126489b06c954e60ab Mon Sep 17 00:00:00 2001
From: Dmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Date: Thu, 28 Sep 2017 16:51:47 +0300
Subject: [PATCH] Text TensorFlow graphs parsing. MobileNet-SSD for 90 classes.

---
 modules/dnn/include/opencv2/dnn/dnn.hpp       |   2 +-
 .../dnn/src/layers/detection_output_layer.cpp |  28 ++-
 modules/dnn/src/layers/prior_box_layer.cpp    |  39 +++-
 modules/dnn/src/tensorflow/tf_importer.cpp    | 189 +++++++++++++++---
 modules/dnn/src/tensorflow/tf_io.cpp          |  15 ++
 modules/dnn/src/tensorflow/tf_io.hpp          |   3 +
 modules/dnn/test/test_tf_importer.cpp         |  60 ++++--
 samples/dnn/mobilenet_ssd_accuracy.py         | 131 ++++++++++++
 samples/dnn/mobilenet_ssd_python.py           |  94 ++++++---
 samples/dnn/shrink_tf_graph_weights.py        |  62 ++++++
 10 files changed, 538 insertions(+), 85 deletions(-)
 create mode 100644 samples/dnn/mobilenet_ssd_accuracy.py
 create mode 100644 samples/dnn/shrink_tf_graph_weights.py

diff --git a/modules/dnn/include/opencv2/dnn/dnn.hpp b/modules/dnn/include/opencv2/dnn/dnn.hpp
index cb015d4eba..640be7a525 100644
--- a/modules/dnn/include/opencv2/dnn/dnn.hpp
+++ b/modules/dnn/include/opencv2/dnn/dnn.hpp
@@ -629,7 +629,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
     /** @brief Reads a network model stored in Tensorflow model file.
       * @details This is shortcut consisting from createTensorflowImporter and Net::populateNet calls.
       */
-    CV_EXPORTS_W Net readNetFromTensorflow(const String &model);
+    CV_EXPORTS_W Net readNetFromTensorflow(const String &model, const String &config = String());
 
     /** @brief Reads a network model stored in Torch model file.
       * @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
diff --git a/modules/dnn/src/layers/detection_output_layer.cpp b/modules/dnn/src/layers/detection_output_layer.cpp
index 0b72326f7e..505b9c7b74 100644
--- a/modules/dnn/src/layers/detection_output_layer.cpp
+++ b/modules/dnn/src/layers/detection_output_layer.cpp
@@ -81,6 +81,8 @@ public:
 
     float _nmsThreshold;
     int _topK;
+    // Whenever predicted bounding boxes are respresented in YXHW instead of XYWH layout.
+    bool _locPredTransposed;
 
     enum { _numAxes = 4 };
     static const std::string _layerName;
@@ -148,6 +150,7 @@ public:
         _keepTopK = getParameter<int>(params, "keep_top_k");
         _confidenceThreshold = getParameter<float>(params, "confidence_threshold", 0, false, -FLT_MAX);
         _topK = getParameter<int>(params, "top_k", 0, false, -1);
+        _locPredTransposed = getParameter<bool>(params, "loc_pred_transposed", 0, false, false);
 
         getCodeType(params);
 
@@ -209,7 +212,7 @@ public:
             // Retrieve all location predictions
             std::vector<LabelBBox> allLocationPredictions;
             GetLocPredictions(locationData, num, numPriors, _numLocClasses,
-                              _shareLocation, allLocationPredictions);
+                              _shareLocation, _locPredTransposed, allLocationPredictions);
 
             // Retrieve all confidences
             GetConfidenceScores(confidenceData, num, numPriors, _numClasses, allConfidenceScores);
@@ -540,11 +543,14 @@ public:
     //    num_loc_classes: number of location classes. It is 1 if share_location is
     //      true; and is equal to number of classes needed to predict otherwise.
     //    share_location: if true, all classes share the same location prediction.
+    //    loc_pred_transposed: if true, represent four bounding box values as
+    //                         [y,x,height,width] or [x,y,width,height] otherwise.
     //    loc_preds: stores the location prediction, where each item contains
     //      location prediction for an image.
     static void GetLocPredictions(const float* locData, const int num,
                            const int numPredsPerClass, const int numLocClasses,
-                           const bool shareLocation, std::vector<LabelBBox>& locPreds)
+                           const bool shareLocation, const bool locPredTransposed,
+                           std::vector<LabelBBox>& locPreds)
     {
         locPreds.clear();
         if (shareLocation)
@@ -566,10 +572,20 @@ public:
                         labelBBox[label].resize(numPredsPerClass);
                     }
                     caffe::NormalizedBBox& bbox = labelBBox[label][p];
-                    bbox.set_xmin(locData[startIdx + c * 4]);
-                    bbox.set_ymin(locData[startIdx + c * 4 + 1]);
-                    bbox.set_xmax(locData[startIdx + c * 4 + 2]);
-                    bbox.set_ymax(locData[startIdx + c * 4 + 3]);
+                    if (locPredTransposed)
+                    {
+                        bbox.set_ymin(locData[startIdx + c * 4]);
+                        bbox.set_xmin(locData[startIdx + c * 4 + 1]);
+                        bbox.set_ymax(locData[startIdx + c * 4 + 2]);
+                        bbox.set_xmax(locData[startIdx + c * 4 + 3]);
+                    }
+                    else
+                    {
+                        bbox.set_xmin(locData[startIdx + c * 4]);
+                        bbox.set_ymin(locData[startIdx + c * 4 + 1]);
+                        bbox.set_xmax(locData[startIdx + c * 4 + 2]);
+                        bbox.set_ymax(locData[startIdx + c * 4 + 3]);
+                    }
                 }
             }
         }
diff --git a/modules/dnn/src/layers/prior_box_layer.cpp b/modules/dnn/src/layers/prior_box_layer.cpp
index 8fa99ac840..75831d0269 100644
--- a/modules/dnn/src/layers/prior_box_layer.cpp
+++ b/modules/dnn/src/layers/prior_box_layer.cpp
@@ -124,6 +124,20 @@ public:
         }
     }
 
+    void getScales(const LayerParams &params)
+    {
+        DictValue scalesParameter;
+        bool scalesRetieved = getParameterDict(params, "scales", scalesParameter);
+        if (scalesRetieved)
+        {
+            _scales.resize(scalesParameter.size());
+            for (int i = 0; i < scalesParameter.size(); ++i)
+            {
+                _scales[i] = scalesParameter.get<float>(i);
+            }
+        }
+    }
+
     void getVariance(const LayerParams &params)
     {
         DictValue varianceParameter;
@@ -169,13 +183,14 @@ public:
         _flip = getParameter<bool>(params, "flip");
         _clip = getParameter<bool>(params, "clip");
 
+        _scales.clear();
         _aspectRatios.clear();
-        _aspectRatios.push_back(1.);
 
         getAspectRatios(params);
         getVariance(params);
+        getScales(params);
 
-        _numPriors = _aspectRatios.size();
+        _numPriors = _aspectRatios.size() + 1;  // + 1 for an aspect ratio 1.0
 
         _maxSize = -1;
         if (params.has("max_size"))
@@ -231,6 +246,11 @@ public:
         CV_TRACE_FUNCTION();
         CV_TRACE_ARG_VALUE(name, "name", name.c_str());
 
+        if (_scales.empty())
+            _scales.resize(_numPriors, 1.0f);
+        else
+            CV_Assert(_scales.size() == _numPriors);
+
         int _layerWidth = inputs[0]->size[3];
         int _layerHeight = inputs[0]->size[2];
 
@@ -256,7 +276,7 @@ public:
         {
             for (size_t w = 0; w < _layerWidth; ++w)
             {
-                _boxWidth = _boxHeight = _minSize;
+                _boxWidth = _boxHeight = _minSize * _scales[0];
 
                 float center_x = (w + 0.5) * stepX;
                 float center_y = (h + 0.5) * stepY;
@@ -272,7 +292,7 @@ public:
                 if (_maxSize > 0)
                 {
                     // second prior: aspect_ratio = 1, size = sqrt(min_size * max_size)
-                    _boxWidth = _boxHeight = sqrt(_minSize * _maxSize);
+                    _boxWidth = _boxHeight = sqrt(_minSize * _maxSize) * _scales[1];
                     // xmin
                     outputPtr[idx++] = (center_x - _boxWidth / 2.) / _imageWidth;
                     // ymin
@@ -284,15 +304,13 @@ public:
                 }
 
                 // rest of priors
+                CV_Assert((_maxSize > 0 ? 2 : 1) + _aspectRatios.size() == _scales.size());
                 for (size_t r = 0; r < _aspectRatios.size(); ++r)
                 {
                     float ar = _aspectRatios[r];
-                    if (fabs(ar - 1.) < 1e-6)
-                    {
-                        continue;
-                    }
-                    _boxWidth = _minSize * sqrt(ar);
-                    _boxHeight = _minSize / sqrt(ar);
+                    float scale = _scales[(_maxSize > 0 ? 2 : 1) + r];
+                    _boxWidth = _minSize * sqrt(ar) * scale;
+                    _boxHeight = _minSize / sqrt(ar) * scale;
                     // xmin
                     outputPtr[idx++] = (center_x - _boxWidth / 2.) / _imageWidth;
                     // ymin
@@ -363,6 +381,7 @@ public:
 
     std::vector<float> _aspectRatios;
     std::vector<float> _variance;
+    std::vector<float> _scales;
 
     bool _flip;
     bool _clip;
diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp
index aca1cd5055..f2e83c087e 100644
--- a/modules/dnn/src/tensorflow/tf_importer.cpp
+++ b/modules/dnn/src/tensorflow/tf_importer.cpp
@@ -321,10 +321,10 @@ DictValue parseDims(const tensorflow::TensorProto &tensor) {
     CV_Assert(tensor.dtype() == tensorflow::DT_INT32);
     CV_Assert(dims == 1);
 
-    int size = tensor.tensor_content().size() / sizeof(int);
-    const int *data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
+    Mat values = getTensorContent(tensor);
+    CV_Assert(values.type() == CV_32SC1);
     // TODO: add reordering shape if dims == 4
-    return DictValue::arrayInt(data, size);
+    return DictValue::arrayInt((int*)values.data, values.total());
 }
 
 void setKSize(LayerParams &layerParams, const tensorflow::NodeDef &layer)
@@ -448,7 +448,7 @@ void ExcludeLayer(tensorflow::GraphDef& net, const int layer_index, const int in
 
 class TFImporter : public Importer {
 public:
-    TFImporter(const char *model);
+    TFImporter(const char *model, const char *config = NULL);
     void populateNet(Net dstNet);
     ~TFImporter() {}
 
@@ -463,13 +463,20 @@ private:
                                                 int input_blob_index = -1, int* actual_inp_blob_idx = 0);
 
 
-    tensorflow::GraphDef net;
+    // Binary serialized TensorFlow graph includes weights.
+    tensorflow::GraphDef netBin;
+    // Optional text definition of TensorFlow graph. More flexible than binary format
+    // and may be used to build the network using binary format only as a weights storage.
+    // This approach is similar to Caffe's `.prorotxt` and `.caffemodel`.
+    tensorflow::GraphDef netTxt;
 };
 
-TFImporter::TFImporter(const char *model)
+TFImporter::TFImporter(const char *model, const char *config)
 {
     if (model && model[0])
-        ReadTFNetParamsFromBinaryFileOrDie(model, &net);
+        ReadTFNetParamsFromBinaryFileOrDie(model, &netBin);
+    if (config && config[0])
+        ReadTFNetParamsFromTextFileOrDie(config, &netTxt);
 }
 
 void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
@@ -557,21 +564,23 @@ const tensorflow::TensorProto& TFImporter::getConstBlob(const tensorflow::NodeDe
         *actual_inp_blob_idx = input_blob_index;
     }
 
-    return net.node(const_layers.at(kernel_inp.name)).attr().at("value").tensor();
+    int nodeIdx = const_layers.at(kernel_inp.name);
+    if (nodeIdx < netBin.node_size() && netBin.node(nodeIdx).name() == kernel_inp.name)
+    {
+        return netBin.node(nodeIdx).attr().at("value").tensor();
+    }
+    else
+    {
+        CV_Assert(nodeIdx < netTxt.node_size(),
+                  netTxt.node(nodeIdx).name() == kernel_inp.name);
+        return netTxt.node(nodeIdx).attr().at("value").tensor();
+    }
 }
 
-
-void TFImporter::populateNet(Net dstNet)
+static void addConstNodes(const tensorflow::GraphDef& net, std::map<String, int>& const_layers,
+                          std::set<String>& layers_to_ignore)
 {
-    RemoveIdentityOps(net);
-
-    std::map<int, String> layers_to_ignore;
-
-    int layersSize = net.node_size();
-
-    // find all Const layers for params
-    std::map<String, int> value_id;
-    for (int li = 0; li < layersSize; li++)
+    for (int li = 0; li < net.node_size(); li++)
     {
         const tensorflow::NodeDef &layer = net.node(li);
         String name = layer.name();
@@ -582,11 +591,27 @@ void TFImporter::populateNet(Net dstNet)
 
         if (layer.attr().find("value") != layer.attr().end())
         {
-            value_id.insert(std::make_pair(name, li));
+            CV_Assert(const_layers.insert(std::make_pair(name, li)).second);
         }
-
-        layers_to_ignore[li] = name;
+        layers_to_ignore.insert(name);
     }
+}
+
+void TFImporter::populateNet(Net dstNet)
+{
+    RemoveIdentityOps(netBin);
+    RemoveIdentityOps(netTxt);
+
+    std::set<String> layers_to_ignore;
+
+    tensorflow::GraphDef& net = netTxt.ByteSize() != 0 ? netTxt : netBin;
+
+    int layersSize = net.node_size();
+
+    // find all Const layers for params
+    std::map<String, int> value_id;
+    addConstNodes(netBin, value_id, layers_to_ignore);
+    addConstNodes(netTxt, value_id, layers_to_ignore);
 
     std::map<String, int> layer_id;
 
@@ -597,7 +622,7 @@ void TFImporter::populateNet(Net dstNet)
         String type = layer.op();
         LayerParams layerParams;
 
-        if(layers_to_ignore.find(li) != layers_to_ignore.end())
+        if(layers_to_ignore.find(name) != layers_to_ignore.end())
             continue;
 
         if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative")
@@ -627,7 +652,7 @@ void TFImporter::populateNet(Net dstNet)
                 StrIntVector next_layers = getNextLayers(net, name, "Conv2D");
                 CV_Assert(next_layers.size() == 1);
                 layer = net.node(next_layers[0].second);
-                layers_to_ignore[next_layers[0].second] = next_layers[0].first;
+                layers_to_ignore.insert(next_layers[0].first);
                 name = layer.name();
                 type = layer.op();
             }
@@ -644,7 +669,7 @@ void TFImporter::populateNet(Net dstNet)
 
                 blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
                 ExcludeLayer(net, weights_layer_index, 0, false);
-                layers_to_ignore[weights_layer_index] = next_layers[0].first;
+                layers_to_ignore.insert(next_layers[0].first);
             }
 
             kernelFromTensor(getConstBlob(layer, value_id), layerParams.blobs[0]);
@@ -684,7 +709,7 @@ void TFImporter::populateNet(Net dstNet)
                 layerParams.set("pad_mode", "");  // We use padding values.
                 CV_Assert(next_layers.size() == 1);
                 ExcludeLayer(net, next_layers[0].second, 0, false);
-                layers_to_ignore[next_layers[0].second] = next_layers[0].first;
+                layers_to_ignore.insert(next_layers[0].first);
             }
 
             int id = dstNet.addLayer(name, "Convolution", layerParams);
@@ -748,7 +773,7 @@ void TFImporter::populateNet(Net dstNet)
                 int weights_layer_index = next_layers[0].second;
                 blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
                 ExcludeLayer(net, weights_layer_index, 0, false);
-                layers_to_ignore[weights_layer_index] = next_layers[0].first;
+                layers_to_ignore.insert(next_layers[0].first);
             }
 
             int kernel_blob_index = -1;
@@ -778,6 +803,30 @@ void TFImporter::populateNet(Net dstNet)
             // one input only
             connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
         }
+        else if (type == "Flatten")
+        {
+            int id = dstNet.addLayer(name, "Flatten", layerParams);
+            layer_id[name] = id;
+            connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+        }
+        else if (type == "Transpose")
+        {
+            Mat perm = getTensorContent(getConstBlob(layer, value_id, 1));
+            CV_Assert(perm.type() == CV_32SC1);
+            int* permData = (int*)perm.data;
+            if (perm.total() == 4)
+            {
+                for (int i = 0; i < 4; ++i)
+                    permData[i] = toNCHW[permData[i]];
+            }
+            layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
+
+            int id = dstNet.addLayer(name, "Permute", layerParams);
+            layer_id[name] = id;
+
+            // one input only
+            connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+        }
         else if (type == "Const")
         {
         }
@@ -807,7 +856,7 @@ void TFImporter::populateNet(Net dstNet)
         {
             int axisId = (type == "Concat" ? 0 : layer.input_size() - 1);
             int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0);
-            layerParams.set("axis", toNCHW[axis]);
+            layerParams.set("axis", 0 <= axis && axis < 4 ? toNCHW[axis] : axis);
 
             int id = dstNet.addLayer(name, "Concat", layerParams);
             layer_id[name] = id;
@@ -929,6 +978,19 @@ void TFImporter::populateNet(Net dstNet)
                 else  // is a vector
                 {
                     layerParams.blobs.resize(1, scaleMat);
+
+                   StrIntVector next_layers = getNextLayers(net, name, "Add");
+                   if (!next_layers.empty())
+                   {
+                       layerParams.set("bias_term", true);
+                       layerParams.blobs.resize(2);
+
+                       int weights_layer_index = next_layers[0].second;
+                       blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs.back());
+                       ExcludeLayer(net, weights_layer_index, 0, false);
+                       layers_to_ignore.insert(next_layers[0].first);
+                   }
+
                     id = dstNet.addLayer(name, "Scale", layerParams);
                 }
                 layer_id[name] = id;
@@ -1037,7 +1099,7 @@ void TFImporter::populateNet(Net dstNet)
 
                 blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
                 ExcludeLayer(net, weights_layer_index, 0, false);
-                layers_to_ignore[weights_layer_index] = next_layers[0].first;
+                layers_to_ignore.insert(next_layers[0].first);
             }
 
             kernelFromTensor(getConstBlob(layer, value_id, 1), layerParams.blobs[0]);
@@ -1148,6 +1210,71 @@ void TFImporter::populateNet(Net dstNet)
 
             connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
         }
+        else if (type == "PriorBox")
+        {
+            if (hasLayerAttr(layer, "min_size"))
+                layerParams.set("min_size", getLayerAttr(layer, "min_size").i());
+            if (hasLayerAttr(layer, "max_size"))
+                layerParams.set("max_size", getLayerAttr(layer, "max_size").i());
+            if (hasLayerAttr(layer, "flip"))
+                layerParams.set("flip", getLayerAttr(layer, "flip").b());
+            if (hasLayerAttr(layer, "clip"))
+                layerParams.set("clip", getLayerAttr(layer, "clip").b());
+            if (hasLayerAttr(layer, "offset"))
+                layerParams.set("offset", getLayerAttr(layer, "offset").f());
+            if (hasLayerAttr(layer, "variance"))
+            {
+                Mat variance = getTensorContent(getLayerAttr(layer, "variance").tensor());
+                layerParams.set("variance",
+                                DictValue::arrayReal<float*>((float*)variance.data, variance.total()));
+            }
+            if (hasLayerAttr(layer, "aspect_ratio"))
+            {
+                Mat aspectRatios = getTensorContent(getLayerAttr(layer, "aspect_ratio").tensor());
+                layerParams.set("aspect_ratio",
+                               DictValue::arrayReal<float*>((float*)aspectRatios.data, aspectRatios.total()));
+            }
+            if (hasLayerAttr(layer, "scales"))
+            {
+                Mat scales = getTensorContent(getLayerAttr(layer, "scales").tensor());
+                layerParams.set("scales",
+                               DictValue::arrayReal<float*>((float*)scales.data, scales.total()));
+            }
+            int id = dstNet.addLayer(name, "PriorBox", layerParams);
+            layer_id[name] = id;
+            connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+            connect(layer_id, dstNet, parsePin(layer.input(1)), id, 1);
+        }
+        else if (type == "DetectionOutput")
+        {
+            // op: "DetectionOutput"
+            // input_0: "locations"
+            // input_1: "classifications"
+            // input_2: "prior_boxes"
+            if (hasLayerAttr(layer, "num_classes"))
+                layerParams.set("num_classes", getLayerAttr(layer, "num_classes").i());
+            if (hasLayerAttr(layer, "share_location"))
+                layerParams.set("share_location", getLayerAttr(layer, "share_location").b());
+            if (hasLayerAttr(layer, "background_label_id"))
+                layerParams.set("background_label_id", getLayerAttr(layer, "background_label_id").i());
+            if (hasLayerAttr(layer, "nms_threshold"))
+                layerParams.set("nms_threshold", getLayerAttr(layer, "nms_threshold").f());
+            if (hasLayerAttr(layer, "top_k"))
+                layerParams.set("top_k", getLayerAttr(layer, "top_k").i());
+            if (hasLayerAttr(layer, "code_type"))
+                layerParams.set("code_type", getLayerAttr(layer, "code_type").s());
+            if (hasLayerAttr(layer, "keep_top_k"))
+                layerParams.set("keep_top_k", getLayerAttr(layer, "keep_top_k").i());
+            if (hasLayerAttr(layer, "confidence_threshold"))
+                layerParams.set("confidence_threshold", getLayerAttr(layer, "confidence_threshold").f());
+            if (hasLayerAttr(layer, "loc_pred_transposed"))
+                layerParams.set("loc_pred_transposed", getLayerAttr(layer, "loc_pred_transposed").b());
+
+            int id = dstNet.addLayer(name, "DetectionOutput", layerParams);
+            layer_id[name] = id;
+            for (int i = 0; i < 3; ++i)
+                connect(layer_id, dstNet, parsePin(layer.input(i)), id, i);
+        }
         else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
                  type == "Relu" || type == "Elu" || type == "Softmax" ||
                  type == "Identity" || type == "Relu6")
@@ -1188,9 +1315,9 @@ Ptr<Importer> createTensorflowImporter(const String&)
 
 #endif //HAVE_PROTOBUF
 
-Net readNetFromTensorflow(const String &model)
+Net readNetFromTensorflow(const String &model, const String &config)
 {
-    TFImporter importer(model.c_str());
+    TFImporter importer(model.c_str(), config.c_str());
     Net net;
     importer.populateNet(net);
     return net;
diff --git a/modules/dnn/src/tensorflow/tf_io.cpp b/modules/dnn/src/tensorflow/tf_io.cpp
index d96d0065e5..694ddd6c48 100644
--- a/modules/dnn/src/tensorflow/tf_io.cpp
+++ b/modules/dnn/src/tensorflow/tf_io.cpp
@@ -52,12 +52,27 @@ bool ReadProtoFromBinaryFileTF(const char* filename, Message* proto) {
     return success;
 }
 
+bool ReadProtoFromTextFileTF(const char* filename, Message* proto) {
+    std::ifstream fs(filename, std::ifstream::in);
+    CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
+    IstreamInputStream input(&fs);
+    bool success = google::protobuf::TextFormat::Parse(&input, proto);
+    fs.close();
+    return success;
+}
+
 void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
                                       tensorflow::GraphDef* param) {
   CHECK(ReadProtoFromBinaryFileTF(param_file, param))
       << "Failed to parse GraphDef file: " << param_file;
 }
 
+void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
+                                      tensorflow::GraphDef* param) {
+  CHECK(ReadProtoFromTextFileTF(param_file, param))
+      << "Failed to parse GraphDef file: " << param_file;
+}
+
 }
 }
 #endif
diff --git a/modules/dnn/src/tensorflow/tf_io.hpp b/modules/dnn/src/tensorflow/tf_io.hpp
index a3abd1d360..151d5f5b6e 100644
--- a/modules/dnn/src/tensorflow/tf_io.hpp
+++ b/modules/dnn/src/tensorflow/tf_io.hpp
@@ -22,6 +22,9 @@ namespace dnn {
 void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
                                       tensorflow::GraphDef* param);
 
+void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
+                                      tensorflow::GraphDef* param);
+
 }
 }
 
diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp
index d2ed0903a8..e5b94efa79 100644
--- a/modules/dnn/test/test_tf_importer.cpp
+++ b/modules/dnn/test/test_tf_importer.cpp
@@ -74,14 +74,15 @@ static std::string path(const std::string& file)
     return findDataFile("dnn/tensorflow/" + file, false);
 }
 
-static void runTensorFlowNet(const std::string& prefix,
+static void runTensorFlowNet(const std::string& prefix, bool hasText = false,
                              double l1 = 1e-5, double lInf = 1e-4)
 {
     std::string netPath = path(prefix + "_net.pb");
+    std::string netConfig = (hasText ? path(prefix + "_net.pbtxt") : "");
     std::string inpPath = path(prefix + "_in.npy");
     std::string outPath = path(prefix + "_out.npy");
 
-    Net net = readNetFromTensorflow(netPath);
+    Net net = readNetFromTensorflow(netPath, netConfig);
 
     cv::Mat input = blobFromNPY(inpPath);
     cv::Mat target = blobFromNPY(outPath);
@@ -120,6 +121,7 @@ TEST(Test_TensorFlow, batch_norm)
 {
     runTensorFlowNet("batch_norm");
     runTensorFlowNet("fused_batch_norm");
+    runTensorFlowNet("batch_norm_text", true);
 }
 
 TEST(Test_TensorFlow, pooling)
@@ -148,26 +150,60 @@ TEST(Test_TensorFlow, reshape)
 {
     runTensorFlowNet("shift_reshape_no_reorder");
     runTensorFlowNet("reshape_reduce");
+    runTensorFlowNet("flatten", true);
 }
 
 TEST(Test_TensorFlow, fp16)
 {
     const float l1 = 1e-3;
     const float lInf = 1e-2;
-    runTensorFlowNet("fp16_single_conv", l1, lInf);
-    runTensorFlowNet("fp16_deconvolution", l1, lInf);
-    runTensorFlowNet("fp16_max_pool_odd_same", l1, lInf);
-    runTensorFlowNet("fp16_padding_valid", l1, lInf);
-    runTensorFlowNet("fp16_eltwise_add_mul", l1, lInf);
-    runTensorFlowNet("fp16_max_pool_odd_valid", l1, lInf);
-    runTensorFlowNet("fp16_pad_and_concat", l1, lInf);
-    runTensorFlowNet("fp16_max_pool_even", l1, lInf);
-    runTensorFlowNet("fp16_padding_same", l1, lInf);
+    runTensorFlowNet("fp16_single_conv", false, l1, lInf);
+    runTensorFlowNet("fp16_deconvolution", false, l1, lInf);
+    runTensorFlowNet("fp16_max_pool_odd_same", false, l1, lInf);
+    runTensorFlowNet("fp16_padding_valid", false, l1, lInf);
+    runTensorFlowNet("fp16_eltwise_add_mul", false, l1, lInf);
+    runTensorFlowNet("fp16_max_pool_odd_valid", false, l1, lInf);
+    runTensorFlowNet("fp16_pad_and_concat", false, l1, lInf);
+    runTensorFlowNet("fp16_max_pool_even", false, l1, lInf);
+    runTensorFlowNet("fp16_padding_same", false, l1, lInf);
+}
+
+TEST(Test_TensorFlow, MobileNet_SSD)
+{
+    std::string netPath = findDataFile("dnn/ssd_mobilenet_v1_coco.pb", false);
+    std::string netConfig = findDataFile("dnn/ssd_mobilenet_v1_coco.pbtxt", false);
+    std::string imgPath = findDataFile("dnn/street.png", false);
+
+    Mat inp;
+    resize(imread(imgPath), inp, Size(300, 300));
+    inp = blobFromImage(inp, 1.0f / 127.5, Size(), Scalar(127.5, 127.5, 127.5), true);
+
+    std::vector<String> outNames(3);
+    outNames[0] = "concat";
+    outNames[1] = "concat_1";
+    outNames[2] = "detection_out";
+
+    std::vector<Mat> target(outNames.size());
+    for (int i = 0; i < outNames.size(); ++i)
+    {
+        std::string path = findDataFile("dnn/tensorflow/ssd_mobilenet_v1_coco." + outNames[i] + ".npy", false);
+        target[i] = blobFromNPY(path);
+    }
+
+    Net net = readNetFromTensorflow(netPath, netConfig);
+    net.setInput(inp);
+
+    std::vector<Mat> output;
+    net.forward(output, outNames);
+
+    normAssert(target[0].reshape(1, 1), output[0].reshape(1, 1));
+    normAssert(target[1].reshape(1, 1), output[1].reshape(1, 1), "", 1e-5, 2e-4);
+    normAssert(target[2].reshape(1, 1), output[2].reshape(1, 1), "", 4e-5, 1e-2);
 }
 
 TEST(Test_TensorFlow, lstm)
 {
-    runTensorFlowNet("lstm");
+    runTensorFlowNet("lstm", true);
 }
 
 TEST(Test_TensorFlow, split)
diff --git a/samples/dnn/mobilenet_ssd_accuracy.py b/samples/dnn/mobilenet_ssd_accuracy.py
new file mode 100644
index 0000000000..378d2fe596
--- /dev/null
+++ b/samples/dnn/mobilenet_ssd_accuracy.py
@@ -0,0 +1,131 @@
+# Script to evaluate MobileNet-SSD object detection model trained in TensorFlow
+# using both TensorFlow and OpenCV. Example:
+#
+# python mobilenet_ssd_accuracy.py \
+#   --weights=frozen_inference_graph.pb \
+#   --prototxt=ssd_mobilenet_v1_coco.pbtxt \
+#   --images=val2017 \
+#   --annotations=annotations/instances_val2017.json
+#
+# Tested on COCO 2017 object detection dataset, http://cocodataset.org/#download
+import os
+import cv2 as cv
+import json
+import argparse
+
+parser = argparse.ArgumentParser(
+    description='Evaluate MobileNet-SSD model using both TensorFlow and OpenCV. '
+                'COCO evaluation framework is required: http://cocodataset.org')
+parser.add_argument('--weights', required=True,
+                    help='Path to frozen_inference_graph.pb of MobileNet-SSD model. '
+                         'Download it at https://github.com/tensorflow/models/tree/master/research/object_detection')
+parser.add_argument('--prototxt', help='Path to ssd_mobilenet_v1_coco.pbtxt from opencv_extra.', required=True)
+parser.add_argument('--images', help='Path to COCO validation images directory.', required=True)
+parser.add_argument('--annotations', help='Path to COCO annotations file.', required=True)
+args = parser.parse_args()
+
+### Get OpenCV predictions #####################################################
+net = cv.dnn.readNetFromTensorflow(args.weights, args.prototxt)
+
+detections = []
+for imgName in os.listdir(args.images):
+    inp = cv.imread(os.path.join(args.images, imgName))
+    rows = inp.shape[0]
+    cols = inp.shape[1]
+    inp = cv.resize(inp, (300, 300))
+
+    net.setInput(cv.dnn.blobFromImage(inp, 1.0/127.5, (300, 300), (127.5, 127.5, 127.5), True))
+    out = net.forward()
+
+    for i in range(out.shape[2]):
+        score = float(out[0, 0, i, 2])
+        # Confidence threshold is in prototxt.
+        classId = int(out[0, 0, i, 1])
+
+        x = out[0, 0, i, 3] * cols
+        y = out[0, 0, i, 4] * rows
+        w = out[0, 0, i, 5] * cols - x
+        h = out[0, 0, i, 6] * rows - y
+        detections.append({
+          "image_id": int(imgName.rstrip('0')[:imgName.rfind('.')]),
+          "category_id": classId,
+          "bbox": [x, y, w, h],
+          "score": score
+        })
+
+with open('cv_result.json', 'wt') as f:
+    json.dump(detections, f)
+
+### Get TensorFlow predictions #################################################
+import tensorflow as tf
+
+with tf.gfile.FastGFile(args.weights) as f:
+    # Load the model
+    graph_def = tf.GraphDef()
+    graph_def.ParseFromString(f.read())
+
+with tf.Session() as sess:
+    # Restore session
+    sess.graph.as_default()
+    tf.import_graph_def(graph_def, name='')
+
+    detections = []
+    for imgName in os.listdir(args.images):
+        inp = cv.imread(os.path.join(args.images, imgName))
+        rows = inp.shape[0]
+        cols = inp.shape[1]
+        inp = cv.resize(inp, (300, 300))
+        inp = inp[:, :, [2, 1, 0]]  # BGR2RGB
+        out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
+                        sess.graph.get_tensor_by_name('detection_scores:0'),
+                        sess.graph.get_tensor_by_name('detection_boxes:0'),
+                        sess.graph.get_tensor_by_name('detection_classes:0')],
+                       feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)})
+        num_detections = int(out[0][0])
+        for i in range(num_detections):
+            classId = int(out[3][0][i])
+            score = float(out[1][0][i])
+            bbox = [float(v) for v in out[2][0][i]]
+            if score > 0.01:
+                x = bbox[1] * cols
+                y = bbox[0] * rows
+                w = bbox[3] * cols - x
+                h = bbox[2] * rows - y
+                detections.append({
+                  "image_id": int(imgName.rstrip('0')[:imgName.rfind('.')]),
+                  "category_id": classId,
+                  "bbox": [x, y, w, h],
+                  "score": score
+                })
+
+with open('tf_result.json', 'wt') as f:
+    json.dump(detections, f)
+
+### Evaluation part ############################################################
+
+# %matplotlib inline
+import matplotlib.pyplot as plt
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+import numpy as np
+import skimage.io as io
+import pylab
+pylab.rcParams['figure.figsize'] = (10.0, 8.0)
+
+annType = ['segm','bbox','keypoints']
+annType = annType[1]      #specify type here
+prefix = 'person_keypoints' if annType=='keypoints' else 'instances'
+print 'Running demo for *%s* results.'%(annType)
+
+#initialize COCO ground truth api
+cocoGt=COCO(args.annotations)
+
+#initialize COCO detections api
+for resFile in ['tf_result.json', 'cv_result.json']:
+    print resFile
+    cocoDt=cocoGt.loadRes(resFile)
+
+    cocoEval = COCOeval(cocoGt,cocoDt,annType)
+    cocoEval.evaluate()
+    cocoEval.accumulate()
+    cocoEval.summarize()
diff --git a/samples/dnn/mobilenet_ssd_python.py b/samples/dnn/mobilenet_ssd_python.py
index 039c244457..f031a7c669 100644
--- a/samples/dnn/mobilenet_ssd_python.py
+++ b/samples/dnn/mobilenet_ssd_python.py
@@ -1,3 +1,14 @@
+# This script is used to demonstrate MobileNet-SSD network using OpenCV deep learning module.
+#
+# It works with model taken from https://github.com/chuanqi305/MobileNet-SSD/ that
+# was trained in Caffe-SSD framework, https://github.com/weiliu89/caffe/tree/ssd.
+# Model detects objects from 20 classes.
+#
+# Also TensorFlow model from TensorFlow object detection model zoo may be used to
+# detect objects from 90 classes:
+# https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
+# Text graph definition must be taken from opencv_extra:
+# https://github.com/opencv/opencv_extra/tree/master/testdata/dnn/ssd_mobilenet_v1_coco.pbtxt
 import numpy as np
 import argparse
 
@@ -13,27 +24,58 @@ WHRatio = inWidth / float(inHeight)
 inScaleFactor = 0.007843
 meanVal = 127.5
 
-classNames = ('background',
-              'aeroplane', 'bicycle', 'bird', 'boat',
-              'bottle', 'bus', 'car', 'cat', 'chair',
-              'cow', 'diningtable', 'dog', 'horse',
-              'motorbike', 'person', 'pottedplant',
-              'sheep', 'sofa', 'train', 'tvmonitor')
-
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
+    parser = argparse.ArgumentParser(
+        description='Script to run MobileNet-SSD object detection network '
+                    'trained either in Caffe or TensorFlow frameworks.')
     parser.add_argument("--video", help="path to video file. If empty, camera's stream will be used")
     parser.add_argument("--prototxt", default="MobileNetSSD_deploy.prototxt",
-                        help="path to caffe prototxt")
-    parser.add_argument("-c", "--caffemodel", default="MobileNetSSD_deploy.caffemodel",
-                        help="path to caffemodel file, download it here: "
-                        "https://github.com/chuanqi305/MobileNet-SSD/")
-    parser.add_argument("--thr", default=0.2, help="confidence threshold to filter out weak detections")
+                                      help='Path to text network file: '
+                                           'MobileNetSSD_deploy.prototxt for Caffe model or '
+                                           'ssd_mobilenet_v1_coco.pbtxt from opencv_extra for TensorFlow model')
+    parser.add_argument("--weights", default="MobileNetSSD_deploy.caffemodel",
+                                     help='Path to weights: '
+                                          'MobileNetSSD_deploy.caffemodel for Caffe model or '
+                                          'frozen_inference_graph.pb from TensorFlow.')
+    parser.add_argument("--num_classes", default=20, type=int,
+                        help="Number of classes. It's 20 for Caffe model from "
+                             "https://github.com/chuanqi305/MobileNet-SSD/ and 90 for "
+                             "TensorFlow model from https://github.com/tensorflow/models/tree/master/research/object_detection")
+    parser.add_argument("--thr", default=0.2, type=float, help="confidence threshold to filter out weak detections")
     args = parser.parse_args()
 
-    net = cv.dnn.readNetFromCaffe(args.prototxt, args.caffemodel)
-
-    if len(args.video):
+    if args.num_classes == 20:
+        net = cv.dnn.readNetFromCaffe(args.prototxt, args.weights)
+        swapRB = False
+        classNames = { 0: 'background',
+            1: 'aeroplane', 2: 'bicycle', 3: 'bird', 4: 'boat',
+            5: 'bottle', 6: 'bus', 7: 'car', 8: 'cat', 9: 'chair',
+            10: 'cow', 11: 'diningtable', 12: 'dog', 13: 'horse',
+            14: 'motorbike', 15: 'person', 16: 'pottedplant',
+            17: 'sheep', 18: 'sofa', 19: 'train', 20: 'tvmonitor' }
+    else:
+        assert(args.num_classes == 90)
+        net = cv.dnn.readNetFromTensorflow(args.weights, args.prototxt)
+        swapRB = True
+        classNames = { 0: 'background',
+            1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus',
+            7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant',
+            13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat',
+            18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear',
+            24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag',
+            32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard',
+            37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove',
+            41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle',
+            46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon',
+            51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange',
+            56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut',
+            61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed',
+            67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse',
+            75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven',
+            80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock',
+            86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush' }
+
+    if args.video:
         cap = cv.VideoCapture(args.video)
     else:
         cap = cv.VideoCapture(0)
@@ -41,7 +83,7 @@ if __name__ == "__main__":
     while True:
         # Capture frame-by-frame
         ret, frame = cap.read()
-        blob = cv.dnn.blobFromImage(frame, inScaleFactor, (inWidth, inHeight), meanVal, False)
+        blob = cv.dnn.blobFromImage(frame, inScaleFactor, (inWidth, inHeight), (meanVal, meanVal, meanVal), swapRB)
         net.setInput(blob)
         detections = net.forward()
 
@@ -74,14 +116,16 @@ if __name__ == "__main__":
 
                 cv.rectangle(frame, (xLeftBottom, yLeftBottom), (xRightTop, yRightTop),
                               (0, 255, 0))
-                label = classNames[class_id] + ": " + str(confidence)
-                labelSize, baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 1)
-
-                cv.rectangle(frame, (xLeftBottom, yLeftBottom - labelSize[1]),
-                                     (xLeftBottom + labelSize[0], yLeftBottom + baseLine),
-                                     (255, 255, 255), cv.FILLED)
-                cv.putText(frame, label, (xLeftBottom, yLeftBottom),
-                            cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0))
+                if class_id in classNames:
+                    label = classNames[class_id] + ": " + str(confidence)
+                    labelSize, baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 1)
+
+                    yLeftBottom = max(yLeftBottom, labelSize[1])
+                    cv.rectangle(frame, (xLeftBottom, yLeftBottom - labelSize[1]),
+                                         (xLeftBottom + labelSize[0], yLeftBottom + baseLine),
+                                         (255, 255, 255), cv.FILLED)
+                    cv.putText(frame, label, (xLeftBottom, yLeftBottom),
+                                cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0))
 
         cv.imshow("detections", frame)
         if cv.waitKey(1) >= 0:
diff --git a/samples/dnn/shrink_tf_graph_weights.py b/samples/dnn/shrink_tf_graph_weights.py
new file mode 100644
index 0000000000..799d6c758c
--- /dev/null
+++ b/samples/dnn/shrink_tf_graph_weights.py
@@ -0,0 +1,62 @@
+# This file is part of OpenCV project.
+# It is subject to the license terms in the LICENSE file found in the top-level directory
+# of this distribution and at http://opencv.org/license.html.
+#
+# Copyright (C) 2017, Intel Corporation, all rights reserved.
+# Third party copyrights are property of their respective owners.
+import tensorflow as tf
+import struct
+import argparse
+import numpy as np
+
+parser = argparse.ArgumentParser(description='Convert weights of a frozen TensorFlow graph to fp16.')
+parser.add_argument('--input', required=True, help='Path to frozen graph.')
+parser.add_argument('--output', required=True, help='Path to output graph.')
+parser.add_argument('--ops', default=['Conv2D', 'MatMul'], nargs='+',
+                    help='List of ops which weights are converted.')
+args = parser.parse_args()
+
+DT_FLOAT = 1
+DT_HALF = 19
+
+# For the frozen graphs, an every node that uses weights connected to Const nodes
+# through an Identity node. Usually they're called in the same way with '/read' suffix.
+# We'll replace all of them to Cast nodes.
+
+# Load the model
+with tf.gfile.FastGFile(args.input) as f:
+    graph_def = tf.GraphDef()
+    graph_def.ParseFromString(f.read())
+
+# Set of all inputs from desired nodes.
+inputs = []
+for node in graph_def.node:
+    if node.op in args.ops:
+        inputs += node.input
+
+weightsNodes = []
+for node in graph_def.node:
+    # From the whole inputs we need to keep only an Identity nodes.
+    if node.name in inputs and node.op == 'Identity' and node.attr['T'].type == DT_FLOAT:
+        weightsNodes.append(node.input[0])
+
+        # Replace Identity to Cast.
+        node.op = 'Cast'
+        node.attr['DstT'].type = DT_FLOAT
+        node.attr['SrcT'].type = DT_HALF
+        del node.attr['T']
+        del node.attr['_class']
+
+# Convert weights to halfs.
+for node in graph_def.node:
+    if node.name in weightsNodes:
+        node.attr['dtype'].type = DT_HALF
+        node.attr['value'].tensor.dtype = DT_HALF
+
+        floats = node.attr['value'].tensor.tensor_content
+
+        floats = struct.unpack('f' * (len(floats) / 4), floats)
+        halfs = np.array(floats).astype(np.float16).view(np.uint16)
+        node.attr['value'].tensor.tensor_content = struct.pack('H' * len(halfs), *halfs)
+
+tf.train.write_graph(graph_def, "", args.output, as_text=False)
-- 
2.18.0