Unverified Commit 5dc708a1 authored by Artur Wojcik's avatar Artur Wojcik Committed by GitHub

[ONNX] automatically detect value types of TensorProto and AttributeProto (#2262)

* [ONNX] detect automatically value types of TensorProto and AttributeProto
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: style apply
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent c5bf6812
...@@ -29,6 +29,13 @@ namespace ngraph ...@@ -29,6 +29,13 @@ namespace ngraph
class Graph; class Graph;
class Model; class Model;
// Detecting automatically the underlying type used to store the information
// for data type of values an attribute is holding. A bug was discovered in
// protobuf which forced ONNX team to switch from `enum AttributeProto_AttributeType`
// to `int32` in order to workaround the bug. This line allows using both versions
// of ONNX generated wrappers.
using AttributeProto_AttributeType = decltype(onnx::AttributeProto{}.type());
namespace error namespace error
{ {
namespace attribute namespace attribute
...@@ -37,7 +44,7 @@ namespace ngraph ...@@ -37,7 +44,7 @@ namespace ngraph
{ {
struct Attribute : ngraph_error struct Attribute : ngraph_error
{ {
Attribute(const std::string& msg, onnx::AttributeProto_AttributeType type) Attribute(const std::string& msg, AttributeProto_AttributeType type)
: ngraph_error{msg + ": " + : ngraph_error{msg + ": " +
onnx::AttributeProto_AttributeType_Name(type)} onnx::AttributeProto_AttributeType_Name(type)}
{ {
...@@ -48,7 +55,7 @@ namespace ngraph ...@@ -48,7 +55,7 @@ namespace ngraph
struct InvalidData : detail::Attribute struct InvalidData : detail::Attribute
{ {
explicit InvalidData(onnx::AttributeProto_AttributeType type) explicit InvalidData(AttributeProto_AttributeType type)
: Attribute{"invalid attribute type", type} : Attribute{"invalid attribute type", type}
{ {
} }
...@@ -56,7 +63,7 @@ namespace ngraph ...@@ -56,7 +63,7 @@ namespace ngraph
struct UnsupportedType : detail::Attribute struct UnsupportedType : detail::Attribute
{ {
explicit UnsupportedType(onnx::AttributeProto_AttributeType type) explicit UnsupportedType(AttributeProto_AttributeType type)
: Attribute{"unsupported attribute type", type} : Attribute{"unsupported attribute type", type}
{ {
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <onnx-ml.pb.h> #include <onnx-ml.pb.h>
#include <utility>
#include <vector> #include <vector>
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
...@@ -26,24 +27,33 @@ namespace ngraph ...@@ -26,24 +27,33 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
// Detecting automatically the underlying type used to store the information
// for data type of values a tensor is holding. A bug was discovered in protobuf
// which forced ONNX team to switch from `enum TensorProto_DataType` to `int32`
// in order to workaround the bug. This line allows using both versions of ONNX
// generated wrappers.
using TensorProto_DataType = decltype(onnx::TensorProto{}.data_type());
namespace error namespace error
{ {
namespace tensor namespace tensor
{ {
struct invalid_data_type : ngraph_error struct invalid_data_type : ngraph_error
{ {
explicit invalid_data_type(onnx::TensorProto_DataType type) explicit invalid_data_type(TensorProto_DataType type)
: ngraph_error{"invalid data type: " + : ngraph_error{"invalid data type: " +
onnx::TensorProto_DataType_Name(type)} onnx::TensorProto_DataType_Name(
static_cast<onnx::TensorProto_DataType>(type))}
{ {
} }
}; };
struct unsupported_data_type : ngraph_error struct unsupported_data_type : ngraph_error
{ {
explicit unsupported_data_type(onnx::TensorProto_DataType type) explicit unsupported_data_type(TensorProto_DataType type)
: ngraph_error{"unsupported data type: " + : ngraph_error{"unsupported data type: " +
onnx::TensorProto_DataType_Name(type)} onnx::TensorProto_DataType_Name(
static_cast<onnx::TensorProto_DataType>(type))}
{ {
} }
}; };
...@@ -306,7 +316,7 @@ namespace ngraph ...@@ -306,7 +316,7 @@ namespace ngraph
} }
} }
operator onnx::TensorProto_DataType() const { return m_tensor_proto->data_type(); } operator TensorProto_DataType() const { return m_tensor_proto->data_type(); }
private: private:
const onnx::TensorProto* m_tensor_proto; const onnx::TensorProto* m_tensor_proto;
Shape m_shape; Shape m_shape;
......
...@@ -43,9 +43,10 @@ namespace ngraph ...@@ -43,9 +43,10 @@ namespace ngraph
}; };
struct unsupported_element_type : ngraph_error struct unsupported_element_type : ngraph_error
{ {
explicit unsupported_element_type(onnx::TensorProto_DataType type) explicit unsupported_element_type(TensorProto_DataType type)
: ngraph_error{"unsupported value info element type: " + : ngraph_error{"unsupported value info element type: " +
onnx::TensorProto_DataType_Name(type)} onnx::TensorProto_DataType_Name(
static_cast<onnx::TensorProto_DataType>(type))}
{ {
} }
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment