Commit 5ba8aa78 authored by Alexander Alekhin's avatar Alexander Alekhin

Merge pull request #14460 from dkurt:dnn_tf_no_extra_clone

parents 455323ed a6ed8f26
...@@ -770,43 +770,47 @@ void RemoveIdentityOps(tensorflow::GraphDef& net) ...@@ -770,43 +770,47 @@ void RemoveIdentityOps(tensorflow::GraphDef& net)
} }
} }
Mat getTensorContent(const tensorflow::TensorProto &tensor) Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy)
{ {
const std::string& content = tensor.tensor_content(); const std::string& content = tensor.tensor_content();
Mat m;
switch (tensor.dtype()) switch (tensor.dtype())
{ {
case tensorflow::DT_FLOAT: case tensorflow::DT_FLOAT:
{ {
if (!content.empty()) if (!content.empty())
return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone(); m = Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str());
else else
{ {
const RepeatedField<float>& field = tensor.float_val(); const RepeatedField<float>& field = tensor.float_val();
CV_Assert(!field.empty()); CV_Assert(!field.empty());
return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone(); m = Mat(1, field.size(), CV_32FC1, (void*)field.data());
} }
break;
} }
case tensorflow::DT_DOUBLE: case tensorflow::DT_DOUBLE:
{ {
if (!content.empty()) if (!content.empty())
return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone(); m = Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str());
else else
{ {
const RepeatedField<double>& field = tensor.double_val(); const RepeatedField<double>& field = tensor.double_val();
CV_Assert(!field.empty()); CV_Assert(!field.empty());
return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone(); m = Mat(1, field.size(), CV_64FC1, (void*)field.data());
} }
break;
} }
case tensorflow::DT_INT32: case tensorflow::DT_INT32:
{ {
if (!content.empty()) if (!content.empty())
return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone(); m = Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str());
else else
{ {
const RepeatedField<int32_t>& field = tensor.int_val(); const RepeatedField<int32_t>& field = tensor.int_val();
CV_Assert(!field.empty()); CV_Assert(!field.empty());
return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone(); m = Mat(1, field.size(), CV_32SC1, (void*)field.data());
} }
break;
} }
case tensorflow::DT_HALF: case tensorflow::DT_HALF:
{ {
...@@ -825,20 +829,20 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor) ...@@ -825,20 +829,20 @@ Mat getTensorContent(const tensorflow::TensorProto &tensor)
} }
// Reinterpret as a signed shorts just for a convertFp16 call. // Reinterpret as a signed shorts just for a convertFp16 call.
Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data); Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data);
Mat floats(halfs.size(), CV_32FC1); convertFp16(halfsSigned, m);
convertFp16(halfsSigned, floats); break;
return floats;
} }
case tensorflow::DT_QUINT8: case tensorflow::DT_QUINT8:
{ {
CV_Assert(!content.empty()); CV_Assert(!content.empty());
return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone(); m = Mat(1, content.size(), CV_8UC1, (void*)content.c_str());
break;
} }
default: default:
CV_Error(Error::StsError, "Tensor's data type is not supported"); CV_Error(Error::StsError, "Tensor's data type is not supported");
break; break;
} }
return Mat(); return copy ? m.clone() : m;
} }
void releaseTensor(tensorflow::TensorProto* tensor) void releaseTensor(tensorflow::TensorProto* tensor)
......
...@@ -21,7 +21,7 @@ void RemoveIdentityOps(tensorflow::GraphDef& net); ...@@ -21,7 +21,7 @@ void RemoveIdentityOps(tensorflow::GraphDef& net);
void simplifySubgraphs(tensorflow::GraphDef& net); void simplifySubgraphs(tensorflow::GraphDef& net);
Mat getTensorContent(const tensorflow::TensorProto &tensor); Mat getTensorContent(const tensorflow::TensorProto &tensor, bool copy = true);
void releaseTensor(tensorflow::TensorProto* tensor); void releaseTensor(tensorflow::TensorProto* tensor);
......
...@@ -109,7 +109,7 @@ void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob) ...@@ -109,7 +109,7 @@ void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
dstBlob.create(shape, CV_32F); dstBlob.create(shape, CV_32F);
Mat tensorContent = getTensorContent(tensor); Mat tensorContent = getTensorContent(tensor, /*no copy*/false);
int size = tensorContent.total(); int size = tensorContent.total();
CV_Assert(size == (int)dstBlob.total()); CV_Assert(size == (int)dstBlob.total());
...@@ -509,7 +509,7 @@ void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &ds ...@@ -509,7 +509,7 @@ void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &ds
dstBlob.create(shape, CV_32F); dstBlob.create(shape, CV_32F);
Mat tensorContent = getTensorContent(tensor); Mat tensorContent = getTensorContent(tensor, /*no copy*/false);
int size = tensorContent.total(); int size = tensorContent.total();
CV_Assert(size == (int)dstBlob.total()); CV_Assert(size == (int)dstBlob.total());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment