Commit 266692e1 authored by Vitaliy Lyudvichenko's avatar Vitaliy Lyudvichenko

Improving of Caffe importer compatibility

parent 8ecae046
...@@ -48,6 +48,7 @@ using namespace cv::dnn; ...@@ -48,6 +48,7 @@ using namespace cv::dnn;
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <sstream>
#include <algorithm> #include <algorithm>
#include <google/protobuf/message.h> #include <google/protobuf/message.h>
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
...@@ -63,279 +64,297 @@ using ::google::protobuf::Reflection; ...@@ -63,279 +64,297 @@ using ::google::protobuf::Reflection;
namespace namespace
{ {
class CaffeImporter : public Importer
{
caffe::NetParameter net;
caffe::NetParameter netBinary;
public: template<typename T>
static cv::String toString(const T &v)
{
std::ostringstream ss;
ss << v;
return ss.str();
}
CaffeImporter(const char *pototxt, const char *caffeModel) class CaffeImporter : public Importer
{ {
ReadNetParamsFromTextFileOrDie(pototxt, &net); caffe::NetParameter net;
caffe::NetParameter netBinary;
if (caffeModel && caffeModel[0]) public:
ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
}
void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams &params) CaffeImporter(const char *pototxt, const char *caffeModel)
{ {
const Reflection *refl = msg.GetReflection(); ReadNetParamsFromTextFileOrDie(pototxt, &net);
int type = field->cpp_type();
bool isRepeated = field->is_repeated(); if (caffeModel && caffeModel[0])
const std::string &name = field->name(); ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
}
#define SET_UP_FILED(getter, arrayConstr, gtype) \
if (isRepeated) { \ void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams &params)
const RepeatedField<gtype> &v = refl->GetRepeatedField<gtype>(msg, field); \ {
params.set(name, DictValue::arrayConstr(v.begin(), (int)v.size())); \ const Reflection *refl = msg.GetReflection();
} \ int type = field->cpp_type();
else { \ bool isRepeated = field->is_repeated();
params.set(name, refl->getter(msg, field)); \ const std::string &name = field->name();
}
#define SET_UP_FILED(getter, arrayConstr, gtype) \
switch (type) if (isRepeated) { \
{ const RepeatedField<gtype> &v = refl->GetRepeatedField<gtype>(msg, field); \
case FieldDescriptor::CPPTYPE_INT32: params.set(name, DictValue::arrayConstr(v.begin(), (int)v.size())); \
SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int32); } \
break; else { \
case FieldDescriptor::CPPTYPE_UINT32: params.set(name, refl->getter(msg, field)); \
SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint32);
break;
case FieldDescriptor::CPPTYPE_INT64:
SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int64);
break;
case FieldDescriptor::CPPTYPE_UINT64:
SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint64);
break;
case FieldDescriptor::CPPTYPE_BOOL:
SET_UP_FILED(GetBool, arrayInt, bool);
break;
case FieldDescriptor::CPPTYPE_DOUBLE:
SET_UP_FILED(GetDouble, arrayReal, double);
break;
case FieldDescriptor::CPPTYPE_FLOAT:
SET_UP_FILED(GetFloat, arrayReal, float);
break;
case FieldDescriptor::CPPTYPE_STRING:
if (isRepeated) {
const RepeatedPtrField<std::string> &v = refl->GetRepeatedPtrField<std::string>(msg, field);
params.set(name, DictValue::arrayString(v.begin(), (int)v.size()));
}
else {
params.set(name, refl->GetString(msg, field));
}
break;
case FieldDescriptor::CPPTYPE_ENUM:
if (isRepeated) {
int size = refl->FieldSize(msg, field);
std::vector<cv::String> buf(size);
for (int i = 0; i < size; i++)
buf[i] = refl->GetRepeatedEnum(msg, field, i)->name();
params.set(name, DictValue::arrayString(buf.begin(), size));
}
else {
params.set(name, refl->GetEnum(msg, field)->name());
}
break;
default:
CV_Error(Error::StsError, "Unknown type \"" + String(field->type_name()) + "\" in prototxt");
break;
} }
}
inline static bool ends_with_param(const std::string &str) switch (type)
{ {
static const std::string _param("_param"); case FieldDescriptor::CPPTYPE_INT32:
return (str.size() >= _param.size()) && str.compare(str.size() - _param.size(), _param.size(), _param) == 0; SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int32);
break;
case FieldDescriptor::CPPTYPE_UINT32:
SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint32);
break;
case FieldDescriptor::CPPTYPE_INT64:
SET_UP_FILED(GetInt32, arrayInt, ::google::protobuf::int64);
break;
case FieldDescriptor::CPPTYPE_UINT64:
SET_UP_FILED(GetUInt32, arrayInt, ::google::protobuf::uint64);
break;
case FieldDescriptor::CPPTYPE_BOOL:
SET_UP_FILED(GetBool, arrayInt, bool);
break;
case FieldDescriptor::CPPTYPE_DOUBLE:
SET_UP_FILED(GetDouble, arrayReal, double);
break;
case FieldDescriptor::CPPTYPE_FLOAT:
SET_UP_FILED(GetFloat, arrayReal, float);
break;
case FieldDescriptor::CPPTYPE_STRING:
if (isRepeated) {
const RepeatedPtrField<std::string> &v = refl->GetRepeatedPtrField<std::string>(msg, field);
params.set(name, DictValue::arrayString(v.begin(), (int)v.size()));
}
else {
params.set(name, refl->GetString(msg, field));
}
break;
case FieldDescriptor::CPPTYPE_ENUM:
if (isRepeated) {
int size = refl->FieldSize(msg, field);
std::vector<cv::String> buf(size);
for (int i = 0; i < size; i++)
buf[i] = refl->GetRepeatedEnum(msg, field, i)->name();
params.set(name, DictValue::arrayString(buf.begin(), size));
}
else {
params.set(name, refl->GetEnum(msg, field)->name());
}
break;
default:
CV_Error(Error::StsError, "Unknown type \"" + String(field->type_name()) + "\" in prototxt");
break;
} }
}
void extractLayerParams(const Message &msg, cv::dnn::LayerParams &params, bool isInternal = false) inline static bool ends_with_param(const std::string &str)
{ {
const Descriptor *msgDesc = msg.GetDescriptor(); static const std::string _param("_param");
const Reflection *msgRefl = msg.GetReflection(); return (str.size() >= _param.size()) && str.compare(str.size() - _param.size(), _param.size(), _param) == 0;
}
for (int fieldId = 0; fieldId < msgDesc->field_count(); fieldId++) void extractLayerParams(const Message &msg, cv::dnn::LayerParams &params, bool isInternal = false)
{ {
const FieldDescriptor *fd = msgDesc->field(fieldId); const Descriptor *msgDesc = msg.GetDescriptor();
const Reflection *msgRefl = msg.GetReflection();
if (!isInternal && !ends_with_param(fd->name()))
continue;
bool hasData = fd->is_required() ||
(fd->is_optional() && msgRefl->HasField(msg, fd)) ||
(fd->is_repeated() && msgRefl->FieldSize(msg, fd) > 0);
if (!hasData)
continue;
if (fd->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE)
{
if (fd->is_repeated()) //Extract only first item!
extractLayerParams(msgRefl->GetRepeatedMessage(msg, fd, 0), params, true);
else
extractLayerParams(msgRefl->GetMessage(msg, fd), params, true);
}
else
{
addParam(msg, fd, params);
}
}
}
BlobShape blobShapeFromProto(const caffe::BlobProto &pbBlob) for (int fieldId = 0; fieldId < msgDesc->field_count(); fieldId++)
{ {
if (pbBlob.has_num() || pbBlob.has_channels() || pbBlob.has_height() || pbBlob.has_width()) const FieldDescriptor *fd = msgDesc->field(fieldId);
{
return BlobShape(pbBlob.num(), pbBlob.channels(), pbBlob.height(), pbBlob.width()); if (!isInternal && !ends_with_param(fd->name()))
} continue;
else if (pbBlob.has_shape())
{
const caffe::BlobShape &_shape = pbBlob.shape();
BlobShape shape = BlobShape::all(_shape.dim_size());
for (int i = 0; i < _shape.dim_size(); i++) bool hasData = fd->is_required() ||
shape[i] = (int)_shape.dim(i); (fd->is_optional() && msgRefl->HasField(msg, fd)) ||
(fd->is_repeated() && msgRefl->FieldSize(msg, fd) > 0);
if (!hasData)
continue;
return shape; if (fd->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE)
{
if (fd->is_repeated()) //Extract only first item!
extractLayerParams(msgRefl->GetRepeatedMessage(msg, fd, 0), params, true);
else
extractLayerParams(msgRefl->GetMessage(msg, fd), params, true);
} }
else else
{ {
CV_Error(Error::StsError, "Unknown shape of input blob"); addParam(msg, fd, params);
return BlobShape();
} }
} }
}
void blobFromProto(const caffe::BlobProto &pbBlob, cv::dnn::Blob &dstBlob) BlobShape blobShapeFromProto(const caffe::BlobProto &pbBlob)
{
if (pbBlob.has_num() || pbBlob.has_channels() || pbBlob.has_height() || pbBlob.has_width())
{ {
BlobShape shape = blobShapeFromProto(pbBlob); return BlobShape(pbBlob.num(), pbBlob.channels(), pbBlob.height(), pbBlob.width());
}
dstBlob.create(shape, CV_32F); else if (pbBlob.has_shape())
CV_Assert(pbBlob.data_size() == (int)dstBlob.matRefConst().total()); {
const caffe::BlobShape &_shape = pbBlob.shape();
BlobShape shape = BlobShape::all(_shape.dim_size());
CV_DbgAssert(pbBlob.GetDescriptor()->FindFieldByLowercaseName("data")->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT); for (int i = 0; i < _shape.dim_size(); i++)
float *dstData = dstBlob.matRef().ptr<float>(); shape[i] = (int)_shape.dim(i);
for (int i = 0; i < pbBlob.data_size(); i++) return shape;
dstData[i] = pbBlob.data(i);
} }
else
void extractBinaryLayerParms(const caffe::LayerParameter& layer, LayerParams& layerParams)
{ {
const std::string &name = layer.name(); CV_Error(Error::StsError, "Unknown shape of input blob");
return BlobShape();
}
}
int li; void blobFromProto(const caffe::BlobProto &pbBlob, cv::dnn::Blob &dstBlob)
for (li = 0; li != netBinary.layer_size(); li++) {
{ BlobShape shape = blobShapeFromProto(pbBlob);
if (netBinary.layer(li).name() == name)
break;
}
if (li == netBinary.layer_size() || netBinary.layer(li).blobs_size() == 0) dstBlob.create(shape, CV_32F);
return; CV_Assert(pbBlob.data_size() == (int)dstBlob.matRefConst().total());
const caffe::LayerParameter &binLayer = netBinary.layer(li); CV_DbgAssert(pbBlob.GetDescriptor()->FindFieldByLowercaseName("data")->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT);
layerParams.blobs.resize(binLayer.blobs_size()); float *dstData = dstBlob.matRef().ptr<float>();
for (int bi = 0; bi < binLayer.blobs_size(); bi++)
{
blobFromProto(binLayer.blobs(bi), layerParams.blobs[bi]);
}
}
struct BlobNote for (int i = 0; i < pbBlob.data_size(); i++)
{ dstData[i] = pbBlob.data(i);
BlobNote(const std::string &_name, int _layerId, int _outNum) : }
name(_name.c_str()), layerId(_layerId), outNum(_outNum) {}
const char *name; void extractBinaryLayerParms(const caffe::LayerParameter& layer, LayerParams& layerParams)
int layerId, outNum; {
}; const std::string &name = layer.name();
void populateNet(Net dstNet) int li;
for (li = 0; li != netBinary.layer_size(); li++)
{ {
int layersSize = net.layer_size(); if (netBinary.layer(li).name() == name)
std::vector<BlobNote> addedBlobs; break;
addedBlobs.reserve(layersSize + 1); }
//setup input layer names if (li == netBinary.layer_size() || netBinary.layer(li).blobs_size() == 0)
{ return;
std::vector<String> netInputs(net.input_size());
for (int inNum = 0; inNum < net.input_size(); inNum++)
{
addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum));
netInputs[inNum] = net.input(inNum);
}
dstNet.setNetInputs(netInputs);
}
for (int li = 0; li < layersSize; li++) const caffe::LayerParameter &binLayer = netBinary.layer(li);
{ layerParams.blobs.resize(binLayer.blobs_size());
const caffe::LayerParameter &layer = net.layer(li); for (int bi = 0; bi < binLayer.blobs_size(); bi++)
String name = layer.name(); {
String type = layer.type(); blobFromProto(binLayer.blobs(bi), layerParams.blobs[bi]);
LayerParams layerParams; }
}
extractLayerParams(layer, layerParams); struct BlobNote
extractBinaryLayerParms(layer, layerParams); {
BlobNote(const std::string &_name, int _layerId, int _outNum) :
name(_name.c_str()), layerId(_layerId), outNum(_outNum) {}
int id = dstNet.addLayer(name, type, layerParams); const char *name;
int layerId, outNum;
};
for (int inNum = 0; inNum < layer.bottom_size(); inNum++) std::vector<BlobNote> addedBlobs;
addInput(layer.bottom(inNum), id, inNum, dstNet, addedBlobs); std::map<String, int> layerCounter;
for (int outNum = 0; outNum < layer.top_size(); outNum++) void populateNet(Net dstNet)
addOutput(layer, id, outNum, addedBlobs); {
int layersSize = net.layer_size();
layerCounter.clear();
addedBlobs.clear();
addedBlobs.reserve(layersSize + 1);
//setup input layer names
{
std::vector<String> netInputs(net.input_size());
for (int inNum = 0; inNum < net.input_size(); inNum++)
{
addedBlobs.push_back(BlobNote(net.input(inNum), 0, inNum));
netInputs[inNum] = net.input(inNum);
} }
dstNet.setNetInputs(netInputs);
} }
void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum, std::vector<BlobNote> &addedBlobs) for (int li = 0; li < layersSize; li++)
{ {
const std::string &name = layer.top(outNum); const caffe::LayerParameter &layer = net.layer(li);
String name = layer.name();
String type = layer.type();
LayerParams layerParams;
bool haveDups = false; extractLayerParams(layer, layerParams);
for (int idx = (int)addedBlobs.size() - 1; idx >= 0; idx--) extractBinaryLayerParms(layer, layerParams);
{
if (addedBlobs[idx].name == name)
{
haveDups = true;
break;
}
}
if (haveDups) int repetitions = layerCounter[name]++;
{ if (repetitions)
bool isInplace = layer.bottom_size() > outNum && layer.bottom(outNum) == name; name += String("_") + toString(repetitions);
if (!isInplace)
CV_Error(Error::StsBadArg, "Duplicate blobs produced by multiple sources"); int id = dstNet.addLayer(name, type, layerParams);
}
addedBlobs.push_back(BlobNote(name, layerId, outNum)); for (int inNum = 0; inNum < layer.bottom_size(); inNum++)
addInput(layer.bottom(inNum), id, inNum, dstNet);
for (int outNum = 0; outNum < layer.top_size(); outNum++)
addOutput(layer, id, outNum);
} }
void addInput(const std::string &name, int layerId, int inNum, Net &dstNet, std::vector<BlobNote> &addedBlobs) addedBlobs.clear();
{ }
int idx;
for (idx = (int)addedBlobs.size() - 1; idx >= 0; idx--) void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum)
{ {
if (addedBlobs[idx].name == name) const std::string &name = layer.top(outNum);
break;
}
if (idx < 0) bool haveDups = false;
for (int idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)
{
if (addedBlobs[idx].name == name)
{ {
CV_Error(Error::StsObjectNotFound, "Can't found output blob \"" + name + "\""); haveDups = true;
return; break;
} }
}
dstNet.connect(addedBlobs[idx].layerId, addedBlobs[idx].outNum, layerId, inNum); if (haveDups)
{
bool isInplace = layer.bottom_size() > outNum && layer.bottom(outNum) == name;
if (!isInplace)
CV_Error(Error::StsBadArg, "Duplicate blobs produced by multiple sources");
} }
~CaffeImporter() addedBlobs.push_back(BlobNote(name, layerId, outNum));
}
void addInput(const std::string &name, int layerId, int inNum, Net &dstNet)
{
int idx;
for (idx = (int)addedBlobs.size() - 1; idx >= 0; idx--)
{ {
if (addedBlobs[idx].name == name)
break;
}
if (idx < 0)
{
CV_Error(Error::StsObjectNotFound, "Can't find output blob \"" + name + "\"");
return;
} }
dstNet.connect(addedBlobs[idx].layerId, addedBlobs[idx].outNum, layerId, inNum);
}
}; ~CaffeImporter()
{
}
};
} }
......
...@@ -60,7 +60,7 @@ namespace dnn ...@@ -60,7 +60,7 @@ namespace dnn
{ {
template<typename T> template<typename T>
String toString(const T &v) static String toString(const T &v)
{ {
std::ostringstream ss; std::ostringstream ss;
ss << v; ss << v;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment