Commit dc45b9db authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Add numeric_limits for bfloat16 type (#2805)

* Add numeric_limits for bfloat16 type

* fix constexpr
parent 5c5690db
...@@ -42,8 +42,6 @@ using namespace ngraph; ...@@ -42,8 +42,6 @@ using namespace ngraph;
static_assert(sizeof(bfloat16) == 2, "class bfloat16 must be exactly 2 bytes"); static_assert(sizeof(bfloat16) == 2, "class bfloat16 must be exactly 2 bytes");
uint16_t bfloat16::BF16_NAN_VALUE = 0x7FC0;
bool float_isnan(const float& x) bool float_isnan(const float& x)
{ {
return std::isnan(x); return std::isnan(x);
...@@ -110,11 +108,6 @@ bfloat16::operator float() const ...@@ -110,11 +108,6 @@ bfloat16::operator float() const
return *f; return *f;
} }
bfloat16::operator double() const
{
return static_cast<float>(m_value);
}
uint16_t bfloat16::to_bits() const uint16_t bfloat16::to_bits() const
{ {
return m_value; return m_value;
......
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
class bfloat16 class bfloat16
{ {
public: public:
bfloat16() constexpr bfloat16()
: m_value{0} : m_value{0}
{ {
} }
...@@ -59,11 +59,10 @@ namespace ngraph ...@@ -59,11 +59,10 @@ namespace ngraph
bool operator>(const bfloat16& other) const; bool operator>(const bfloat16& other) const;
bool operator>=(const bfloat16& other) const; bool operator>=(const bfloat16& other) const;
operator float() const; operator float() const;
operator double() const;
static std::vector<float> to_float_vector(const std::vector<bfloat16>&); static std::vector<float> to_float_vector(const std::vector<bfloat16>&);
static std::vector<bfloat16> from_float_vector(const std::vector<float>&); static std::vector<bfloat16> from_float_vector(const std::vector<float>&);
static bfloat16 from_bits(uint16_t bits) { return bfloat16(bits, false); } static constexpr bfloat16 from_bits(uint16_t bits) { return bfloat16(bits, true); }
uint16_t to_bits() const; uint16_t to_bits() const;
friend std::ostream& operator<<(std::ostream& out, const bfloat16& obj) friend std::ostream& operator<<(std::ostream& out, const bfloat16& obj)
{ {
...@@ -85,6 +84,10 @@ namespace ngraph ...@@ -85,6 +84,10 @@ namespace ngraph
static uint16_t truncate(float x) { return static_cast<uint16_t>((cu32(x)) >> 16); } static uint16_t truncate(float x) { return static_cast<uint16_t>((cu32(x)) >> 16); }
private: private:
constexpr bfloat16(uint16_t x, bool)
: m_value{x}
{
}
union F32 { union F32 {
F32(float val) F32(float val)
: f{val} : f{val}
...@@ -97,15 +100,74 @@ namespace ngraph ...@@ -97,15 +100,74 @@ namespace ngraph
float f; float f;
uint32_t i; uint32_t i;
}; };
// This should be private since it is ugly. Need the bool so the signature can't match
// the float version of the ctor.
bfloat16(uint16_t value, bool)
: m_value{value}
{
}
uint16_t m_value; uint16_t m_value;
};
}
static uint16_t BF16_NAN_VALUE; namespace std
{
template <>
class numeric_limits<ngraph::bfloat16>
{
public:
static constexpr bool is_specialized = true;
static constexpr ngraph::bfloat16 min() noexcept
{
return ngraph::bfloat16::from_bits(0x007F);
}
static constexpr ngraph::bfloat16 max() noexcept
{
return ngraph::bfloat16::from_bits(0x7F7F);
}
static constexpr ngraph::bfloat16 lowest() noexcept
{
return ngraph::bfloat16::from_bits(0xFF7F);
}
static constexpr int digits = 7;
static constexpr int digits10 = 2;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = 2;
static constexpr ngraph::bfloat16 epsilon() noexcept
{
return ngraph::bfloat16::from_bits(0x3C00);
}
static constexpr ngraph::bfloat16 round_error() noexcept
{
return ngraph::bfloat16::from_bits(0x3F00);
}
static constexpr int min_exponent = -125;
static constexpr int min_exponent10 = -37;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_absent;
static constexpr bool has_denorm_loss = false;
static constexpr ngraph::bfloat16 infinity() noexcept
{
return ngraph::bfloat16::from_bits(0x7F80);
}
static constexpr ngraph::bfloat16 quiet_NaN() noexcept
{
return ngraph::bfloat16::from_bits(0x7FC0);
}
static constexpr ngraph::bfloat16 signaling_NaN() noexcept
{
return ngraph::bfloat16::from_bits(0x7FC0);
}
static constexpr ngraph::bfloat16 denorm_min() noexcept
{
return ngraph::bfloat16::from_bits(0);
}
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = false;
static constexpr bool is_modulo = false;
static constexpr bool traps = false;
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
}; };
} }
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <climits>
#include <random> #include <random>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -22,6 +23,7 @@ ...@@ -22,6 +23,7 @@
#include "ngraph/type/bfloat16.hpp" #include "ngraph/type/bfloat16.hpp"
#include "util/float_util.hpp" #include "util/float_util.hpp"
using namespace std;
using namespace ngraph; using namespace ngraph;
template <typename T> template <typename T>
...@@ -137,6 +139,19 @@ TEST(bfloat16, to_float) ...@@ -137,6 +139,19 @@ TEST(bfloat16, to_float)
EXPECT_EQ(f, 1.03125f); EXPECT_EQ(f, 1.03125f);
} }
TEST(bfloat16, numeric_limits)
{
bfloat16 infinity = numeric_limits<bfloat16>::infinity();
bfloat16 neg_infinity = -numeric_limits<bfloat16>::infinity();
bfloat16 quiet_nan = numeric_limits<bfloat16>::quiet_NaN();
bfloat16 signaling_nan = numeric_limits<bfloat16>::signaling_NaN();
EXPECT_TRUE(isinf(infinity));
EXPECT_TRUE(isinf(neg_infinity));
EXPECT_TRUE(isnan(quiet_nan));
EXPECT_TRUE(isnan(signaling_nan));
}
TEST(benchmark, bfloat16) TEST(benchmark, bfloat16)
{ {
size_t buffer_size = 128 * 3 * 224 * 224; size_t buffer_size = 128 * 3 * 224 * 224;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment