Commit 777d7784 authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Free Convolution and MatMul weights after TensorFlow layers import

parent 9ffe4694
...@@ -612,7 +612,7 @@ void RemoveIdentityOps(tensorflow::GraphDef& net) ...@@ -612,7 +612,7 @@ void RemoveIdentityOps(tensorflow::GraphDef& net)
Mat getTensorContent(const tensorflow::TensorProto &tensor) Mat getTensorContent(const tensorflow::TensorProto &tensor)
{ {
std::string content = tensor.tensor_content(); const std::string& content = tensor.tensor_content();
switch (tensor.dtype()) switch (tensor.dtype())
{ {
case tensorflow::DT_FLOAT: case tensorflow::DT_FLOAT:
...@@ -681,6 +681,14 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor) ...@@ -681,6 +681,14 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor)
return Mat(); return Mat();
} }
void releaseTensor(tensorflow::TensorProto* tensor)
{
if (!tensor->mutable_tensor_content()->empty())
{
delete tensor->release_tensor_content();
}
}
CV__DNN_EXPERIMENTAL_NS_END CV__DNN_EXPERIMENTAL_NS_END
}} // namespace dnn, namespace cv }} // namespace dnn, namespace cv
......
...@@ -23,6 +23,8 @@ void simplifySubgraphs(tensorflow::GraphDef& net); ...@@ -23,6 +23,8 @@ void simplifySubgraphs(tensorflow::GraphDef& net);
Mat getTensorContent(const tensorflow::TensorProto &tensor); Mat getTensorContent(const tensorflow::TensorProto &tensor);
void releaseTensor(tensorflow::TensorProto* tensor);
CV__DNN_EXPERIMENTAL_NS_END CV__DNN_EXPERIMENTAL_NS_END
}} // namespace dnn, namespace cv }} // namespace dnn, namespace cv
......
...@@ -677,7 +677,9 @@ void TFImporter::populateNet(Net dstNet) ...@@ -677,7 +677,9 @@ void TFImporter::populateNet(Net dstNet)
layers_to_ignore.insert(next_layers[0].first); layers_to_ignore.insert(next_layers[0].first);
} }
kernelFromTensor(getConstBlob(layer, value_id), layerParams.blobs[0]); const tensorflow::TensorProto& kernelTensor = getConstBlob(layer, value_id);
kernelFromTensor(kernelTensor, layerParams.blobs[0]);
releaseTensor(const_cast<tensorflow::TensorProto*>(&kernelTensor));
int* kshape = layerParams.blobs[0].size.p; int* kshape = layerParams.blobs[0].size.p;
if (type == "DepthwiseConv2dNative") if (type == "DepthwiseConv2dNative")
{ {
...@@ -788,7 +790,9 @@ void TFImporter::populateNet(Net dstNet) ...@@ -788,7 +790,9 @@ void TFImporter::populateNet(Net dstNet)
} }
int kernel_blob_index = -1; int kernel_blob_index = -1;
blobFromTensor(getConstBlob(layer, value_id, -1, &kernel_blob_index), layerParams.blobs[0]); const tensorflow::TensorProto& kernelTensor = getConstBlob(layer, value_id, -1, &kernel_blob_index);
blobFromTensor(kernelTensor, layerParams.blobs[0]);
releaseTensor(const_cast<tensorflow::TensorProto*>(&kernelTensor));
if (kernel_blob_index == 1) { // In this case output is computed by x*W formula - W should be transposed if (kernel_blob_index == 1) { // In this case output is computed by x*W formula - W should be transposed
Mat data = layerParams.blobs[0].t(); Mat data = layerParams.blobs[0].t();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment