Commit e9dd6087 authored by Gleb Kazantaev's avatar Gleb Kazantaev Committed by Jayaram Bobba

Added u1 precision for binary weights (#3914)

* Added U1 precision for binary weights

* Handle switch cases with u1 type

* Fixed code style

* Added convert_to_string support for u1 type

* Use real C type  for u1 type.
Co-Authored-By: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 89b3019e
...@@ -238,6 +238,7 @@ mlir::Type NgDialectConversionPass::getMlirType(const element::Type& type) ...@@ -238,6 +238,7 @@ mlir::Type NgDialectConversionPass::getMlirType(const element::Type& type)
{ {
case ngraph::element::Type_t::undefined: case ngraph::element::Type_t::undefined:
case ngraph::element::Type_t::dynamic: case ngraph::element::Type_t::dynamic:
case ngraph::element::Type_t::u1:
default: NGRAPH_CHECK(false, "MLIR: Unsupported NGraph types"); break; default: NGRAPH_CHECK(false, "MLIR: Unsupported NGraph types"); break;
case ngraph::element::Type_t::bf16: return mlir::NGFloatType::getBF16(m_context); case ngraph::element::Type_t::bf16: return mlir::NGFloatType::getBF16(m_context);
case ngraph::element::Type_t::f16: return mlir::NGFloatType::getF16(m_context); case ngraph::element::Type_t::f16: return mlir::NGFloatType::getF16(m_context);
......
...@@ -94,6 +94,8 @@ namespace ngraph ...@@ -94,6 +94,8 @@ namespace ngraph
throw ngraph_error("make_constant: Unsupported element type 'dynamic'"); throw ngraph_error("make_constant: Unsupported element type 'dynamic'");
case element::Type_t::boolean: case element::Type_t::boolean:
throw ngraph_error("make_constant: Unsupported element type 'boolean'"); throw ngraph_error("make_constant: Unsupported element type 'boolean'");
case element::Type_t::u1:
throw ngraph_error("make_constant: Unsupported element type 'u1'");
case element::Type_t::undefined: case element::Type_t::undefined:
throw ngraph_error("make_constant: Unsupported element type 'undefined'"); throw ngraph_error("make_constant: Unsupported element type 'undefined'");
} }
......
...@@ -196,6 +196,7 @@ namespace ngraph ...@@ -196,6 +196,7 @@ namespace ngraph
case element::Type_t::u64: m_type = MPI_UNSIGNED_LONG; break; case element::Type_t::u64: m_type = MPI_UNSIGNED_LONG; break;
case element::Type_t::bf16: case element::Type_t::bf16:
case element::Type_t::f16: case element::Type_t::f16:
case element::Type_t::u1:
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type"); case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
} }
......
...@@ -114,18 +114,7 @@ void op::v1::BinaryConvolution::validate_and_infer_types() ...@@ -114,18 +114,7 @@ void op::v1::BinaryConvolution::validate_and_infer_types()
} }
} }
element::Type result_et;
PartialShape result_shape; PartialShape result_shape;
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, data_batch_et, filters_et),
"Element types for data batch and filters do not match (data batch element type: ",
data_batch_et,
", filters element type: ",
filters_et,
").");
result_shape = result_shape =
infer_convolution_forward(this, infer_convolution_forward(this,
data_batch_shape, data_batch_shape,
...@@ -136,7 +125,7 @@ void op::v1::BinaryConvolution::validate_and_infer_types() ...@@ -136,7 +125,7 @@ void op::v1::BinaryConvolution::validate_and_infer_types()
m_strides, m_strides,
m_dilations); m_dilations);
set_output_type(0, result_et, result_shape); set_output_type(0, data_batch_et, result_shape);
} }
shared_ptr<Node> op::v1::BinaryConvolution::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::BinaryConvolution::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -74,6 +74,9 @@ string op::Constant::convert_value_to_string(size_t index) const ...@@ -74,6 +74,9 @@ string op::Constant::convert_value_to_string(size_t index) const
case element::Type_t::i16: rc = to_string(get_vector<int16_t>()[index]); break; case element::Type_t::i16: rc = to_string(get_vector<int16_t>()[index]); break;
case element::Type_t::i32: rc = to_string(get_vector<int32_t>()[index]); break; case element::Type_t::i32: rc = to_string(get_vector<int32_t>()[index]); break;
case element::Type_t::i64: rc = to_string(get_vector<int64_t>()[index]); break; case element::Type_t::i64: rc = to_string(get_vector<int64_t>()[index]); break;
case element::Type_t::u1:
rc = to_string((get_vector<uint8_t>()[index / 8] >> (7 - (index % 8))) & 1);
break;
case element::Type_t::u8: rc = to_string(get_vector<uint8_t>()[index]); break; case element::Type_t::u8: rc = to_string(get_vector<uint8_t>()[index]); break;
case element::Type_t::u16: rc = to_string(get_vector<uint16_t>()[index]); break; case element::Type_t::u16: rc = to_string(get_vector<uint16_t>()[index]); break;
case element::Type_t::u32: rc = to_string(get_vector<uint32_t>()[index]); break; case element::Type_t::u32: rc = to_string(get_vector<uint32_t>()[index]); break;
...@@ -176,6 +179,7 @@ vector<string> op::Constant::get_value_strings() const ...@@ -176,6 +179,7 @@ vector<string> op::Constant::get_value_strings() const
rc.push_back(to_string(value)); rc.push_back(to_string(value));
} }
break; break;
case element::Type_t::u1: throw runtime_error("unsupported type");
case element::Type_t::undefined: throw runtime_error("unsupported type"); case element::Type_t::undefined: throw runtime_error("unsupported type");
case element::Type_t::dynamic: throw runtime_error("unsupported type"); case element::Type_t::dynamic: throw runtime_error("unsupported type");
} }
...@@ -323,6 +327,7 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const ...@@ -323,6 +327,7 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const
rc = test_bitwise_identical<uint64_t>(this); rc = test_bitwise_identical<uint64_t>(this);
break; break;
} }
case element::Type_t::u1:
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: break; case element::Type_t::dynamic: break;
} }
......
...@@ -338,6 +338,7 @@ namespace ngraph ...@@ -338,6 +338,7 @@ namespace ngraph
case element::Type_t::u64: case element::Type_t::u64:
write_buffer<uint64_t, T>(target, source, target_element_count); write_buffer<uint64_t, T>(target, source, target_element_count);
break; break;
case element::Type_t::u1: throw std::runtime_error("unsupported type");
case element::Type_t::undefined: throw std::runtime_error("unsupported type"); case element::Type_t::undefined: throw std::runtime_error("unsupported type");
case element::Type_t::dynamic: throw std::runtime_error("unsupported type"); case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
} }
......
...@@ -230,6 +230,7 @@ void op::Range::validate_and_infer_types() ...@@ -230,6 +230,7 @@ void op::Range::validate_and_infer_types()
case element::Type_t::u32: result_shape = infer_output_shape<uint32_t>(this, result_et); break; case element::Type_t::u32: result_shape = infer_output_shape<uint32_t>(this, result_et); break;
case element::Type_t::u64: result_shape = infer_output_shape<uint64_t>(this, result_et); break; case element::Type_t::u64: result_shape = infer_output_shape<uint64_t>(this, result_et); break;
case element::Type_t::dynamic: result_shape = PartialShape::dynamic(1); break; case element::Type_t::dynamic: result_shape = PartialShape::dynamic(1); break;
case element::Type_t::u1:
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::boolean: case element::Type_t::boolean:
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
......
...@@ -75,6 +75,7 @@ shared_ptr<Node> op::v0::Max::get_default_value() const ...@@ -75,6 +75,7 @@ shared_ptr<Node> op::v0::Max::get_default_value() const
case element::Type_t::u64: case element::Type_t::u64:
return make_constant_from_string( return make_constant_from_string(
to_string(numeric_limits<uint64_t>::min()), get_element_type(), get_shape()); to_string(numeric_limits<uint64_t>::min()), get_element_type(), get_shape());
case element::Type_t::u1:
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
default: throw runtime_error("Max default value not defined for type"); default: throw runtime_error("Max default value not defined for type");
......
...@@ -75,6 +75,7 @@ shared_ptr<Node> op::v0::Min::get_default_value() const ...@@ -75,6 +75,7 @@ shared_ptr<Node> op::v0::Min::get_default_value() const
case element::Type_t::u64: case element::Type_t::u64:
return make_constant_from_string( return make_constant_from_string(
to_string(numeric_limits<uint64_t>::max()), get_element_type(), get_shape()); to_string(numeric_limits<uint64_t>::max()), get_element_type(), get_shape());
case element::Type_t::u1:
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
default: throw runtime_error("Min default value not defined for type"); default: throw runtime_error("Min default value not defined for type");
......
...@@ -189,6 +189,9 @@ static shared_ptr<op::Constant> ...@@ -189,6 +189,9 @@ static shared_ptr<op::Constant>
NGRAPH_CHECK(false, NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in fold_constant_arithmetic_reduction"); "Encountered 'dynamic' element type in fold_constant_arithmetic_reduction");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_arithmetic_reduction");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
return fold_constant_arithmetic_reduction_helper<char>(constant, reduction_node); return fold_constant_arithmetic_reduction_helper<char>(constant, reduction_node);
case element::Type_t::bf16: case element::Type_t::bf16:
......
...@@ -620,6 +620,9 @@ void pass::ConstantFolding::construct_constant_binary() ...@@ -620,6 +620,9 @@ void pass::ConstantFolding::construct_constant_binary()
NGRAPH_CHECK(false, NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_binary_callback"); "Encountered 'dynamic' element type in constant_binary_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_binary_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = replacement =
fold_constant_binary_helper<char>(a_match, b_match, binary_match, func); fold_constant_binary_helper<char>(a_match, b_match, binary_match, func);
......
...@@ -117,6 +117,9 @@ void pass::ConstantFolding::construct_constant_broadcast() ...@@ -117,6 +117,9 @@ void pass::ConstantFolding::construct_constant_broadcast()
NGRAPH_CHECK(false, NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_broadcast_callback"); "Encountered 'dynamic' element type in constant_broadcast_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_broadcast_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_broadcast<char>(constant_match, broadcast_match, func); replacement = fold_constant_broadcast<char>(constant_match, broadcast_match, func);
break; break;
......
...@@ -79,6 +79,9 @@ void pass::ConstantFolding::construct_constant_concat() ...@@ -79,6 +79,9 @@ void pass::ConstantFolding::construct_constant_concat()
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_concat"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_concat");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_concat");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_concat_helper<char>(concat_node); replacement = fold_constant_concat_helper<char>(concat_node);
break; break;
......
...@@ -57,6 +57,9 @@ shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant> ...@@ -57,6 +57,9 @@ shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant>
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
return fold_constant_convert_helper1<TI, char>(constant, output_element_type); return fold_constant_convert_helper1<TI, char>(constant, output_element_type);
case element::Type_t::bf16: case element::Type_t::bf16:
...@@ -114,6 +117,9 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c ...@@ -114,6 +117,9 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_convert");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
return fold_constant_convert_helper0<char>(constant, output_element_type); return fold_constant_convert_helper0<char>(constant, output_element_type);
case element::Type_t::bf16: case element::Type_t::bf16:
......
...@@ -81,6 +81,9 @@ void pass::ConstantFolding::construct_constant_dyn_broadcast() ...@@ -81,6 +81,9 @@ void pass::ConstantFolding::construct_constant_dyn_broadcast()
NGRAPH_CHECK(false, NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_dyn_broadcast_callback"); "Encountered 'dynamic' element type in constant_dyn_broadcast_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_dyn_broadcast_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_dyn_broadcast<char>( replacement = fold_constant_dyn_broadcast<char>(
constant_arg_match, constant_shape_match, constant_axes_match); constant_arg_match, constant_shape_match, constant_axes_match);
......
...@@ -79,6 +79,9 @@ void pass::ConstantFolding::construct_constant_dyn_reshape() ...@@ -79,6 +79,9 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
NGRAPH_CHECK(false, NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_dyn_reshape_callback"); "Encountered 'dynamic' element type in constant_dyn_reshape_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_dyn_reshape_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_dyn_reshape<char>(constant_data_match, dyn_reshape_match); replacement = fold_constant_dyn_reshape<char>(constant_data_match, dyn_reshape_match);
break; break;
......
...@@ -114,6 +114,9 @@ void pass::ConstantFolding::construct_constant_dyn_slice() ...@@ -114,6 +114,9 @@ void pass::ConstantFolding::construct_constant_dyn_slice()
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_dyn_slice"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_dyn_slice");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_dyn_slice");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = replacement =
fold_constant_dyn_slice<char>(data_node, lb_node, ub_node, strides_node, dyn_slice); fold_constant_dyn_slice<char>(data_node, lb_node, ub_node, strides_node, dyn_slice);
......
...@@ -81,6 +81,7 @@ static shared_ptr<op::Constant> fold_constant_gather(const shared_ptr<op::Consta ...@@ -81,6 +81,7 @@ static shared_ptr<op::Constant> fold_constant_gather(const shared_ptr<op::Consta
case element::Type_t::f64: case element::Type_t::f64:
case element::Type_t::i8: case element::Type_t::i8:
case element::Type_t::i16: case element::Type_t::i16:
case element::Type_t::u1:
case element::Type_t::u8: case element::Type_t::u8:
case element::Type_t::u16: case element::Type_t::u16:
case element::Type_t::u32: case element::Type_t::u32:
...@@ -134,6 +135,9 @@ void pass::ConstantFolding::construct_constant_gather() ...@@ -134,6 +135,9 @@ void pass::ConstantFolding::construct_constant_gather()
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_gather_callback"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_gather_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_gather_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_gather<char>(data, indices, gather); replacement = fold_constant_gather<char>(data, indices, gather);
break; break;
......
...@@ -100,6 +100,9 @@ void pass::ConstantFolding::construct_constant_pad() ...@@ -100,6 +100,9 @@ void pass::ConstantFolding::construct_constant_pad()
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_pad_callback"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_pad_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_pad_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_pad<char>(constant_match, pad_match, func); replacement = fold_constant_pad<char>(constant_match, pad_match, func);
break; break;
......
...@@ -68,6 +68,9 @@ void pass::ConstantFolding::construct_constant_range() ...@@ -68,6 +68,9 @@ void pass::ConstantFolding::construct_constant_range()
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_range_callback"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_range_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_range_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_range<char>(start_node, step_node, range); replacement = fold_constant_range<char>(start_node, step_node, range);
break; break;
......
...@@ -87,6 +87,9 @@ void pass::ConstantFolding::construct_constant_reshape() ...@@ -87,6 +87,9 @@ void pass::ConstantFolding::construct_constant_reshape()
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_reshape_callback"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_reshape_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_reshape_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_reshape<char>(constant_match, reshape_match, func); replacement = fold_constant_reshape<char>(constant_match, reshape_match, func);
break; break;
......
...@@ -52,6 +52,9 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c ...@@ -52,6 +52,9 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_convert");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
return fold_constant_reverse_helper<char>(constant, reversed_axes); return fold_constant_reverse_helper<char>(constant, reversed_axes);
case element::Type_t::bf16: case element::Type_t::bf16:
......
...@@ -72,6 +72,9 @@ void pass::ConstantFolding::construct_constant_select() ...@@ -72,6 +72,9 @@ void pass::ConstantFolding::construct_constant_select()
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_select_callback"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_select_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_select_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_select<char>(selection_node, t_node, f_node, select); replacement = fold_constant_select<char>(selection_node, t_node, f_node, select);
break; break;
......
...@@ -67,6 +67,9 @@ void pass::ConstantFolding::construct_constant_slice() ...@@ -67,6 +67,9 @@ void pass::ConstantFolding::construct_constant_slice()
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_slice"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_slice");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_slice");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_slice<char>(data_node, slice); replacement = fold_constant_slice<char>(data_node, slice);
break; break;
......
...@@ -61,6 +61,9 @@ void pass::ConstantFolding::construct_constant_squeeze() ...@@ -61,6 +61,9 @@ void pass::ConstantFolding::construct_constant_squeeze()
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_squeeze_callback"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_squeeze_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_squeeze_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_squeeze<char>(constant_match, squeeze_match); replacement = fold_constant_squeeze<char>(constant_match, squeeze_match);
break; break;
......
...@@ -75,6 +75,9 @@ void pass::ConstantFolding::construct_constant_transpose() ...@@ -75,6 +75,9 @@ void pass::ConstantFolding::construct_constant_transpose()
NGRAPH_CHECK(false, NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_transpose_callback"); "Encountered 'dynamic' element type in constant_transpose_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_transpose_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_transpose<char>( replacement = fold_constant_transpose<char>(
constant_data_match, constant_perm_match, transpose_match); constant_data_match, constant_perm_match, transpose_match);
......
...@@ -169,6 +169,9 @@ void pass::ConstantFolding::construct_constant_unary() ...@@ -169,6 +169,9 @@ void pass::ConstantFolding::construct_constant_unary()
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_unary_callback"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_unary_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_unary_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_unary<char>(constant_match, unary_match, func); replacement = fold_constant_unary<char>(constant_match, unary_match, func);
break; break;
......
...@@ -62,6 +62,9 @@ void pass::ConstantFolding::construct_constant_unsqueeze() ...@@ -62,6 +62,9 @@ void pass::ConstantFolding::construct_constant_unsqueeze()
NGRAPH_CHECK(false, NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_unsqueeze_callback"); "Encountered 'dynamic' element type in constant_unsqueeze_callback");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_unsqueeze_callback");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_unsqueeze<char>(constant_match, unsqueeze_match); replacement = fold_constant_unsqueeze<char>(constant_match, unsqueeze_match);
break; break;
......
...@@ -428,6 +428,7 @@ void pass::DynElimination::construct_range() ...@@ -428,6 +428,7 @@ void pass::DynElimination::construct_range()
case element::Type_t::u64: case element::Type_t::u64:
replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::u1:
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::boolean: case element::Type_t::boolean:
......
...@@ -112,6 +112,9 @@ namespace ngraph ...@@ -112,6 +112,9 @@ namespace ngraph
NGRAPH_CHECK(false, NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in fold_constant_convert"); "Encountered 'dynamic' element type in fold_constant_convert");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_convert");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
functor = prepare_functor<char>(node, args, out, external_function); functor = prepare_functor<char>(node, args, out, external_function);
break; break;
......
...@@ -1376,6 +1376,7 @@ static void dump_one_kernel_with_type(runtime::cpu::CPU_DebugTracer& debug_trace ...@@ -1376,6 +1376,7 @@ static void dump_one_kernel_with_type(runtime::cpu::CPU_DebugTracer& debug_trace
case element::Type_t::f64: case element::Type_t::f64:
case element::Type_t::i16: case element::Type_t::i16:
case element::Type_t::i64: case element::Type_t::i64:
case element::Type_t::u1:
case element::Type_t::u16: case element::Type_t::u16:
case element::Type_t::u32: case element::Type_t::u32:
case element::Type_t::u64: case element::Type_t::u64:
......
...@@ -439,6 +439,10 @@ bool ngraph::runtime::cpu::pass::CPUConvertLayoutConstantFolding::run_on_functio ...@@ -439,6 +439,10 @@ bool ngraph::runtime::cpu::pass::CPUConvertLayoutConstantFolding::run_on_functio
false, false,
"Encountered 'dynamic' element type in construct_constant_convertlayout"); "Encountered 'dynamic' element type in construct_constant_convertlayout");
break; break;
case element::Type_t::u1:
NGRAPH_CHECK(
false, "Encountered 'u1' element type in construct_constant_convertlayout");
break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_convertlayout_helper<char>( replacement = fold_constant_convertlayout_helper<char>(
m_input, m_convertlayout, input_md, output_md); m_input, m_convertlayout, input_md, output_md);
......
...@@ -230,6 +230,7 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type, ...@@ -230,6 +230,7 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
case element::Type_t::u64: op_engine<uint64_t>(op, out, in); break; case element::Type_t::u64: op_engine<uint64_t>(op, out, in); break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::u1:
case element::Type_t::bf16: case element::Type_t::bf16:
case element::Type_t::f16: case element::Type_t::f16:
ss << "unsupported element type " << type << " op " << op.get_node()->get_name(); ss << "unsupported element type " << type << " op " << op.get_node()->get_name();
......
...@@ -626,6 +626,7 @@ private: ...@@ -626,6 +626,7 @@ private:
break; break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::u1:
case element::Type_t::bf16: case element::Type_t::bf16:
case element::Type_t::f16: case element::Type_t::f16:
ss << "unsupported element type " << type << " op Convert"; ss << "unsupported element type " << type << " op Convert";
......
...@@ -232,6 +232,7 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty ...@@ -232,6 +232,7 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty
case element::Type_t::u64: op_engine<uint64_t>(op, out, in); break; case element::Type_t::u64: op_engine<uint64_t>(op, out, in); break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::u1:
case element::Type_t::bf16: case element::Type_t::bf16:
case element::Type_t::f16: case element::Type_t::f16:
ss << "unsupported element type " << type << " op " << op.get_node()->get_name(); ss << "unsupported element type " << type << " op " << op.get_node()->get_name();
......
...@@ -672,6 +672,7 @@ private: ...@@ -672,6 +672,7 @@ private:
break; break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::u1:
case element::Type_t::bf16: case element::Type_t::bf16:
case element::Type_t::f16: case element::Type_t::f16:
ss << "unsupported element type " << type << " op Convert"; ss << "unsupported element type " << type << " op Convert";
......
...@@ -34,6 +34,7 @@ NGRAPH_API const element::Type element::i8(element::Type_t::i8); ...@@ -34,6 +34,7 @@ NGRAPH_API const element::Type element::i8(element::Type_t::i8);
NGRAPH_API const element::Type element::i16(element::Type_t::i16); NGRAPH_API const element::Type element::i16(element::Type_t::i16);
NGRAPH_API const element::Type element::i32(element::Type_t::i32); NGRAPH_API const element::Type element::i32(element::Type_t::i32);
NGRAPH_API const element::Type element::i64(element::Type_t::i64); NGRAPH_API const element::Type element::i64(element::Type_t::i64);
NGRAPH_API const element::Type element::u1(element::Type_t::u1);
NGRAPH_API const element::Type element::u8(element::Type_t::u8); NGRAPH_API const element::Type element::u8(element::Type_t::u8);
NGRAPH_API const element::Type element::u16(element::Type_t::u16); NGRAPH_API const element::Type element::u16(element::Type_t::u16);
NGRAPH_API const element::Type element::u32(element::Type_t::u32); NGRAPH_API const element::Type element::u32(element::Type_t::u32);
...@@ -80,6 +81,7 @@ static const map<element::Type_t, const TypeInfo>& get_type_info_map() ...@@ -80,6 +81,7 @@ static const map<element::Type_t, const TypeInfo>& get_type_info_map()
{element::Type_t::i16, TypeInfo(16, false, true, false, "int16_t", "i16")}, {element::Type_t::i16, TypeInfo(16, false, true, false, "int16_t", "i16")},
{element::Type_t::i32, TypeInfo(32, false, true, true, "int32_t", "i32")}, {element::Type_t::i32, TypeInfo(32, false, true, true, "int32_t", "i32")},
{element::Type_t::i64, TypeInfo(64, false, true, false, "int64_t", "i64")}, {element::Type_t::i64, TypeInfo(64, false, true, false, "int64_t", "i64")},
{element::Type_t::u1, TypeInfo(1, false, false, false, "uint8_t", "u1")},
{element::Type_t::u8, TypeInfo(8, false, false, true, "uint8_t", "u8")}, {element::Type_t::u8, TypeInfo(8, false, false, true, "uint8_t", "u8")},
{element::Type_t::u16, TypeInfo(16, false, false, false, "uint16_t", "u16")}, {element::Type_t::u16, TypeInfo(16, false, false, false, "uint16_t", "u16")},
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")}, {element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")},
...@@ -100,6 +102,7 @@ std::vector<const element::Type*> element::Type::get_known_types() ...@@ -100,6 +102,7 @@ std::vector<const element::Type*> element::Type::get_known_types()
&element::i16, &element::i16,
&element::i32, &element::i32,
&element::i64, &element::i64,
&element::u1,
&element::u8, &element::u8,
&element::u16, &element::u16,
&element::u32, &element::u32,
......
...@@ -49,6 +49,7 @@ namespace ngraph ...@@ -49,6 +49,7 @@ namespace ngraph
i16, i16,
i32, i32,
i64, i64,
u1,
u8, u8,
u16, u16,
u32, u32,
...@@ -139,6 +140,7 @@ namespace ngraph ...@@ -139,6 +140,7 @@ namespace ngraph
extern NGRAPH_API const Type i16; extern NGRAPH_API const Type i16;
extern NGRAPH_API const Type i32; extern NGRAPH_API const Type i32;
extern NGRAPH_API const Type i64; extern NGRAPH_API const Type i64;
extern NGRAPH_API const Type u1;
extern NGRAPH_API const Type u8; extern NGRAPH_API const Type u8;
extern NGRAPH_API const Type u16; extern NGRAPH_API const Type u16;
extern NGRAPH_API const Type u32; extern NGRAPH_API const Type u32;
......
...@@ -100,6 +100,7 @@ void random_init(shared_ptr<runtime::Tensor> tensor) ...@@ -100,6 +100,7 @@ void random_init(shared_ptr<runtime::Tensor> tensor)
case element::Type_t::u64: init_int_tensor<uint64_t>(tensor, 0, 1); break; case element::Type_t::u64: init_int_tensor<uint64_t>(tensor, 0, 1); break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::u1:
case element::Type_t::bf16: case element::Type_t::bf16:
case element::Type_t::f16: case element::Type_t::f16:
default: throw runtime_error("unsupported type"); default: throw runtime_error("unsupported type");
......
...@@ -54,6 +54,7 @@ set(SRC ...@@ -54,6 +54,7 @@ set(SRC
constant_folding.cpp constant_folding.cpp
concat_fusion.cpp concat_fusion.cpp
control_dependencies.cpp control_dependencies.cpp
convert_u1_to_string.cpp
coordinate.cpp coordinate.cpp
copy.cpp copy.cpp
cpio.cpp cpio.cpp
......
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/all_close_f.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(convert_u1_to_string, convert_u1_to_string)
{
vector<uint8_t> values{171, 16};
auto constant = make_shared<op::Constant>(element::u1, Shape{12}, &values[0]);
vector<string> ref{"1", "0", "1", "0", "1", "0", "1", "1", "0", "0", "0", "1"};
for (size_t i = 0; i < 12; ++i)
{
ASSERT_EQ(constant->convert_value_to_string(i), ref[i]);
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment