Commit f37f4cf3 authored by Alexander Alekhin's avatar Alexander Alekhin

Merge pull request #9994 from r2d3:dnn_memory_load

parents e7d62d6e f723cede
...@@ -644,11 +644,33 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN ...@@ -644,11 +644,33 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
*/ */
CV_EXPORTS_W Net readNetFromCaffe(const String &prototxt, const String &caffeModel = String()); CV_EXPORTS_W Net readNetFromCaffe(const String &prototxt, const String &caffeModel = String());
/** @brief Reads a network model stored in Caffe model in memory.
* @details This is an overloaded member function, provided for convenience.
* It differs from the above function only in what argument(s) it accepts.
* @param bufferProto buffer containing the content of the .prototxt file
* @param lenProto length of bufferProto
* @param bufferModel buffer containing the content of the .caffemodel file
* @param lenModel length of bufferModel
*/
CV_EXPORTS Net readNetFromCaffe(const char *bufferProto, size_t lenProto,
const char *bufferModel = NULL, size_t lenModel = 0);
/** @brief Reads a network model stored in Tensorflow model file. /** @brief Reads a network model stored in Tensorflow model file.
* @details This is shortcut consisting from createTensorflowImporter and Net::populateNet calls. * @details This is shortcut consisting from createTensorflowImporter and Net::populateNet calls.
*/ */
CV_EXPORTS_W Net readNetFromTensorflow(const String &model, const String &config = String()); CV_EXPORTS_W Net readNetFromTensorflow(const String &model, const String &config = String());
/** @brief Reads a network model stored in Tensorflow model in memory.
* @details This is an overloaded member function, provided for convenience.
* It differs from the above function only in what argument(s) it accepts.
* @param bufferModel buffer containing the content of the pb file
* @param lenModel length of bufferModel
* @param bufferConfig buffer containing the content of the pbtxt file
* @param lenConfig length of bufferConfig
*/
CV_EXPORTS Net readNetFromTensorflow(const char *bufferModel, size_t lenModel,
const char *bufferConfig = NULL, size_t lenConfig = 0);
/** @brief Reads a network model stored in Torch model file. /** @brief Reads a network model stored in Torch model file.
* @details This is shortcut consisting from createTorchImporter and Net::populateNet calls. * @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
*/ */
......
...@@ -92,6 +92,17 @@ public: ...@@ -92,6 +92,17 @@ public:
ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary); ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
} }
CaffeImporter(const char *dataProto, size_t lenProto,
const char *dataModel, size_t lenModel)
{
CV_TRACE_FUNCTION();
ReadNetParamsFromTextBufferOrDie(dataProto, lenProto, &net);
if (dataModel != NULL && lenModel > 0)
ReadNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBinary);
}
void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams &params) void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams &params)
{ {
const Reflection *refl = msg.GetReflection(); const Reflection *refl = msg.GetReflection();
...@@ -398,6 +409,15 @@ Net readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String ...@@ -398,6 +409,15 @@ Net readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String
return net; return net;
} }
Net readNetFromCaffe(const char *bufferProto, size_t lenProto,
const char *bufferModel, size_t lenModel)
{
CaffeImporter caffeImporter(bufferProto, lenProto, bufferModel, lenModel);
Net net;
caffeImporter.populateNet(net);
return net;
}
#endif //HAVE_PROTOBUF #endif //HAVE_PROTOBUF
CV__DNN_EXPERIMENTAL_NS_END CV__DNN_EXPERIMENTAL_NS_END
......
...@@ -1107,28 +1107,37 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) { ...@@ -1107,28 +1107,37 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) {
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
bool ReadProtoFromBinary(ZeroCopyInputStream* input, Message *proto) {
CodedInputStream coded_input(input);
coded_input.SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
return proto->ParseFromCodedStream(&coded_input);
}
bool ReadProtoFromTextFile(const char* filename, Message* proto) { bool ReadProtoFromTextFile(const char* filename, Message* proto) {
std::ifstream fs(filename, std::ifstream::in); std::ifstream fs(filename, std::ifstream::in);
CHECK(fs.is_open()) << "Can't open \"" << filename << "\""; CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
IstreamInputStream input(&fs); IstreamInputStream input(&fs);
bool success = google::protobuf::TextFormat::Parse(&input, proto); return google::protobuf::TextFormat::Parse(&input, proto);
fs.close();
return success;
} }
bool ReadProtoFromBinaryFile(const char* filename, Message* proto) { bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
std::ifstream fs(filename, std::ifstream::in | std::ifstream::binary); std::ifstream fs(filename, std::ifstream::in | std::ifstream::binary);
CHECK(fs.is_open()) << "Can't open \"" << filename << "\""; CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
ZeroCopyInputStream* raw_input = new IstreamInputStream(&fs); IstreamInputStream raw_input(&fs);
CodedInputStream* coded_input = new CodedInputStream(raw_input);
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912); return ReadProtoFromBinary(&raw_input, proto);
}
bool ReadProtoFromTextBuffer(const char* data, size_t len, Message* proto) {
ArrayInputStream input(data, len);
return google::protobuf::TextFormat::Parse(&input, proto);
}
bool success = proto->ParseFromCodedStream(coded_input);
delete coded_input; bool ReadProtoFromBinaryBuffer(const char* data, size_t len, Message* proto) {
delete raw_input; ArrayInputStream raw_input(data, len);
fs.close(); return ReadProtoFromBinary(&raw_input, proto);
return success;
} }
void ReadNetParamsFromTextFileOrDie(const char* param_file, void ReadNetParamsFromTextFileOrDie(const char* param_file,
...@@ -1138,6 +1147,13 @@ void ReadNetParamsFromTextFileOrDie(const char* param_file, ...@@ -1138,6 +1147,13 @@ void ReadNetParamsFromTextFileOrDie(const char* param_file,
UpgradeNetAsNeeded(param_file, param); UpgradeNetAsNeeded(param_file, param);
} }
void ReadNetParamsFromTextBufferOrDie(const char* data, size_t len,
NetParameter* param) {
CHECK(ReadProtoFromTextBuffer(data, len, param))
<< "Failed to parse NetParameter buffer";
UpgradeNetAsNeeded("memory buffer", param);
}
void ReadNetParamsFromBinaryFileOrDie(const char* param_file, void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
NetParameter* param) { NetParameter* param) {
CHECK(ReadProtoFromBinaryFile(param_file, param)) CHECK(ReadProtoFromBinaryFile(param_file, param))
...@@ -1145,6 +1161,13 @@ void ReadNetParamsFromBinaryFileOrDie(const char* param_file, ...@@ -1145,6 +1161,13 @@ void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
UpgradeNetAsNeeded(param_file, param); UpgradeNetAsNeeded(param_file, param);
} }
void ReadNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
NetParameter* param) {
CHECK(ReadProtoFromBinaryBuffer(data, len, param))
<< "Failed to parse NetParameter buffer";
UpgradeNetAsNeeded("memory buffer", param);
}
} }
} }
#endif #endif
...@@ -102,6 +102,18 @@ void ReadNetParamsFromTextFileOrDie(const char* param_file, ...@@ -102,6 +102,18 @@ void ReadNetParamsFromTextFileOrDie(const char* param_file,
void ReadNetParamsFromBinaryFileOrDie(const char* param_file, void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
caffe::NetParameter* param); caffe::NetParameter* param);
// Read parameters from a memory buffer into a NetParammeter proto message.
void ReadNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
caffe::NetParameter* param);
void ReadNetParamsFromTextBufferOrDie(const char* data, size_t len,
caffe::NetParameter* param);
// Utility functions used internally by Caffe and TensorFlow loaders
bool ReadProtoFromTextFile(const char* filename, ::google::protobuf::Message* proto);
bool ReadProtoFromBinaryFile(const char* filename, ::google::protobuf::Message* proto);
bool ReadProtoFromTextBuffer(const char* data, size_t len, ::google::protobuf::Message* proto);
bool ReadProtoFromBinaryBuffer(const char* data, size_t len, ::google::protobuf::Message* proto);
} }
} }
#endif #endif
......
...@@ -449,6 +449,9 @@ void ExcludeLayer(tensorflow::GraphDef& net, const int layer_index, const int in ...@@ -449,6 +449,9 @@ void ExcludeLayer(tensorflow::GraphDef& net, const int layer_index, const int in
class TFImporter : public Importer { class TFImporter : public Importer {
public: public:
TFImporter(const char *model, const char *config = NULL); TFImporter(const char *model, const char *config = NULL);
TFImporter(const char *dataModel, size_t lenModel,
const char *dataConfig = NULL, size_t lenConfig = 0);
void populateNet(Net dstNet); void populateNet(Net dstNet);
~TFImporter() {} ~TFImporter() {}
...@@ -479,6 +482,15 @@ TFImporter::TFImporter(const char *model, const char *config) ...@@ -479,6 +482,15 @@ TFImporter::TFImporter(const char *model, const char *config)
ReadTFNetParamsFromTextFileOrDie(config, &netTxt); ReadTFNetParamsFromTextFileOrDie(config, &netTxt);
} }
TFImporter::TFImporter(const char *dataModel, size_t lenModel,
const char *dataConfig, size_t lenConfig)
{
if (dataModel != NULL && lenModel > 0)
ReadTFNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBin);
if (dataConfig != NULL && lenConfig > 0)
ReadTFNetParamsFromTextBufferOrDie(dataConfig, lenConfig, &netTxt);
}
void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob) void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
{ {
MatShape shape; MatShape shape;
...@@ -1326,5 +1338,14 @@ Net readNetFromTensorflow(const String &model, const String &config) ...@@ -1326,5 +1338,14 @@ Net readNetFromTensorflow(const String &model, const String &config)
return net; return net;
} }
Net readNetFromTensorflow(const char* bufferModel, size_t lenModel,
const char* bufferConfig, size_t lenConfig)
{
TFImporter importer(bufferModel, lenModel, bufferConfig, lenConfig);
Net net;
importer.populateNet(net);
return net;
}
CV__DNN_EXPERIMENTAL_NS_END CV__DNN_EXPERIMENTAL_NS_END
}} // namespace }} // namespace
...@@ -23,6 +23,7 @@ Implementation of various functions which are related to Tensorflow models readi ...@@ -23,6 +23,7 @@ Implementation of various functions which are related to Tensorflow models readi
#include "graph.pb.h" #include "graph.pb.h"
#include "tf_io.hpp" #include "tf_io.hpp"
#include "../caffe/caffe_io.hpp"
#include "../caffe/glog_emulator.hpp" #include "../caffe/glog_emulator.hpp"
namespace cv { namespace cv {
...@@ -36,41 +37,28 @@ using namespace ::google::protobuf::io; ...@@ -36,41 +37,28 @@ using namespace ::google::protobuf::io;
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
// TODO: remove Caffe duplicate void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
bool ReadProtoFromBinaryFileTF(const char* filename, Message* proto) { tensorflow::GraphDef* param) {
std::ifstream fs(filename, std::ifstream::in | std::ifstream::binary); CHECK(ReadProtoFromBinaryFile(param_file, param))
CHECK(fs.is_open()) << "Can't open \"" << filename << "\""; << "Failed to parse GraphDef file: " << param_file;
ZeroCopyInputStream* raw_input = new IstreamInputStream(&fs);
CodedInputStream* coded_input = new CodedInputStream(raw_input);
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
bool success = proto->ParseFromCodedStream(coded_input);
delete coded_input;
delete raw_input;
fs.close();
return success;
} }
bool ReadProtoFromTextFileTF(const char* filename, Message* proto) { void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
std::ifstream fs(filename, std::ifstream::in); tensorflow::GraphDef* param) {
CHECK(fs.is_open()) << "Can't open \"" << filename << "\""; CHECK(ReadProtoFromBinaryBuffer(data, len, param))
IstreamInputStream input(&fs); << "Failed to parse GraphDef buffer";
bool success = google::protobuf::TextFormat::Parse(&input, proto);
fs.close();
return success;
} }
void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file, void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
tensorflow::GraphDef* param) { tensorflow::GraphDef* param) {
CHECK(ReadProtoFromBinaryFileTF(param_file, param)) CHECK(ReadProtoFromTextFile(param_file, param))
<< "Failed to parse GraphDef file: " << param_file; << "Failed to parse GraphDef file: " << param_file;
} }
void ReadTFNetParamsFromTextFileOrDie(const char* param_file, void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param) { tensorflow::GraphDef* param) {
CHECK(ReadProtoFromTextFileTF(param_file, param)) CHECK(ReadProtoFromTextBuffer(data, len, param))
<< "Failed to parse GraphDef file: " << param_file; << "Failed to parse GraphDef buffer";
} }
} }
......
...@@ -25,6 +25,13 @@ void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file, ...@@ -25,6 +25,13 @@ void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
void ReadTFNetParamsFromTextFileOrDie(const char* param_file, void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
tensorflow::GraphDef* param); tensorflow::GraphDef* param);
// Read parameters from a memory buffer into a GraphDef proto message.
void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param);
void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param);
} }
} }
......
...@@ -55,6 +55,24 @@ static std::string _tf(TString filename) ...@@ -55,6 +55,24 @@ static std::string _tf(TString filename)
return (getOpenCVExtraDir() + "/dnn/") + filename; return (getOpenCVExtraDir() + "/dnn/") + filename;
} }
TEST(Test_Caffe, memory_read)
{
const string proto = findDataFile("dnn/bvlc_googlenet.prototxt", false);
const string model = findDataFile("dnn/bvlc_googlenet.caffemodel", false);
string dataProto;
ASSERT_TRUE(readFileInMemory(proto, dataProto));
string dataModel;
ASSERT_TRUE(readFileInMemory(model, dataModel));
Net net = readNetFromCaffe(dataProto.c_str(), dataProto.size());
ASSERT_FALSE(net.empty());
Net net2 = readNetFromCaffe(dataProto.c_str(), dataProto.size(),
dataModel.c_str(), dataModel.size());
ASSERT_FALSE(net2.empty());
}
TEST(Test_Caffe, read_gtsrb) TEST(Test_Caffe, read_gtsrb)
{ {
Net net = readNetFromCaffe(_tf("gtsrb.prototxt")); Net net = readNetFromCaffe(_tf("gtsrb.prototxt"));
...@@ -67,13 +85,26 @@ TEST(Test_Caffe, read_googlenet) ...@@ -67,13 +85,26 @@ TEST(Test_Caffe, read_googlenet)
ASSERT_FALSE(net.empty()); ASSERT_FALSE(net.empty());
} }
TEST(Reproducibility_AlexNet, Accuracy) typedef testing::TestWithParam<tuple<bool> > Reproducibility_AlexNet;
TEST_P(Reproducibility_AlexNet, Accuracy)
{ {
bool readFromMemory = get<0>(GetParam());
Net net; Net net;
{ {
const string proto = findDataFile("dnn/bvlc_alexnet.prototxt", false); const string proto = findDataFile("dnn/bvlc_alexnet.prototxt", false);
const string model = findDataFile("dnn/bvlc_alexnet.caffemodel", false); const string model = findDataFile("dnn/bvlc_alexnet.caffemodel", false);
net = readNetFromCaffe(proto, model); if (readFromMemory)
{
string dataProto;
ASSERT_TRUE(readFileInMemory(proto, dataProto));
string dataModel;
ASSERT_TRUE(readFileInMemory(model, dataModel));
net = readNetFromCaffe(dataProto.c_str(), dataProto.size(),
dataModel.c_str(), dataModel.size());
}
else
net = readNetFromCaffe(proto, model);
ASSERT_FALSE(net.empty()); ASSERT_FALSE(net.empty());
} }
...@@ -86,6 +117,8 @@ TEST(Reproducibility_AlexNet, Accuracy) ...@@ -86,6 +117,8 @@ TEST(Reproducibility_AlexNet, Accuracy)
normAssert(ref, out); normAssert(ref, out);
} }
INSTANTIATE_TEST_CASE_P(Test_Caffe, Reproducibility_AlexNet, testing::Values(true, false));
#if !defined(_WIN32) || defined(_WIN64) #if !defined(_WIN32) || defined(_WIN64)
TEST(Reproducibility_FCN, Accuracy) TEST(Reproducibility_FCN, Accuracy)
{ {
......
...@@ -57,4 +57,23 @@ inline void normAssert(cv::InputArray ref, cv::InputArray test, const char *comm ...@@ -57,4 +57,23 @@ inline void normAssert(cv::InputArray ref, cv::InputArray test, const char *comm
EXPECT_LE(normInf, lInf) << comment; EXPECT_LE(normInf, lInf) << comment;
} }
inline bool readFileInMemory(const std::string& filename, std::string& content)
{
std::ios::openmode mode = std::ios::in | std::ios::binary;
std::ifstream ifs(filename.c_str(), mode);
if (!ifs.is_open())
return false;
content.clear();
ifs.seekg(0, std::ios::end);
content.reserve(ifs.tellg());
ifs.seekg(0, std::ios::beg);
content.assign((std::istreambuf_iterator<char>(ifs)),
std::istreambuf_iterator<char>());
return true;
}
#endif #endif
...@@ -75,14 +75,32 @@ static std::string path(const std::string& file) ...@@ -75,14 +75,32 @@ static std::string path(const std::string& file)
} }
static void runTensorFlowNet(const std::string& prefix, bool hasText = false, static void runTensorFlowNet(const std::string& prefix, bool hasText = false,
double l1 = 1e-5, double lInf = 1e-4) double l1 = 1e-5, double lInf = 1e-4,
bool memoryLoad = false)
{ {
std::string netPath = path(prefix + "_net.pb"); std::string netPath = path(prefix + "_net.pb");
std::string netConfig = (hasText ? path(prefix + "_net.pbtxt") : ""); std::string netConfig = (hasText ? path(prefix + "_net.pbtxt") : "");
std::string inpPath = path(prefix + "_in.npy"); std::string inpPath = path(prefix + "_in.npy");
std::string outPath = path(prefix + "_out.npy"); std::string outPath = path(prefix + "_out.npy");
Net net = readNetFromTensorflow(netPath, netConfig); Net net;
if (memoryLoad)
{
// Load files into a memory buffers
string dataModel;
ASSERT_TRUE(readFileInMemory(netPath, dataModel));
string dataConfig;
if (hasText)
ASSERT_TRUE(readFileInMemory(netConfig, dataConfig));
net = readNetFromTensorflow(dataModel.c_str(), dataModel.size(),
dataConfig.c_str(), dataConfig.size());
}
else
net = readNetFromTensorflow(netPath, netConfig);
ASSERT_FALSE(net.empty());
cv::Mat input = blobFromNPY(inpPath); cv::Mat input = blobFromNPY(inpPath);
cv::Mat target = blobFromNPY(outPath); cv::Mat target = blobFromNPY(outPath);
...@@ -216,4 +234,15 @@ TEST(Test_TensorFlow, resize_nearest_neighbor) ...@@ -216,4 +234,15 @@ TEST(Test_TensorFlow, resize_nearest_neighbor)
runTensorFlowNet("resize_nearest_neighbor"); runTensorFlowNet("resize_nearest_neighbor");
} }
TEST(Test_TensorFlow, memory_read)
{
double l1 = 1e-5;
double lInf = 1e-4;
runTensorFlowNet("lstm", true, l1, lInf, true);
runTensorFlowNet("batch_norm", false, l1, lInf, true);
runTensorFlowNet("fused_batch_norm", false, l1, lInf, true);
runTensorFlowNet("batch_norm_text", true, l1, lInf, true);
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment