Commit cc989301 authored by Artur Wojcik's avatar Artur Wojcik Committed by Michał Karzyński

[ONNX] Tensor: add support for raw_data (#1552)

parent 3ffccb33
......@@ -65,6 +65,22 @@ namespace ngraph
}
};
struct data_type_undefined : ngraph_error
{
data_type_undefined()
: ngraph_error{"data type is not defined"}
{
}
};
struct segments_unsupported : ngraph_error
{
segments_unsupported()
: ngraph_error{"loading segments not supported"}
{
}
};
} // namespace tensor
} // namespace error
......@@ -82,6 +98,13 @@ namespace ngraph
{
return {std::begin(container), std::end(container)};
}
template <typename T>
inline std::vector<T> __get_raw_data(const std::string& raw_data)
{
auto it = reinterpret_cast<const T*>(raw_data.data());
return {it, it + (raw_data.size() / sizeof(T))};
}
}
}
......@@ -94,6 +117,10 @@ namespace ngraph
template <>
inline std::vector<double> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<double>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_DOUBLE)
{
return detail::__get_data<double>(tensor.double_data());
......@@ -121,6 +148,10 @@ namespace ngraph
template <>
inline std::vector<float> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<float>(tensor.raw_data());
}
if ((tensor.data_type() == onnx::TensorProto_DataType_FLOAT) or
(tensor.data_type() == onnx::TensorProto_DataType_FLOAT16))
{
......@@ -144,6 +175,10 @@ namespace ngraph
template <>
inline std::vector<int32_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<int32_t>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT32)
{
return detail::__get_data<int32_t>(tensor.int32_data());
......@@ -154,6 +189,10 @@ namespace ngraph
template <>
inline std::vector<int64_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<int64_t>(tensor.raw_data());
}
if (tensor.data_type() != onnx::TensorProto_DataType_INT64)
{
throw error::tensor::invalid_data_type{tensor.data_type()};
......@@ -164,6 +203,10 @@ namespace ngraph
template <>
inline std::vector<uint64_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<uint64_t>(tensor.raw_data());
}
if (tensor.data_type() != onnx::TensorProto_DataType_UINT64)
{
throw error::tensor::invalid_data_type{tensor.data_type()};
......@@ -213,6 +256,10 @@ namespace ngraph
template <typename T>
std::vector<T> get_data() const
{
if (m_tensor_proto->has_segment())
{
throw error::tensor::segments_unsupported{};
}
return detail::tensor::get_data<T>(*m_tensor_proto);
}
......@@ -254,6 +301,8 @@ namespace ngraph
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: return element::u16;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: return element::u32;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: return element::u64;
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
throw error::tensor::data_type_undefined{};
default: throw error::tensor::unsupported_data_type{m_tensor_proto->data_type()};
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment