Unverified Commit d2be2c76 authored by Katarzyna Mitrus's avatar Katarzyna Mitrus Committed by GitHub

[ONNX] Enable boolean tensor type (#4355)

* get_data boolean support

* compare_values for boolean char

* Register comparator boolean

* Bool type model prototxt

* Test debug

* Test debug cleaning

* Pass output test

* Style apply

* Test update

* Exclude LogicalAnd test on GPU

* Make ng constan for bool

* Style apply

* Copy data for bool

* Tests update

* Style apply

* GPU manifest update

* Test bool constant op

* Store bool as char vector

* Comment update
Co-authored-by: 's avatarMichał Karzyński <postrational@users.noreply.github.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 363cfafa
......@@ -104,7 +104,7 @@ namespace ngraph
template <typename T, typename Container>
inline std::vector<T> __get_data(const Container& container)
{
return {std::begin(container), std::end(container)};
return std::vector<T>(std::begin(container), std::end(container));
}
/// Returns the size if bytes of an ONNX data type.
......@@ -119,7 +119,7 @@ namespace ngraph
case onnx::TensorProto_DataType_INT16: return sizeof(int16_t);
case onnx::TensorProto_DataType_INT32: return sizeof(int32_t);
case onnx::TensorProto_DataType_INT64: return sizeof(int64_t);
case onnx::TensorProto_DataType_BOOL: return sizeof(bool);
case onnx::TensorProto_DataType_BOOL: return sizeof(char);
case onnx::TensorProto_DataType_FLOAT16: return 2;
case onnx::TensorProto_DataType_DOUBLE: return sizeof(double);
case onnx::TensorProto_DataType_UINT32: return sizeof(uint32_t);
......@@ -135,8 +135,8 @@ namespace ngraph
int onnx_data_type)
{
auto it = reinterpret_cast<const T*>(raw_data.data());
return {it,
it + (raw_data.size() / __get_onnx_data_size(onnx_data_type))};
return std::vector<T>(
it, it + (raw_data.size() / __get_onnx_data_size(onnx_data_type)));
}
}
}
......@@ -334,6 +334,22 @@ namespace ngraph
}
return detail::__get_data<uint64_t>(tensor.uint64_data());
}
template <>
inline std::vector<char> get_data(const onnx::TensorProto& tensor)
{
// Boolean values are stored as char because std::vector<bool>
// can behave differently from other vector containers.
if (tensor.has_raw_data())
{
return detail::__get_raw_data<char>(tensor.raw_data(), tensor.data_type());
}
if (tensor.data_type() == onnx::TensorProto_DataType_BOOL)
{
return detail::__get_data<char>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
}
}
......@@ -441,7 +457,7 @@ namespace ngraph
switch (m_tensor_proto->data_type())
{
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return make_ng_constant<bool>(element::boolean);
return make_ng_constant<char>(element::boolean);
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
return make_ng_constant<float>(element::f32);
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
......
......@@ -121,6 +121,13 @@ namespace ngraph
return __make_ng_constant<uint64_t>(element::u64, tensor);
}
template <>
inline std::shared_ptr<default_opset::Constant>
make_ng_constant<Tensor::Type::boolean>(const Tensor& tensor)
{
return __make_ng_constant<char>(element::boolean, tensor);
}
inline std::shared_ptr<default_opset::Constant>
make_constant(const Tensor& tensor)
{
......@@ -140,6 +147,7 @@ namespace ngraph
MAKE_NG_CONSTANT(Tensor::Type::uint16);
MAKE_NG_CONSTANT(Tensor::Type::uint32);
MAKE_NG_CONSTANT(Tensor::Type::uint64);
MAKE_NG_CONSTANT(Tensor::Type::boolean);
default: throw error::tensor::invalid_data_type{tensor};
}
}
......
......@@ -166,6 +166,8 @@ create_tensor_2_output
# Not implemented
batch_mat_mul_forward
backwards_batchmatmul_tensor2_tensor2
bool_init_and
bool_input_or
erf
zero_sized_erf
model_erf
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
name: "test_graph"
node {
output: "A"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 2
dims: 2
data_type: 9
int32_data: 1
int32_data: 0
int32_data: 0
int32_data: 1
name: "const_bool_tensor"
}
type: TENSOR
}
}
output {
name: "A"
type {
tensor_type {
elem_type: 9
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 4
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
name: "test_graph"
node {
input: "A"
input: "B"
output: "Y"
name: "node"
op_type: "And"
}
initializer {
data_type: 9
name: "A"
int32_data: 1
}
input {
name: "A"
type {
tensor_type {
elem_type: 9
shape {
}
}
}
}
initializer {
data_type: 9
name: "B"
raw_data: "\001"
}
input {
name: "B"
type {
tensor_type {
elem_type: 9
shape {
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 9
shape {
}
}
}
}
}
opset_import {
version: 4
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
name: "test_graph"
initializer {
dims: 1
dims: 3
data_type: 9
name: "A"
raw_data: "\001\000\001"
}
input {
name: "A"
type {
tensor_type {
elem_type: 9
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "A"
type {
tensor_type {
elem_type: 9
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 4
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
name: "test_graph"
node {
input: "A"
input: "B"
output: "Y"
name: "node"
op_type: "Or"
}
input {
name: "A"
type {
tensor_type {
elem_type: 9
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 9
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 9
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 4
}
......@@ -131,6 +131,48 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_binary_add_abc)
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, bool_const_op)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/bool_const_op.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_expected_output(std::vector<bool>{1, 0, 0, 1});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, bool_init_and)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/bool_init_and.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_expected_output(std::vector<bool>{1});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, bool_input_or)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/bool_input_or.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input(std::vector<bool>{true, false, true, false});
test_case.add_input(std::vector<bool>{false, false, true, true});
test_case.add_expected_output(std::vector<bool>{1, 0, 1, 1});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, bool_init_raw)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/bool_init_raw.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_expected_output(std::vector<bool>{true, false, true});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_add_abc_initializers)
{
auto function = onnx_import::import_onnx_model(
......
......@@ -242,6 +242,7 @@ namespace ngraph
REGISTER_COMPARATOR(u16, uint16_t),
REGISTER_COMPARATOR(u32, uint32_t),
REGISTER_COMPARATOR(u64, uint64_t),
REGISTER_COMPARATOR(boolean, char),
};
#undef REGISTER_COMPARATOR
......
......@@ -79,6 +79,13 @@ shared_ptr<Function> make_test_graph()
return f0;
}
template <>
void copy_data<bool>(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<bool>& data)
{
std::vector<char> data_char(data.begin(), data.end());
copy_data(tv, data_char);
}
template <>
void init_int_tv<char>(ngraph::runtime::Tensor* tv,
std::default_random_engine& engine,
......
......@@ -58,6 +58,9 @@ void copy_data(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<T>
tv->write(data.data(), data_size);
}
template <>
void copy_data<bool>(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<bool>& data);
template <typename T>
std::vector<T> read_vector(std::shared_ptr<ngraph::runtime::Tensor> tv)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment