Commit c89abac0 authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[ONNX] Fix support for float16 in ONNX frontend. (#3563)

Initial support for float16 was done before the ngraph::float16 data
type was introduced in nGraph. This PR aligns ONNX frontend with the
latest float16 implementation.
parent e5330c16
......@@ -90,10 +90,8 @@ namespace ngraph
{
}
};
} // namespace tensor
} // namespace error
}
}
namespace detail
{
......@@ -109,11 +107,36 @@ namespace ngraph
return {std::begin(container), std::end(container)};
}
/// Returns the size if bytes of an ONNX data type.
inline size_t __get_onnx_data_size(int data_type)
{
switch (data_type)
{
case onnx::TensorProto_DataType_FLOAT: return sizeof(float);
case onnx::TensorProto_DataType_UINT8: return sizeof(uint8_t);
case onnx::TensorProto_DataType_INT8: return sizeof(int8_t);
case onnx::TensorProto_DataType_UINT16: return sizeof(uint16_t);
case onnx::TensorProto_DataType_INT16: return sizeof(int16_t);
case onnx::TensorProto_DataType_INT32: return sizeof(int32_t);
case onnx::TensorProto_DataType_INT64: return sizeof(int64_t);
case onnx::TensorProto_DataType_BOOL: return sizeof(bool);
case onnx::TensorProto_DataType_FLOAT16: return 2;
case onnx::TensorProto_DataType_DOUBLE: return sizeof(double);
case onnx::TensorProto_DataType_UINT32: return sizeof(uint32_t);
case onnx::TensorProto_DataType_UINT64: return sizeof(uint64_t);
case onnx::TensorProto_DataType_COMPLEX64: return 2 * sizeof(float);
case onnx::TensorProto_DataType_COMPLEX128: return 2 * sizeof(double);
default: NGRAPH_UNREACHABLE("Unsupported data type");
}
}
template <typename T>
inline std::vector<T> __get_raw_data(const std::string& raw_data)
inline std::vector<T> __get_raw_data(const std::string& raw_data,
int onnx_data_type)
{
auto it = reinterpret_cast<const T*>(raw_data.data());
return {it, it + (raw_data.size() / sizeof(T))};
return {it,
it + (raw_data.size() / __get_onnx_data_size(onnx_data_type))};
}
}
}
......@@ -130,14 +153,14 @@ namespace ngraph
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<double>(tensor.raw_data());
return detail::__get_raw_data<double>(tensor.raw_data(),
tensor.data_type());
}
if (tensor.data_type() == onnx::TensorProto_DataType_DOUBLE)
{
return detail::__get_data<double>(tensor.double_data());
}
if ((tensor.data_type() == onnx::TensorProto_DataType_FLOAT) ||
(tensor.data_type() == onnx::TensorProto_DataType_FLOAT16))
if (tensor.data_type() == onnx::TensorProto_DataType_FLOAT)
{
return detail::__get_data<double>(tensor.float_data());
}
......@@ -161,10 +184,9 @@ namespace ngraph
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<float>(tensor.raw_data());
return detail::__get_raw_data<float>(tensor.raw_data(), tensor.data_type());
}
if ((tensor.data_type() == onnx::TensorProto_DataType_FLOAT) ||
(tensor.data_type() == onnx::TensorProto_DataType_FLOAT16))
if ((tensor.data_type() == onnx::TensorProto_DataType_FLOAT))
{
return detail::__get_data<float>(tensor.float_data());
}
......@@ -183,12 +205,23 @@ namespace ngraph
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<ngraph::float16> get_data(const onnx::TensorProto& tensor)
{
NGRAPH_CHECK(tensor.data_type() == onnx::TensorProto_DataType_FLOAT16,
"Expected FLOAT16 data type");
NGRAPH_CHECK(tensor.has_raw_data(), "Expected raw data for FLOAT16 data type");
return detail::__get_raw_data<ngraph::float16>(tensor.raw_data(),
tensor.data_type());
}
template <>
inline std::vector<int8_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<int8_t>(tensor.raw_data());
return detail::__get_raw_data<int8_t>(tensor.raw_data(),
tensor.data_type());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT8)
{
......@@ -202,7 +235,8 @@ namespace ngraph
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<int16_t>(tensor.raw_data());
return detail::__get_raw_data<int16_t>(tensor.raw_data(),
tensor.data_type());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT16)
{
......@@ -216,7 +250,8 @@ namespace ngraph
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<int32_t>(tensor.raw_data());
return detail::__get_raw_data<int32_t>(tensor.raw_data(),
tensor.data_type());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT32)
{
......@@ -230,7 +265,8 @@ namespace ngraph
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<int64_t>(tensor.raw_data());
return detail::__get_raw_data<int64_t>(tensor.raw_data(),
tensor.data_type());
}
if (tensor.data_type() != onnx::TensorProto_DataType_INT64)
{
......@@ -244,7 +280,8 @@ namespace ngraph
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<uint8_t>(tensor.raw_data());
return detail::__get_raw_data<uint8_t>(tensor.raw_data(),
tensor.data_type());
}
if (tensor.data_type() == onnx::TensorProto_DataType_UINT8)
{
......@@ -258,7 +295,8 @@ namespace ngraph
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<uint16_t>(tensor.raw_data());
return detail::__get_raw_data<uint16_t>(tensor.raw_data(),
tensor.data_type());
}
if (tensor.data_type() == onnx::TensorProto_DataType_UINT16)
{
......@@ -272,7 +310,8 @@ namespace ngraph
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<uint32_t>(tensor.raw_data());
return detail::__get_raw_data<uint32_t>(tensor.raw_data(),
tensor.data_type());
}
if (tensor.data_type() == onnx::TensorProto_DataType_UINT32)
{
......@@ -286,7 +325,8 @@ namespace ngraph
{
if (tensor.has_raw_data())
{
return detail::__get_raw_data<uint64_t>(tensor.raw_data());
return detail::__get_raw_data<uint64_t>(tensor.raw_data(),
tensor.data_type());
}
if (tensor.data_type() != onnx::TensorProto_DataType_UINT64)
{
......@@ -371,8 +411,8 @@ namespace ngraph
switch (m_tensor_proto->data_type())
{
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: return element::boolean;
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: return element::f32;
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: return element::f32;
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: return element::f16;
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: return element::f64;
case onnx::TensorProto_DataType::TensorProto_DataType_INT8: return element::i8;
case onnx::TensorProto_DataType::TensorProto_DataType_INT16: return element::i16;
......@@ -396,8 +436,9 @@ namespace ngraph
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return make_ng_constant<bool>(element::boolean);
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
return make_ng_constant<float>(element::f32);
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
return make_ng_constant<ngraph::float16>(element::f16);
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
return make_ng_constant<double>(element::f64);
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
......@@ -435,7 +476,5 @@ namespace ngraph
{
return (outs << "<Tensor: " << tensor.get_name() << ">");
}
} // namespace onnx_import
} // namespace ngraph
}
}
......@@ -48,7 +48,7 @@ namespace ngraph
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float16>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
return __make_ng_constant<ngraph::float16>(element::f16, tensor);
}
template <>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment