Commit cc989301 authored by Artur Wojcik's avatar Artur Wojcik Committed by Michał Karzyński

[ONNX] Tensor: add support for raw_data (#1552)

parent 3ffccb33
...@@ -65,6 +65,22 @@ namespace ngraph ...@@ -65,6 +65,22 @@ namespace ngraph
} }
}; };
struct data_type_undefined : ngraph_error
{
data_type_undefined()
: ngraph_error{"data type is not defined"}
{
}
};
struct segments_unsupported : ngraph_error
{
segments_unsupported()
: ngraph_error{"loading segments not supported"}
{
}
};
} // namespace tensor } // namespace tensor
} // namespace error } // namespace error
...@@ -82,6 +98,13 @@ namespace ngraph ...@@ -82,6 +98,13 @@ namespace ngraph
{ {
return {std::begin(container), std::end(container)}; return {std::begin(container), std::end(container)};
} }
template <typename T>
inline std::vector<T> __get_raw_data(const std::string& raw_data)
{
auto it = reinterpret_cast<const T*>(raw_data.data());
return {it, it + (raw_data.size() / sizeof(T))};
}
} }
} }
...@@ -94,6 +117,10 @@ namespace ngraph ...@@ -94,6 +117,10 @@ namespace ngraph
template <> template <>
inline std::vector<double> get_data(const onnx::TensorProto& tensor) inline std::vector<double> get_data(const onnx::TensorProto& tensor)
{ {
if (tensor.has_raw_data())
{
return detail::__get_raw_data<double>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_DOUBLE) if (tensor.data_type() == onnx::TensorProto_DataType_DOUBLE)
{ {
return detail::__get_data<double>(tensor.double_data()); return detail::__get_data<double>(tensor.double_data());
...@@ -121,6 +148,10 @@ namespace ngraph ...@@ -121,6 +148,10 @@ namespace ngraph
template <> template <>
inline std::vector<float> get_data(const onnx::TensorProto& tensor) inline std::vector<float> get_data(const onnx::TensorProto& tensor)
{ {
if (tensor.has_raw_data())
{
return detail::__get_raw_data<float>(tensor.raw_data());
}
if ((tensor.data_type() == onnx::TensorProto_DataType_FLOAT) or if ((tensor.data_type() == onnx::TensorProto_DataType_FLOAT) or
(tensor.data_type() == onnx::TensorProto_DataType_FLOAT16)) (tensor.data_type() == onnx::TensorProto_DataType_FLOAT16))
{ {
...@@ -144,6 +175,10 @@ namespace ngraph ...@@ -144,6 +175,10 @@ namespace ngraph
template <> template <>
inline std::vector<int32_t> get_data(const onnx::TensorProto& tensor) inline std::vector<int32_t> get_data(const onnx::TensorProto& tensor)
{ {
if (tensor.has_raw_data())
{
return detail::__get_raw_data<int32_t>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT32) if (tensor.data_type() == onnx::TensorProto_DataType_INT32)
{ {
return detail::__get_data<int32_t>(tensor.int32_data()); return detail::__get_data<int32_t>(tensor.int32_data());
...@@ -154,6 +189,10 @@ namespace ngraph ...@@ -154,6 +189,10 @@ namespace ngraph
template <> template <>
inline std::vector<int64_t> get_data(const onnx::TensorProto& tensor) inline std::vector<int64_t> get_data(const onnx::TensorProto& tensor)
{ {
if (tensor.has_raw_data())
{
return detail::__get_raw_data<int64_t>(tensor.raw_data());
}
if (tensor.data_type() != onnx::TensorProto_DataType_INT64) if (tensor.data_type() != onnx::TensorProto_DataType_INT64)
{ {
throw error::tensor::invalid_data_type{tensor.data_type()}; throw error::tensor::invalid_data_type{tensor.data_type()};
...@@ -164,6 +203,10 @@ namespace ngraph ...@@ -164,6 +203,10 @@ namespace ngraph
template <> template <>
inline std::vector<uint64_t> get_data(const onnx::TensorProto& tensor) inline std::vector<uint64_t> get_data(const onnx::TensorProto& tensor)
{ {
if (tensor.has_raw_data())
{
return detail::__get_raw_data<uint64_t>(tensor.raw_data());
}
if (tensor.data_type() != onnx::TensorProto_DataType_UINT64) if (tensor.data_type() != onnx::TensorProto_DataType_UINT64)
{ {
throw error::tensor::invalid_data_type{tensor.data_type()}; throw error::tensor::invalid_data_type{tensor.data_type()};
...@@ -213,6 +256,10 @@ namespace ngraph ...@@ -213,6 +256,10 @@ namespace ngraph
template <typename T> template <typename T>
std::vector<T> get_data() const std::vector<T> get_data() const
{ {
if (m_tensor_proto->has_segment())
{
throw error::tensor::segments_unsupported{};
}
return detail::tensor::get_data<T>(*m_tensor_proto); return detail::tensor::get_data<T>(*m_tensor_proto);
} }
...@@ -254,6 +301,8 @@ namespace ngraph ...@@ -254,6 +301,8 @@ namespace ngraph
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: return element::u16; case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: return element::u16;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: return element::u32; case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: return element::u32;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: return element::u64; case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: return element::u64;
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
throw error::tensor::data_type_undefined{};
default: throw error::tensor::unsupported_data_type{m_tensor_proto->data_type()}; default: throw error::tensor::unsupported_data_type{m_tensor_proto->data_type()};
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment