Unverified Commit 5d80f203 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Implement ngraph::element::Type as a wrapper around an enum (#2120)

* cleanup and add enum

* new type working

* enum works

* use type enum

* cleanup

* fix errant past to source file

* fix type

* safely construct the type map

* fix get_type_info_map return type
parent 3f2cd153
......@@ -235,53 +235,23 @@ void runtime::interpreter::INTBackend::generate_calls(const element::Type& type,
const vector<const void*>& inputs,
FunctionInstance& instance)
{
if (type == element::boolean)
stringstream ss;
switch (type.get_type_enum())
{
op_engine<char>(op, outputs, inputs, instance);
}
else if (type == element::f32)
{
op_engine<float>(op, outputs, inputs, instance);
}
else if (type == element::f64)
{
op_engine<double>(op, outputs, inputs, instance);
}
else if (type == element::i8)
{
op_engine<int8_t>(op, outputs, inputs, instance);
}
else if (type == element::i16)
{
op_engine<int16_t>(op, outputs, inputs, instance);
}
else if (type == element::i32)
{
op_engine<int32_t>(op, outputs, inputs, instance);
}
else if (type == element::i64)
{
op_engine<int64_t>(op, outputs, inputs, instance);
}
else if (type == element::u8)
{
op_engine<uint8_t>(op, outputs, inputs, instance);
}
else if (type == element::u16)
{
op_engine<uint16_t>(op, outputs, inputs, instance);
}
else if (type == element::u32)
{
op_engine<uint32_t>(op, outputs, inputs, instance);
}
else if (type == element::u64)
{
op_engine<uint64_t>(op, outputs, inputs, instance);
}
else
{
stringstream ss;
case element::Type_t::boolean: op_engine<char>(op, outputs, inputs, instance); break;
case element::Type_t::f32: op_engine<float>(op, outputs, inputs, instance); break;
case element::Type_t::f64: op_engine<double>(op, outputs, inputs, instance); break;
case element::Type_t::i8: op_engine<int8_t>(op, outputs, inputs, instance); break;
case element::Type_t::i16: op_engine<int16_t>(op, outputs, inputs, instance); break;
case element::Type_t::i32: op_engine<int32_t>(op, outputs, inputs, instance); break;
case element::Type_t::i64: op_engine<int64_t>(op, outputs, inputs, instance); break;
case element::Type_t::u8: op_engine<uint8_t>(op, outputs, inputs, instance); break;
case element::Type_t::u16: op_engine<uint16_t>(op, outputs, inputs, instance); break;
case element::Type_t::u32: op_engine<uint32_t>(op, outputs, inputs, instance); break;
case element::Type_t::u64: op_engine<uint64_t>(op, outputs, inputs, instance); break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::bf16:
ss << "unsupported element type " << type << " op " << op.get_node().get_name();
throw ngraph_error(ss.str());
}
......
......@@ -493,65 +493,57 @@ private:
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
element::Type type = node.get_element_type();
std::stringstream ss;
size_t element_count = shape_size(node.get_output_shape(0));
if (type == element::boolean)
switch (type.get_type_enum())
{
case element::Type_t::boolean:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<char*>(out[0]), element_count);
}
else if (type == element::f32)
{
break;
case element::Type_t::f32:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<float*>(out[0]), element_count);
}
else if (type == element::f64)
{
break;
case element::Type_t::f64:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<double*>(out[0]), element_count);
}
else if (type == element::i8)
{
break;
case element::Type_t::i8:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int8_t*>(out[0]), element_count);
}
else if (type == element::i16)
{
break;
case element::Type_t::i16:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int16_t*>(out[0]), element_count);
}
else if (type == element::i32)
{
break;
case element::Type_t::i32:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int32_t*>(out[0]), element_count);
}
else if (type == element::i64)
{
break;
case element::Type_t::i64:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<int64_t*>(out[0]), element_count);
}
else if (type == element::u8)
{
break;
case element::Type_t::u8:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint8_t*>(out[0]), element_count);
}
else if (type == element::u16)
{
break;
case element::Type_t::u16:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint16_t*>(out[0]), element_count);
}
else if (type == element::u32)
{
break;
case element::Type_t::u32:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint32_t*>(out[0]), element_count);
}
else if (type == element::u64)
{
break;
case element::Type_t::u64:
reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<uint64_t*>(out[0]), element_count);
}
else
{
std::stringstream ss;
break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::bf16:
ss << "unsupported element type " << type << " op Convert";
throw std::runtime_error(ss.str());
}
......
......@@ -16,28 +16,73 @@
#include <cmath>
#include <iostream>
#include <map>
#include "ngraph/log.hpp"
#include "ngraph/type/element_type.hpp"
using namespace ngraph;
using namespace std;
NGRAPH_API const element::Type element::dynamic(0, false, false, false, "dynamic");
NGRAPH_API const element::Type element::boolean(8, false, true, false, "char");
NGRAPH_API const element::Type element::bf16(16, true, true, false, "bfloat16");
NGRAPH_API const element::Type element::f32(32, true, true, false, "float");
NGRAPH_API const element::Type element::f64(64, true, true, false, "double");
NGRAPH_API const element::Type element::i8(8, false, true, true, "int8_t");
NGRAPH_API const element::Type element::i16(16, false, true, false, "int16_t");
NGRAPH_API const element::Type element::i32(32, false, true, true, "int32_t");
NGRAPH_API const element::Type element::i64(64, false, true, false, "int64_t");
NGRAPH_API const element::Type element::u8(8, false, false, true, "uint8_t");
NGRAPH_API const element::Type element::u16(16, false, false, false, "uint16_t");
NGRAPH_API const element::Type element::u32(32, false, false, false, "uint32_t");
NGRAPH_API const element::Type element::u64(64, false, false, false, "uint64_t");
NGRAPH_API const element::Type element::dynamic(element::Type_t::dynamic);
NGRAPH_API const element::Type element::boolean(element::Type_t::boolean);
NGRAPH_API const element::Type element::bf16(element::Type_t::bf16);
NGRAPH_API const element::Type element::f32(element::Type_t::f32);
NGRAPH_API const element::Type element::f64(element::Type_t::f64);
NGRAPH_API const element::Type element::i8(element::Type_t::i8);
NGRAPH_API const element::Type element::i16(element::Type_t::i16);
NGRAPH_API const element::Type element::i32(element::Type_t::i32);
NGRAPH_API const element::Type element::i64(element::Type_t::i64);
NGRAPH_API const element::Type element::u8(element::Type_t::u8);
NGRAPH_API const element::Type element::u16(element::Type_t::u16);
NGRAPH_API const element::Type element::u32(element::Type_t::u32);
NGRAPH_API const element::Type element::u64(element::Type_t::u64);
class TypeInfo
{
public:
TypeInfo(
size_t bitwidth, bool is_real, bool is_signed, bool is_quantized, const std::string& cname)
: m_bitwidth{bitwidth}
, m_is_real{is_real}
, m_is_signed{is_signed}
, m_is_quantized{is_quantized}
, m_cname{cname}
{
}
size_t m_bitwidth;
bool m_is_real;
bool m_is_signed;
bool m_is_quantized;
std::string m_cname;
};
static const map<element::Type_t, const TypeInfo>& get_type_info_map()
{
static map<element::Type_t, const TypeInfo> s_type_info_map{
{element::Type_t::undefined,
TypeInfo(std::numeric_limits<size_t>::max(), false, false, false, "undefined")},
{element::Type_t::dynamic, TypeInfo(0, false, false, false, "dynamic")},
{element::Type_t::boolean, TypeInfo(8, false, true, false, "char")},
{element::Type_t::bf16, TypeInfo(16, true, true, false, "bfloat16")},
{element::Type_t::f32, TypeInfo(32, true, true, false, "float")},
{element::Type_t::f64, TypeInfo(64, true, true, false, "double")},
{element::Type_t::i8, TypeInfo(8, false, true, true, "int8_t")},
{element::Type_t::i16, TypeInfo(16, false, true, false, "int16_t")},
{element::Type_t::i32, TypeInfo(32, false, true, true, "int32_t")},
{element::Type_t::i64, TypeInfo(64, false, true, false, "int64_t")},
{element::Type_t::u8, TypeInfo(8, false, false, true, "uint8_t")},
{element::Type_t::u16, TypeInfo(16, false, false, false, "uint16_t")},
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t")},
{element::Type_t::u64, TypeInfo(64, false, false, false, "uint64_t")},
};
return s_type_info_map;
};
std::vector<const element::Type*> element::Type::get_known_types()
{
std::vector<const element::Type*> rc = {&element::boolean,
std::vector<const element::Type*> rc = {&element::dynamic,
&element::boolean,
&element::bf16,
&element::f32,
&element::f64,
......@@ -54,63 +99,42 @@ std::vector<const element::Type*> element::Type::get_known_types()
element::Type::Type(
size_t bitwidth, bool is_real, bool is_signed, bool is_quantized, const std::string& cname)
: m_bitwidth{bitwidth}
, m_is_real{is_real}
, m_is_signed{is_signed}
, m_is_quantized{is_quantized}
, m_cname{cname}
{
}
element::Type& element::Type::operator=(const element::Type& t)
{
m_bitwidth = t.m_bitwidth;
m_is_real = t.m_is_real;
m_is_signed = t.m_is_signed;
m_is_quantized = t.m_is_quantized;
m_cname = t.m_cname;
return *this;
for (const pair<element::Type_t, TypeInfo>& t : get_type_info_map())
{
const TypeInfo& info = t.second;
if (bitwidth == info.m_bitwidth && is_real == info.m_is_real &&
is_signed == info.m_is_signed && is_quantized == info.m_is_quantized)
{
m_type = t.first;
return;
}
}
}
const std::string& element::Type::c_type_string() const
{
return m_cname;
return get_type_info_map().at(m_type).m_cname;
}
bool element::Type::operator==(const element::Type& other) const
{
return m_bitwidth == other.m_bitwidth && m_is_real == other.m_is_real &&
m_is_signed == other.m_is_signed && m_is_quantized == other.m_is_quantized &&
m_cname == other.m_cname;
return m_type == other.m_type;
}
bool element::Type::operator<(const Type& other) const
{
size_t v1 = m_bitwidth << 3;
v1 |= static_cast<size_t>(m_is_real ? 4 : 0);
v1 |= static_cast<size_t>(m_is_signed ? 2 : 0);
v1 |= static_cast<size_t>(m_is_quantized ? 1 : 0);
size_t v2 = other.m_bitwidth << 3;
v2 |= static_cast<size_t>(other.m_is_real ? 4 : 0);
v2 |= static_cast<size_t>(other.m_is_signed ? 2 : 0);
v2 |= static_cast<size_t>(other.m_is_quantized ? 1 : 0);
return v1 < v2;
return m_type < other.m_type;
}
size_t element::Type::size() const
{
return std::ceil(static_cast<float>(m_bitwidth) / 8.0f);
return std::ceil(static_cast<float>(bitwidth()) / 8.0f);
}
size_t element::Type::hash() const
{
size_t h1 = std::hash<size_t>{}(m_bitwidth);
size_t h2 = std::hash<bool>{}(m_is_real);
size_t h3 = std::hash<bool>{}(m_is_signed);
size_t h4 = std::hash<bool>{}(m_is_quantized);
return h1 ^ ((h2 ^ ((h3 ^ (h4 << 1)) << 1)) << 1);
return static_cast<size_t>(m_type);
}
namespace ngraph
......@@ -120,7 +144,7 @@ namespace ngraph
template <>
const Type& from<char>()
{
return boolean;
return element::boolean;
}
template <>
const Type& from<bool>()
......@@ -187,12 +211,12 @@ namespace ngraph
std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
{
out << "element::Type{" << obj.m_bitwidth << ", " << obj.m_is_real << ", " << obj.m_is_signed
<< ", " << obj.m_is_quantized << ", \"" << obj.m_cname << "\"}";
out << "element::Type{" << obj.bitwidth() << ", " << obj.is_real() << ", " << obj.is_signed()
<< ", " << obj.is_quantized() << ", \"" << obj.c_type_string() << "\"}";
return out;
}
bool element::Type::compatible(element::Type t) const
bool element::Type::compatible(const element::Type& t) const
{
return (is_dynamic() || t.is_dynamic() || *this == t);
}
......@@ -222,5 +246,25 @@ bool element::Type::merge(element::Type& dst, const element::Type& t1, const ele
bool element::Type::is_static() const
{
return (*this != dynamic);
return get_type_info_map().at(m_type).m_bitwidth != 0;
}
bool element::Type::is_real() const
{
return get_type_info_map().at(m_type).m_is_real;
}
bool element::Type::is_signed() const
{
return get_type_info_map().at(m_type).m_is_signed;
}
bool element::Type::is_quantized() const
{
return get_type_info_map().at(m_type).m_is_quantized;
}
size_t element::Type::bitwidth() const
{
return get_type_info_map().at(m_type).m_bitwidth;
}
......@@ -21,6 +21,7 @@
#pragma once
#include <iostream>
#include <limits>
#include <memory>
#include <string>
#include <vector>
......@@ -33,39 +34,79 @@ namespace ngraph
{
namespace element
{
enum class Type_t
{
undefined,
dynamic,
boolean,
bf16,
f32,
f64,
i8,
i16,
i32,
i64,
u8,
u16,
u32,
u64
};
class Type;
extern const Type dynamic;
extern const Type boolean;
extern const Type bf16;
extern const Type f32;
extern const Type f64;
extern const Type i8;
extern const Type i16;
extern const Type i32;
extern const Type i64;
extern const Type u8;
extern const Type u16;
extern const Type u32;
extern const Type u64;
class Type
{
public:
Type() {}
Type()
: m_type{element::Type_t::undefined}
{
}
Type(const Type&) = default;
Type(const Type_t t)
: m_type{t}
{
}
Type(size_t bitwidth,
bool is_real,
bool is_signed,
bool is_quantized,
const std::string& cname);
Type& operator=(const Type&);
virtual ~Type() {}
~Type() {}
Type& operator=(const Type&) = default;
Type_t get_type_enum() const { return m_type; }
const std::string& c_type_string() const;
size_t size() const;
size_t hash() const;
bool is_static() const;
bool is_dynamic() const { return !is_static(); }
bool is_real() const { return m_is_real; }
bool is_signed() const { return m_is_signed; }
bool is_quantized() const { return m_is_quantized; }
size_t bitwidth() const { return m_bitwidth; }
bool is_real() const;
bool is_signed() const;
bool is_quantized() const;
size_t bitwidth() const;
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
bool operator<(const Type& other) const;
friend std::ostream& operator<<(std::ostream&, const Type&);
static std::vector<const Type*> get_known_types();
/// Returns true if the type is floating point, else false.
bool get_is_real() const { return m_is_real; }
/// \brief Checks whether this element type is merge-compatible with `t`.
/// \param t The element type to compare this element type to.
/// \return `true` if this element type is compatible with `t`, else `false`.
bool compatible(element::Type t) const;
bool compatible(const element::Type& t) const;
/// \brief Merges two element types t1 and t2, writing the result into dst and
/// returning true if successful, else returning false.
......@@ -88,11 +129,7 @@ namespace ngraph
static bool merge(element::Type& dst, const element::Type& t1, const element::Type& t2);
private:
size_t m_bitwidth{0};
bool m_is_real{false};
bool m_is_signed{false};
bool m_is_quantized{false};
std::string m_cname{"dynamic"};
Type_t m_type;
};
extern NGRAPH_API const Type dynamic;
......
......@@ -45,46 +45,6 @@ TEST(element_type, mapable)
test_map.insert({element::f32, "float"});
}
TEST(element_type, size)
{
{
element::Type t1{1, false, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{2, false, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{3, false, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{4, false, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{5, false, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{6, false, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{7, false, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{8, false, false, false, ""};
EXPECT_EQ(1, t1.size());
}
{
element::Type t1{9, false, false, false, ""};
EXPECT_EQ(2, t1.size());
}
}
TEST(element_type, merge_both_dynamic)
{
element::Type t;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment