Commit 9a075c46 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Michał Karzyński

[ONNX] Add lacking data types to ONNX Tensor (#2366)

parent af85d85d
......@@ -181,6 +181,34 @@ namespace ngraph
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<int8_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<int8_t>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT8)
{
return detail::__get_data<int8_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<int16_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<int16_t>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT16)
{
return detail::__get_data<int16_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<int32_t> get_data(const onnx::TensorProto& tensor)
{
......@@ -209,6 +237,48 @@ namespace ngraph
return detail::__get_data<int64_t>(tensor.int64_data());
}
template <>
inline std::vector<uint8_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<uint8_t>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_UINT8)
{
return detail::__get_data<uint8_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<uint16_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<uint16_t>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_UINT16)
{
return detail::__get_data<uint16_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<uint32_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<uint32_t>(tensor.raw_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_UINT32)
{
return detail::__get_data<uint32_t>(tensor.uint64_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<uint64_t> get_data(const onnx::TensorProto& tensor)
{
......
......@@ -65,6 +65,20 @@ namespace ngraph
return __make_ng_constant<double>(element::f64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int8>(const Tensor& tensor)
{
return __make_ng_constant<int8_t>(element::i8, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int16>(const Tensor& tensor)
{
return __make_ng_constant<int16_t>(element::i16, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int32>(const Tensor& tensor)
......@@ -79,6 +93,20 @@ namespace ngraph
return __make_ng_constant<int64_t>(element::i64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint8>(const Tensor& tensor)
{
return __make_ng_constant<uint8_t>(element::u8, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint16>(const Tensor& tensor)
{
return __make_ng_constant<uint16_t>(element::u16, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor)
......@@ -103,8 +131,12 @@ namespace ngraph
MAKE_NG_CONSTANT(Tensor::Type::float16);
MAKE_NG_CONSTANT(Tensor::Type::float32);
MAKE_NG_CONSTANT(Tensor::Type::float64);
MAKE_NG_CONSTANT(Tensor::Type::int8);
MAKE_NG_CONSTANT(Tensor::Type::int16);
MAKE_NG_CONSTANT(Tensor::Type::int32);
MAKE_NG_CONSTANT(Tensor::Type::int64);
MAKE_NG_CONSTANT(Tensor::Type::uint8);
MAKE_NG_CONSTANT(Tensor::Type::uint16);
MAKE_NG_CONSTANT(Tensor::Type::uint32);
MAKE_NG_CONSTANT(Tensor::Type::uint64);
default: throw error::tensor::invalid_data_type{tensor};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment