Commit dc45b9db authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Add numeric_limits for bfloat16 type (#2805)

* Add numeric_limits for bfloat16 type

* fix constexpr
parent 5c5690db
......@@ -42,8 +42,6 @@ using namespace ngraph;
static_assert(sizeof(bfloat16) == 2, "class bfloat16 must be exactly 2 bytes");
uint16_t bfloat16::BF16_NAN_VALUE = 0x7FC0;
bool float_isnan(const float& x)
{
return std::isnan(x);
......@@ -110,11 +108,6 @@ bfloat16::operator float() const
return *f;
}
bfloat16::operator double() const
{
return static_cast<float>(m_value);
}
uint16_t bfloat16::to_bits() const
{
return m_value;
......
......@@ -29,7 +29,7 @@ namespace ngraph
class bfloat16
{
public:
bfloat16()
constexpr bfloat16()
: m_value{0}
{
}
......@@ -59,11 +59,10 @@ namespace ngraph
bool operator>(const bfloat16& other) const;
bool operator>=(const bfloat16& other) const;
operator float() const;
operator double() const;
static std::vector<float> to_float_vector(const std::vector<bfloat16>&);
static std::vector<bfloat16> from_float_vector(const std::vector<float>&);
static bfloat16 from_bits(uint16_t bits) { return bfloat16(bits, false); }
static constexpr bfloat16 from_bits(uint16_t bits) { return bfloat16(bits, true); }
uint16_t to_bits() const;
friend std::ostream& operator<<(std::ostream& out, const bfloat16& obj)
{
......@@ -85,6 +84,10 @@ namespace ngraph
static uint16_t truncate(float x) { return static_cast<uint16_t>((cu32(x)) >> 16); }
private:
constexpr bfloat16(uint16_t x, bool)
: m_value{x}
{
}
union F32 {
F32(float val)
: f{val}
......@@ -97,15 +100,74 @@ namespace ngraph
float f;
uint32_t i;
};
// This should be private since it is ugly. Need the bool so the signature can't match
// the float version of the ctor.
bfloat16(uint16_t value, bool)
: m_value{value}
{
}
uint16_t m_value;
};
}
static uint16_t BF16_NAN_VALUE;
namespace std
{
template <>
class numeric_limits<ngraph::bfloat16>
{
public:
static constexpr bool is_specialized = true;
static constexpr ngraph::bfloat16 min() noexcept
{
return ngraph::bfloat16::from_bits(0x007F);
}
static constexpr ngraph::bfloat16 max() noexcept
{
return ngraph::bfloat16::from_bits(0x7F7F);
}
static constexpr ngraph::bfloat16 lowest() noexcept
{
return ngraph::bfloat16::from_bits(0xFF7F);
}
static constexpr int digits = 7;
static constexpr int digits10 = 2;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = 2;
static constexpr ngraph::bfloat16 epsilon() noexcept
{
return ngraph::bfloat16::from_bits(0x3C00);
}
static constexpr ngraph::bfloat16 round_error() noexcept
{
return ngraph::bfloat16::from_bits(0x3F00);
}
static constexpr int min_exponent = -125;
static constexpr int min_exponent10 = -37;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_absent;
static constexpr bool has_denorm_loss = false;
static constexpr ngraph::bfloat16 infinity() noexcept
{
return ngraph::bfloat16::from_bits(0x7F80);
}
static constexpr ngraph::bfloat16 quiet_NaN() noexcept
{
return ngraph::bfloat16::from_bits(0x7FC0);
}
static constexpr ngraph::bfloat16 signaling_NaN() noexcept
{
return ngraph::bfloat16::from_bits(0x7FC0);
}
static constexpr ngraph::bfloat16 denorm_min() noexcept
{
return ngraph::bfloat16::from_bits(0);
}
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = false;
static constexpr bool is_modulo = false;
static constexpr bool traps = false;
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
}
......@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <climits>
#include <random>
#include "gtest/gtest.h"
......@@ -22,6 +23,7 @@
#include "ngraph/type/bfloat16.hpp"
#include "util/float_util.hpp"
using namespace std;
using namespace ngraph;
template <typename T>
......@@ -137,6 +139,19 @@ TEST(bfloat16, to_float)
EXPECT_EQ(f, 1.03125f);
}
TEST(bfloat16, numeric_limits)
{
bfloat16 infinity = numeric_limits<bfloat16>::infinity();
bfloat16 neg_infinity = -numeric_limits<bfloat16>::infinity();
bfloat16 quiet_nan = numeric_limits<bfloat16>::quiet_NaN();
bfloat16 signaling_nan = numeric_limits<bfloat16>::signaling_NaN();
EXPECT_TRUE(isinf(infinity));
EXPECT_TRUE(isinf(neg_infinity));
EXPECT_TRUE(isnan(quiet_nan));
EXPECT_TRUE(isnan(signaling_nan));
}
TEST(benchmark, bfloat16)
{
size_t buffer_size = 128 * 3 * 224 * 224;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment