Commit e9dd6087 authored by Gleb Kazantaev's avatar Gleb Kazantaev Committed by Jayaram Bobba

Added u1 precision for binary weights (#3914)

* Added U1 precision for binary weights

* Handle switch cases with u1 type

* Fixed code style

* Added convert_to_string support for u1 type

* Use real C type  for u1 type.
Co-Authored-By: 's avatarRobert Kimball <robert.kimball@intel.com>
parent 89b3019e
......@@ -238,6 +238,7 @@ mlir::Type NgDialectConversionPass::getMlirType(const element::Type& type)
{
case ngraph::element::Type_t::undefined:
case ngraph::element::Type_t::dynamic:
case ngraph::element::Type_t::u1:
default: NGRAPH_CHECK(false, "MLIR: Unsupported NGraph types"); break;
case ngraph::element::Type_t::bf16: return mlir::NGFloatType::getBF16(m_context);
case ngraph::element::Type_t::f16: return mlir::NGFloatType::getF16(m_context);
......
......@@ -94,6 +94,8 @@ namespace ngraph
throw ngraph_error("make_constant: Unsupported element type 'dynamic'");
case element::Type_t::boolean:
throw ngraph_error("make_constant: Unsupported element type 'boolean'");
case element::Type_t::u1:
throw ngraph_error("make_constant: Unsupported element type 'u1'");
case element::Type_t::undefined:
throw ngraph_error("make_constant: Unsupported element type 'undefined'");
}
......
......@@ -196,6 +196,7 @@ namespace ngraph
case element::Type_t::u64: m_type = MPI_UNSIGNED_LONG; break;
case element::Type_t::bf16:
case element::Type_t::f16:
case element::Type_t::u1:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
......
......@@ -114,18 +114,7 @@ void op::v1::BinaryConvolution::validate_and_infer_types()
}
}
element::Type result_et;
PartialShape result_shape;
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, data_batch_et, filters_et),
"Element types for data batch and filters do not match (data batch element type: ",
data_batch_et,
", filters element type: ",
filters_et,
").");
result_shape =
infer_convolution_forward(this,
data_batch_shape,
......@@ -136,7 +125,7 @@ void op::v1::BinaryConvolution::validate_and_infer_types()
m_strides,
m_dilations);
set_output_type(0, result_et, result_shape);
set_output_type(0, data_batch_et, result_shape);
}
shared_ptr<Node> op::v1::BinaryConvolution::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -74,6 +74,9 @@ string op::Constant::convert_value_to_string(size_t index) const
case element::Type_t::i16: rc = to_string(get_vector<int16_t>()[index]); break;
case element::Type_t::i32: rc = to_string(get_vector<int32_t>()[index]); break;
case element::Type_t::i64: rc = to_string(get_vector<int64_t>()[index]); break;
case element::Type_t::u1:
rc = to_string((get_vector<uint8_t>()[index / 8] >> (7 - (index % 8))) & 1);
break;
case element::Type_t::u8: rc = to_string(get_vector<uint8_t>()[index]); break;
case element::Type_t::u16: rc = to_string(get_vector<uint16_t>()[index]); break;
case element::Type_t::u32: rc = to_string(get_vector<uint32_t>()[index]); break;
......@@ -176,6 +179,7 @@ vector<string> op::Constant::get_value_strings() const
rc.push_back(to_string(value));
}
break;
case element::Type_t::u1: throw runtime_error("unsupported type");
case element::Type_t::undefined: throw runtime_error("unsupported type");
case element::Type_t::dynamic: throw runtime_error("unsupported type");
}
......@@ -323,6 +327,7 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const
rc = test_bitwise_identical<uint64_t>(this);
break;
}
case element::Type_t::u1:
case element::Type_t::undefined:
case element::Type_t::dynamic: break;
}
......
......@@ -338,6 +338,7 @@ namespace ngraph
case element::Type_t::u64:
write_buffer<uint64_t, T>(target, source, target_element_count);
break;
case element::Type_t::u1: throw std::runtime_error("unsupported type");
case element::Type_t::undefined: throw std::runtime_error("unsupported type");
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
......
......@@ -230,6 +230,7 @@ void op::Range::validate_and_infer_types()
case element::Type_t::u32: result_shape = infer_output_shape<uint32_t>(this, result_et); break;
case element::Type_t::u64: result_shape = infer_output_shape<uint64_t>(this, result_et); break;
case element::Type_t::dynamic: result_shape = PartialShape::dynamic(1); break;
case element::Type_t::u1:
case element::Type_t::undefined:
case element::Type_t::boolean:
NODE_VALIDATION_CHECK(
......
......@@ -75,6 +75,7 @@ shared_ptr<Node> op::v0::Max::get_default_value() const
case element::Type_t::u64:
return make_constant_from_string(
to_string(numeric_limits<uint64_t>::min()), get_element_type(), get_shape());
case element::Type_t::u1:
case element::Type_t::undefined:
case element::Type_t::dynamic:
default: throw runtime_error("Max default value not defined for type");
......
......@@ -75,6 +75,7 @@ shared_ptr<Node> op::v0::Min::get_default_value() const
case element::Type_t::u64:
return make_constant_from_string(
to_string(numeric_limits<uint64_t>::max()), get_element_type(), get_shape());
case element::Type_t::u1:
case element::Type_t::undefined:
case element::Type_t::dynamic:
default: throw runtime_error("Min default value not defined for type");
......
......@@ -189,6 +189,9 @@ static shared_ptr<op::Constant>
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in fold_constant_arithmetic_reduction");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_arithmetic_reduction");
break;
case element::Type_t::boolean:
return fold_constant_arithmetic_reduction_helper<char>(constant, reduction_node);
case element::Type_t::bf16:
......
......@@ -620,6 +620,9 @@ void pass::ConstantFolding::construct_constant_binary()
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_binary_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_binary_callback");
break;
case element::Type_t::boolean:
replacement =
fold_constant_binary_helper<char>(a_match, b_match, binary_match, func);
......
......@@ -117,6 +117,9 @@ void pass::ConstantFolding::construct_constant_broadcast()
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_broadcast_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_broadcast_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_broadcast<char>(constant_match, broadcast_match, func);
break;
......
......@@ -79,6 +79,9 @@ void pass::ConstantFolding::construct_constant_concat()
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_concat");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_concat");
break;
case element::Type_t::boolean:
replacement = fold_constant_concat_helper<char>(concat_node);
break;
......
......@@ -57,6 +57,9 @@ shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant>
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert");
break;
case element::Type_t::boolean:
return fold_constant_convert_helper1<TI, char>(constant, output_element_type);
case element::Type_t::bf16:
......@@ -114,6 +117,9 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_convert");
break;
case element::Type_t::boolean:
return fold_constant_convert_helper0<char>(constant, output_element_type);
case element::Type_t::bf16:
......
......@@ -81,6 +81,9 @@ void pass::ConstantFolding::construct_constant_dyn_broadcast()
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_dyn_broadcast_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_dyn_broadcast_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_dyn_broadcast<char>(
constant_arg_match, constant_shape_match, constant_axes_match);
......
......@@ -79,6 +79,9 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_dyn_reshape_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_dyn_reshape_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_dyn_reshape<char>(constant_data_match, dyn_reshape_match);
break;
......
......@@ -114,6 +114,9 @@ void pass::ConstantFolding::construct_constant_dyn_slice()
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_dyn_slice");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_dyn_slice");
break;
case element::Type_t::boolean:
replacement =
fold_constant_dyn_slice<char>(data_node, lb_node, ub_node, strides_node, dyn_slice);
......
......@@ -81,6 +81,7 @@ static shared_ptr<op::Constant> fold_constant_gather(const shared_ptr<op::Consta
case element::Type_t::f64:
case element::Type_t::i8:
case element::Type_t::i16:
case element::Type_t::u1:
case element::Type_t::u8:
case element::Type_t::u16:
case element::Type_t::u32:
......@@ -134,6 +135,9 @@ void pass::ConstantFolding::construct_constant_gather()
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_gather_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_gather_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_gather<char>(data, indices, gather);
break;
......
......@@ -100,6 +100,9 @@ void pass::ConstantFolding::construct_constant_pad()
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_pad_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_pad_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_pad<char>(constant_match, pad_match, func);
break;
......
......@@ -68,6 +68,9 @@ void pass::ConstantFolding::construct_constant_range()
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_range_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_range_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_range<char>(start_node, step_node, range);
break;
......
......@@ -87,6 +87,9 @@ void pass::ConstantFolding::construct_constant_reshape()
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_reshape_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_reshape_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_reshape<char>(constant_match, reshape_match, func);
break;
......
......@@ -52,6 +52,9 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_convert");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_convert");
break;
case element::Type_t::boolean:
return fold_constant_reverse_helper<char>(constant, reversed_axes);
case element::Type_t::bf16:
......
......@@ -72,6 +72,9 @@ void pass::ConstantFolding::construct_constant_select()
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_select_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_select_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_select<char>(selection_node, t_node, f_node, select);
break;
......
......@@ -67,6 +67,9 @@ void pass::ConstantFolding::construct_constant_slice()
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_slice");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_slice");
break;
case element::Type_t::boolean:
replacement = fold_constant_slice<char>(data_node, slice);
break;
......
......@@ -61,6 +61,9 @@ void pass::ConstantFolding::construct_constant_squeeze()
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_squeeze_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_squeeze_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_squeeze<char>(constant_match, squeeze_match);
break;
......
......@@ -75,6 +75,9 @@ void pass::ConstantFolding::construct_constant_transpose()
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_transpose_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_transpose_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_transpose<char>(
constant_data_match, constant_perm_match, transpose_match);
......
......@@ -169,6 +169,9 @@ void pass::ConstantFolding::construct_constant_unary()
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_unary_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_unary_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_unary<char>(constant_match, unary_match, func);
break;
......
......@@ -62,6 +62,9 @@ void pass::ConstantFolding::construct_constant_unsqueeze()
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_unsqueeze_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_unsqueeze_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_unsqueeze<char>(constant_match, unsqueeze_match);
break;
......
......@@ -428,6 +428,7 @@ void pass::DynElimination::construct_range()
case element::Type_t::u64:
replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::u1:
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::boolean:
......
......@@ -112,6 +112,9 @@ namespace ngraph
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in fold_constant_convert");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in fold_constant_convert");
break;
case element::Type_t::boolean:
functor = prepare_functor<char>(node, args, out, external_function);
break;
......
......@@ -1376,6 +1376,7 @@ static void dump_one_kernel_with_type(runtime::cpu::CPU_DebugTracer& debug_trace
case element::Type_t::f64:
case element::Type_t::i16:
case element::Type_t::i64:
case element::Type_t::u1:
case element::Type_t::u16:
case element::Type_t::u32:
case element::Type_t::u64:
......
......@@ -439,6 +439,10 @@ bool ngraph::runtime::cpu::pass::CPUConvertLayoutConstantFolding::run_on_functio
false,
"Encountered 'dynamic' element type in construct_constant_convertlayout");
break;
case element::Type_t::u1:
NGRAPH_CHECK(
false, "Encountered 'u1' element type in construct_constant_convertlayout");
break;
case element::Type_t::boolean:
replacement = fold_constant_convertlayout_helper<char>(
m_input, m_convertlayout, input_md, output_md);
......
......@@ -230,6 +230,7 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
case element::Type_t::u64: op_engine<uint64_t>(op, out, in); break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::u1:
case element::Type_t::bf16:
case element::Type_t::f16:
ss << "unsupported element type " << type << " op " << op.get_node()->get_name();
......
......@@ -626,6 +626,7 @@ private:
break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::u1:
case element::Type_t::bf16:
case element::Type_t::f16:
ss << "unsupported element type " << type << " op Convert";
......
......@@ -232,6 +232,7 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty
case element::Type_t::u64: op_engine<uint64_t>(op, out, in); break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::u1:
case element::Type_t::bf16:
case element::Type_t::f16:
ss << "unsupported element type " << type << " op " << op.get_node()->get_name();
......
......@@ -672,6 +672,7 @@ private:
break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::u1:
case element::Type_t::bf16:
case element::Type_t::f16:
ss << "unsupported element type " << type << " op Convert";
......
......@@ -34,6 +34,7 @@ NGRAPH_API const element::Type element::i8(element::Type_t::i8);
NGRAPH_API const element::Type element::i16(element::Type_t::i16);
NGRAPH_API const element::Type element::i32(element::Type_t::i32);
NGRAPH_API const element::Type element::i64(element::Type_t::i64);
NGRAPH_API const element::Type element::u1(element::Type_t::u1);
NGRAPH_API const element::Type element::u8(element::Type_t::u8);
NGRAPH_API const element::Type element::u16(element::Type_t::u16);
NGRAPH_API const element::Type element::u32(element::Type_t::u32);
......@@ -80,6 +81,7 @@ static const map<element::Type_t, const TypeInfo>& get_type_info_map()
{element::Type_t::i16, TypeInfo(16, false, true, false, "int16_t", "i16")},
{element::Type_t::i32, TypeInfo(32, false, true, true, "int32_t", "i32")},
{element::Type_t::i64, TypeInfo(64, false, true, false, "int64_t", "i64")},
{element::Type_t::u1, TypeInfo(1, false, false, false, "uint8_t", "u1")},
{element::Type_t::u8, TypeInfo(8, false, false, true, "uint8_t", "u8")},
{element::Type_t::u16, TypeInfo(16, false, false, false, "uint16_t", "u16")},
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")},
......@@ -100,6 +102,7 @@ std::vector<const element::Type*> element::Type::get_known_types()
&element::i16,
&element::i32,
&element::i64,
&element::u1,
&element::u8,
&element::u16,
&element::u32,
......
......@@ -49,6 +49,7 @@ namespace ngraph
i16,
i32,
i64,
u1,
u8,
u16,
u32,
......@@ -139,6 +140,7 @@ namespace ngraph
extern NGRAPH_API const Type i16;
extern NGRAPH_API const Type i32;
extern NGRAPH_API const Type i64;
extern NGRAPH_API const Type u1;
extern NGRAPH_API const Type u8;
extern NGRAPH_API const Type u16;
extern NGRAPH_API const Type u32;
......
......@@ -100,6 +100,7 @@ void random_init(shared_ptr<runtime::Tensor> tensor)
case element::Type_t::u64: init_int_tensor<uint64_t>(tensor, 0, 1); break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::u1:
case element::Type_t::bf16:
case element::Type_t::f16:
default: throw runtime_error("unsupported type");
......
......@@ -54,6 +54,7 @@ set(SRC
constant_folding.cpp
concat_fusion.cpp
control_dependencies.cpp
convert_u1_to_string.cpp
coordinate.cpp
copy.cpp
cpio.cpp
......
//*****************************************************************************
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/all_close_f.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(convert_u1_to_string, convert_u1_to_string)
{
vector<uint8_t> values{171, 16};
auto constant = make_shared<op::Constant>(element::u1, Shape{12}, &values[0]);
vector<string> ref{"1", "0", "1", "0", "1", "0", "1", "1", "0", "0", "0", "1"};
for (size_t i = 0; i < 12; ++i)
{
ASSERT_EQ(constant->convert_value_to_string(i), ref[i]);
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment