Commit 125f7242 authored by gcwenger's avatar gcwenger Committed by Scott Cyphers

Support for all_close_f w/ doubles (#2184)

* Double support for all_close_f

* all_close_f uses fixed number of mantissa bits now. Simplified testing code.

* Initialize test data members in constructor to values which will cause test failure. Setup then sets them correctly.

* Reduce info printed out during all_close_f unit tests.
parent 91c4b553
...@@ -28,10 +28,21 @@ using namespace std; ...@@ -28,10 +28,21 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
union FloatUnion { union FloatUnion {
FloatUnion() { i = 0; }
FloatUnion(float val) { f = val; }
FloatUnion(uint32_t val) { i = val; }
float f; float f;
uint32_t i; uint32_t i;
}; };
union DoubleUnion {
DoubleUnion() { i = 0; }
DoubleUnion(double val) { d = val; }
DoubleUnion(uint64_t val) { i = val; }
double d;
uint64_t i;
};
string float_to_bits(float f) string float_to_bits(float f)
{ {
FloatUnion fu{f}; FloatUnion fu{f};
...@@ -40,6 +51,14 @@ string float_to_bits(float f) ...@@ -40,6 +51,14 @@ string float_to_bits(float f)
return ss.str(); return ss.str();
} }
string double_to_bits(double d)
{
DoubleUnion du{d};
stringstream ss;
ss << bitset<64>(du.i);
return ss.str();
}
float bits_to_float(const string& s) float bits_to_float(const string& s)
{ {
if (s.size() != 32) if (s.size() != 32)
...@@ -52,6 +71,263 @@ float bits_to_float(const string& s) ...@@ -52,6 +71,263 @@ float bits_to_float(const string& s)
return fu.f; return fu.f;
} }
double bits_to_double(const string& s)
{
if (s.size() != 64)
{
throw ngraph_error("Input length must be 64");
}
bitset<64> bs(s);
DoubleUnion du;
du.i = static_cast<uint64_t>(bs.to_ullong());
return du.d;
}
class all_close_f_param_test : public testing::TestWithParam<::std::tuple<float, int, int>>
{
protected:
all_close_f_param_test()
: upper_bound(FLT_MAX)
, lower_bound(-FLT_MAX)
, past_upper_bound(FLT_MAX)
, past_lower_bound(-FLT_MAX)
{
std::tie(expected, mantissa_bits, tolerance_bits) = GetParam();
}
void SetUp() override
{
uint32_t expected_as_int = FloatUnion(expected).i;
// Turn on targeted bit
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
// float_length sign exp mantissa implicit 1 tolerance_bits
uint32_t tolerance_bit_shift = 32 - (1 + 8 + (mantissa_bits - 1) - tolerance_bits);
uint32_t targeted_bit = (1u << tolerance_bit_shift);
if (expected > 0.f)
{
uint32_t upper_bound_as_int = expected_as_int + targeted_bit;
upper_bound = FloatUnion(upper_bound_as_int).f;
past_upper_bound = FloatUnion(upper_bound_as_int + 1).f;
uint32_t lower_bound_as_int = expected_as_int - targeted_bit;
lower_bound = FloatUnion(lower_bound_as_int).f;
past_lower_bound = FloatUnion(lower_bound_as_int - 1).f;
}
else if (expected < 0.f)
{
// Same logic/math as above, but reversed variable name order
uint32_t lower_bound_as_int = expected_as_int + targeted_bit;
lower_bound = FloatUnion(lower_bound_as_int).f;
past_lower_bound = FloatUnion(lower_bound_as_int + 1).f;
uint32_t upper_bound_as_int = expected_as_int - targeted_bit;
upper_bound = FloatUnion(upper_bound_as_int).f;
past_upper_bound = FloatUnion(upper_bound_as_int - 1).f;
}
else // (expected == 0.f) || (expected == -0.f)
{
// Special handling of 0 / -0 which get same bounds
uint32_t upper_bound_as_int = targeted_bit;
upper_bound = FloatUnion(upper_bound_as_int).f;
uint32_t past_upper_bound_as_int = upper_bound_as_int + 1;
past_upper_bound = FloatUnion(past_upper_bound_as_int).f;
lower_bound = FloatUnion(upper_bound_as_int | 0x80000000).f;
past_lower_bound = FloatUnion(past_upper_bound_as_int | 0x80000000).f;
}
}
float expected;
int mantissa_bits;
int tolerance_bits;
float upper_bound;
float lower_bound;
float past_upper_bound;
float past_lower_bound;
};
TEST_P(all_close_f_param_test, test_boundaries)
{
// Print short string documenting which test is being run
std::cout << "[ INFO ] Test params: (" << expected << ", " << mantissa_bits << ", "
<< tolerance_bits << ")\n";
// Format verbose info to only print out in case of test failure
stringstream ss;
ss << "Testing target of: " << expected << " (" << float_to_bits(expected) << ")\n";
ss << "Matching to targets with: " << mantissa_bits << " mantissa_bits and " << tolerance_bits
<< " tolerance_bits\n";
ss << "upper_bound: " << upper_bound << " (" << float_to_bits(upper_bound) << ")\n";
ss << "lower_bound: " << lower_bound << " (" << float_to_bits(lower_bound) << ")\n";
ss << "past_upper_bound: " << past_upper_bound << " (" << float_to_bits(past_upper_bound)
<< ")\n";
ss << "past_lower_bound: " << past_lower_bound << " (" << float_to_bits(past_lower_bound)
<< ")\n";
EXPECT_TRUE(test::close_f(expected, upper_bound, mantissa_bits, tolerance_bits)) << ss.str();
EXPECT_TRUE(test::all_close_f(
vector<float>({expected}), vector<float>({upper_bound}), mantissa_bits, tolerance_bits))
<< ss.str();
EXPECT_TRUE(test::close_f(expected, lower_bound, mantissa_bits, tolerance_bits)) << ss.str();
EXPECT_TRUE(test::all_close_f(
vector<float>({expected}), vector<float>({lower_bound}), mantissa_bits, tolerance_bits))
<< ss.str();
EXPECT_FALSE(test::close_f(expected, past_upper_bound, mantissa_bits, tolerance_bits))
<< ss.str();
EXPECT_FALSE(test::all_close_f(vector<float>({expected}),
vector<float>({past_upper_bound}),
mantissa_bits,
tolerance_bits))
<< ss.str();
EXPECT_FALSE(test::close_f(expected, past_lower_bound, mantissa_bits, tolerance_bits))
<< ss.str();
EXPECT_FALSE(test::all_close_f(vector<float>({expected}),
vector<float>({past_lower_bound}),
mantissa_bits,
tolerance_bits))
<< ss.str();
}
// Avoid warning with how gtest defines INSTANTIATE_TEST_CASE_P
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmissing-variable-declarations"
INSTANTIATE_TEST_CASE_P(
test_simple_floats_with_range_of_precisions,
all_close_f_param_test,
testing::Combine(testing::Values(0.f,
-0.f,
1.f,
-1.f,
10.f,
-10.f,
0.75f,
-0.75f,
0.5f,
-0.5f,
0.25f,
-0.25f,
0.125f,
-0.125f),
testing::Values(8,
24), // For broader range of testing use testing::Range(8, 25)
testing::Range(0, 5)), );
#pragma GCC diagnostic pop
class all_close_f_double_param_test : public testing::TestWithParam<::std::tuple<double, int>>
{
protected:
all_close_f_double_param_test()
: mantissa_bits(53)
, upper_bound(DBL_MAX)
, lower_bound(-DBL_MAX)
, past_upper_bound(DBL_MAX)
, past_lower_bound(-DBL_MAX)
{
std::tie(expected, tolerance_bits) = GetParam();
}
void SetUp() override
{
uint64_t expected_as_int = DoubleUnion(expected).i;
// Turn on targeted bit
// e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (52 - 1 ) - 2 )
// double_length sign exp mantissa implicit 1 tolerance_bits
uint64_t tolerance_bit_shift = 64 - (1 + 11 + (mantissa_bits - 1) - tolerance_bits);
uint64_t targeted_bit = (1ull << tolerance_bit_shift);
if (expected > 0.)
{
uint64_t upper_bound_as_int = expected_as_int + targeted_bit;
upper_bound = DoubleUnion(upper_bound_as_int).d;
past_upper_bound = DoubleUnion(upper_bound_as_int + 1).d;
uint64_t lower_bound_as_int = expected_as_int - targeted_bit;
lower_bound = DoubleUnion(lower_bound_as_int).d;
past_lower_bound = DoubleUnion(lower_bound_as_int - 1).d;
}
else if (expected < 0.)
{
// Same logic/math as above, but reversed variable name order
uint64_t lower_bound_as_int = expected_as_int + targeted_bit;
lower_bound = DoubleUnion(lower_bound_as_int).d;
past_lower_bound = DoubleUnion(lower_bound_as_int + 1).d;
uint64_t upper_bound_as_int = expected_as_int - targeted_bit;
upper_bound = DoubleUnion(upper_bound_as_int).d;
past_upper_bound = DoubleUnion(upper_bound_as_int - 1).d;
}
else // (expected == 0.) || (expected == -0.)
{
// Special handling of 0 / -0 which get same bounds
uint64_t upper_bound_as_int = targeted_bit;
upper_bound = DoubleUnion(upper_bound_as_int).d;
uint64_t past_upper_bound_as_int = upper_bound_as_int + 1;
past_upper_bound = DoubleUnion(past_upper_bound_as_int).d;
lower_bound = DoubleUnion(upper_bound_as_int | 0x8000000000000000).d;
past_lower_bound = DoubleUnion(past_upper_bound_as_int | 0x8000000000000000).d;
}
}
double expected;
int mantissa_bits;
int tolerance_bits;
double upper_bound;
double lower_bound;
double past_upper_bound;
double past_lower_bound;
};
TEST_P(all_close_f_double_param_test, test_boundaries)
{
// Print short string documenting which test is being run
std::cout << "[ INFO ] Test params: (" << expected << ", " << tolerance_bits << ")\n";
// Format verbose info to only print out in case of test failure
stringstream ss;
ss << "Testing target of: " << expected << " (" << double_to_bits(expected) << ")\n";
ss << "Matching to targets with: " << mantissa_bits << " mantissa_bits and " << tolerance_bits
<< " tolerance_bits\n";
ss << "upper_bound: " << upper_bound << " (" << double_to_bits(upper_bound) << ")\n";
ss << "lower_bound: " << lower_bound << " (" << double_to_bits(lower_bound) << ")\n";
ss << "past_upper_bound: " << past_upper_bound << " (" << double_to_bits(past_upper_bound)
<< ")\n";
ss << "past_lower_bound: " << past_lower_bound << " (" << double_to_bits(past_lower_bound)
<< ")\n";
EXPECT_TRUE(test::close_f(expected, upper_bound, tolerance_bits)) << ss.str();
EXPECT_TRUE(test::all_close_f(
vector<double>({expected}), vector<double>({upper_bound}), tolerance_bits))
<< ss.str();
EXPECT_TRUE(test::close_f(expected, lower_bound, tolerance_bits)) << ss.str();
EXPECT_TRUE(test::all_close_f(
vector<double>({expected}), vector<double>({lower_bound}), tolerance_bits))
<< ss.str();
EXPECT_FALSE(test::close_f(expected, past_upper_bound, tolerance_bits)) << ss.str();
EXPECT_FALSE(test::all_close_f(
vector<double>({expected}), vector<double>({past_upper_bound}), tolerance_bits))
<< ss.str();
EXPECT_FALSE(test::close_f(expected, past_lower_bound, tolerance_bits)) << ss.str();
EXPECT_FALSE(test::all_close_f(
vector<double>({expected}), vector<double>({past_lower_bound}), tolerance_bits))
<< ss.str();
}
// Avoid warning with how gtest defines INSTANTIATE_TEST_CASE_P
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmissing-variable-declarations"
INSTANTIATE_TEST_CASE_P(
test_simple_doubles_with_range_of_precisions,
all_close_f_double_param_test,
testing::Combine(
testing::Values(
0., -0., 1., -1., 10., -10., 0.75, -0.75, 0.5, -0.5, 0.25, -0.25, 0.125, -0.125),
testing::Range(0, 17)), );
#pragma GCC diagnostic pop
// Test the exact bounds near +0.f // Test the exact bounds near +0.f
// //
// With mantissa_bits = 8, tolerance_bits = 2 // With mantissa_bits = 8, tolerance_bits = 2
...@@ -516,3 +792,31 @@ TEST(all_close_f, inf_nan) ...@@ -516,3 +792,31 @@ TEST(all_close_f, inf_nan)
EXPECT_FALSE(test::close_f(signaling_nan, signaling_nan)); EXPECT_FALSE(test::close_f(signaling_nan, signaling_nan));
EXPECT_FALSE(test::all_close_f(vector<float>({signaling_nan}), vector<float>({signaling_nan}))); EXPECT_FALSE(test::all_close_f(vector<float>({signaling_nan}), vector<float>({signaling_nan})));
} }
TEST(all_close_f, double_inf_nan)
{
double zero = 0.f;
double infinity = numeric_limits<double>::infinity();
double neg_infinity = -numeric_limits<double>::infinity();
double quiet_nan = numeric_limits<double>::quiet_NaN();
double signaling_nan = numeric_limits<double>::signaling_NaN();
EXPECT_FALSE(test::close_f(zero, infinity));
EXPECT_FALSE(test::all_close_f(vector<double>({zero}), vector<double>({infinity})));
EXPECT_FALSE(test::close_f(zero, neg_infinity));
EXPECT_FALSE(test::all_close_f(vector<double>({zero}), vector<double>({neg_infinity})));
EXPECT_FALSE(test::close_f(zero, quiet_nan));
EXPECT_FALSE(test::all_close_f(vector<double>({zero}), vector<double>({quiet_nan})));
EXPECT_FALSE(test::close_f(zero, signaling_nan));
EXPECT_FALSE(test::all_close_f(vector<double>({zero}), vector<double>({signaling_nan})));
EXPECT_FALSE(test::close_f(infinity, infinity));
EXPECT_FALSE(test::all_close_f(vector<double>({infinity}), vector<double>({infinity})));
EXPECT_FALSE(test::close_f(neg_infinity, neg_infinity));
EXPECT_FALSE(test::all_close_f(vector<double>({neg_infinity}), vector<double>({neg_infinity})));
EXPECT_FALSE(test::close_f(quiet_nan, quiet_nan));
EXPECT_FALSE(test::all_close_f(vector<double>({quiet_nan}), vector<double>({quiet_nan})));
EXPECT_FALSE(test::close_f(signaling_nan, signaling_nan));
EXPECT_FALSE(
test::all_close_f(vector<double>({signaling_nan}), vector<double>({signaling_nan})));
}
...@@ -27,6 +27,11 @@ union FloatUnion { ...@@ -27,6 +27,11 @@ union FloatUnion {
uint32_t i; uint32_t i;
}; };
union DoubleUnion {
double d;
uint64_t i;
};
uint32_t test::float_distance(float a, float b) uint32_t test::float_distance(float a, float b)
{ {
if (!isfinite(a) || !isfinite(b)) if (!isfinite(a) || !isfinite(b))
...@@ -50,6 +55,29 @@ uint32_t test::float_distance(float a, float b) ...@@ -50,6 +55,29 @@ uint32_t test::float_distance(float a, float b)
return distance; return distance;
} }
uint64_t test::float_distance(double a, double b)
{
if (!isfinite(a) || !isfinite(b))
{
return ULLONG_MAX;
}
DoubleUnion a_du{a};
DoubleUnion b_du{b};
uint64_t a_uint = a_du.i;
uint64_t b_uint = b_du.i;
// A trick to handle both positive and negative numbers, see https://goo.gl/YbdnFQ
// - If negative: convert to two's complement
// - If positive: mask with sign bit
uint64_t sign_mask = static_cast<uint64_t>(1U) << 63;
a_uint = (sign_mask & a_uint) ? (~a_uint + 1) : (sign_mask | a_uint);
b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);
uint64_t distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
return distance;
}
bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits) bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
{ {
// isfinite(a) => !isinf(a) && !isnan(a) // isfinite(a) => !isinf(a) && !isnan(a)
...@@ -69,6 +97,27 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits) ...@@ -69,6 +97,27 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
return distance <= tolerance; return distance <= tolerance;
} }
bool test::close_f(double a, double b, int tolerance_bits)
{
constexpr int mantissa_bits = 53;
// isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b))
{
return false;
}
uint64_t distance = float_distance(a, b);
// e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
// double_length sign exp mantissa implicit 1 tolerance_bits
uint64_t tolerance_bit_shift = 64 - (1 + 11 + (mantissa_bits - 1) - tolerance_bits);
uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;
return distance <= tolerance;
}
vector<uint32_t> test::float_distances(const vector<float>& a, const vector<float>& b) vector<uint32_t> test::float_distances(const vector<float>& a, const vector<float>& b)
{ {
if (a.size() != b.size()) if (a.size() != b.size())
...@@ -84,6 +133,21 @@ vector<uint32_t> test::float_distances(const vector<float>& a, const vector<floa ...@@ -84,6 +133,21 @@ vector<uint32_t> test::float_distances(const vector<float>& a, const vector<floa
return distances; return distances;
} }
vector<uint64_t> test::float_distances(const vector<double>& a, const vector<double>& b)
{
if (a.size() != b.size())
{
throw ngraph_error("a.size() != b.size() for float_distances comparison.");
}
vector<uint64_t> distances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
distances[i] = float_distance(a[i], b[i]);
}
return distances;
}
uint32_t test::matching_mantissa_bits(uint32_t distance) uint32_t test::matching_mantissa_bits(uint32_t distance)
{ {
uint32_t tolerance_needed = distance; uint32_t tolerance_needed = distance;
...@@ -127,6 +191,50 @@ uint32_t test::matching_mantissa_bits(uint32_t distance) ...@@ -127,6 +191,50 @@ uint32_t test::matching_mantissa_bits(uint32_t distance)
return matching_matissa_bits; return matching_matissa_bits;
} }
uint64_t test::matching_mantissa_bits(uint64_t distance)
{
uint64_t tolerance_needed = distance;
if (tolerance_needed < 0x8000000000000000)
{
// Set up the dominos - turn on all the bits below maximal bit
tolerance_needed |= tolerance_needed >> 1;
tolerance_needed |= tolerance_needed >> 2;
tolerance_needed |= tolerance_needed >> 4;
tolerance_needed |= tolerance_needed >> 8;
tolerance_needed |= tolerance_needed >> 16;
tolerance_needed |= tolerance_needed >> 32;
// Tumble the dominos so we end up with next highest bit
++tolerance_needed;
// all_close_f is <= test for tolerance
if ((tolerance_needed >> 1) == distance)
{
tolerance_needed = distance;
}
}
uint64_t tolerance_bit_shift = 0;
while (tolerance_needed >>= 1)
{
++tolerance_bit_shift;
}
// all_close_f calculation of tolerance_bit_shift:
// e.g. for double with 53 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
// double_length sign exp matching_matissa_bits implicit 1 tolerance_bits
//
// Assuming 0 tolerance_bits and solving for matching_matissa_bits yields:
// tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) - 0 )
// tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) )
// matching_matissa_bits = 64 - (1 + 11 + (tolerance_bit_shift - 1 ) )
uint64_t matching_matissa_bits =
tolerance_bit_shift < 53 ? (64 - (1 + 11 + (tolerance_bit_shift - 1))) : 0;
return matching_matissa_bits;
}
bool test::all_close_f(const vector<float>& a, bool test::all_close_f(const vector<float>& a,
const vector<float>& b, const vector<float>& b,
int mantissa_bits, int mantissa_bits,
...@@ -206,6 +314,82 @@ bool test::all_close_f(const vector<float>& a, ...@@ -206,6 +314,82 @@ bool test::all_close_f(const vector<float>& a,
return rc; return rc;
} }
bool test::all_close_f(const vector<double>& a, const vector<double>& b, int tolerance_bits)
{
constexpr int mantissa_bits = 53;
bool rc = true;
if (a.size() != b.size())
{
throw ngraph_error("a.size() != b.size() for all_close_f comparison.");
}
vector<uint64_t> distances = float_distances(a, b);
// e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
// double_length sign exp mantissa implicit 1 tolerance_bits
uint64_t tolerance_bit_shift = 64 - (1 + 11 + (mantissa_bits - 1) - tolerance_bits);
uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;
uint64_t max_distance = 0;
uint64_t min_distance = ULLONG_MAX;
size_t max_distance_index = 0;
size_t min_distance_index = 0;
size_t diff_count = 0;
for (size_t i = 0; i < a.size(); ++i)
{
if (distances[i] > max_distance)
{
max_distance = distances[i];
max_distance_index = i;
}
if (distances[i] < min_distance)
{
min_distance = distances[i];
min_distance_index = i;
}
bool is_close_f = distances[i] <= tolerance;
if (!is_close_f)
{
if (diff_count < 5)
{
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i;
}
rc = false;
diff_count++;
}
}
if (!rc)
{
NGRAPH_INFO << "diff count: " << diff_count << " out of " << a.size();
}
// Find median value via partial sorting
size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
uint64_t median_distance = distances[middle];
if (distances.size() % 2 == 0)
{
uint64_t median_distance2 = *max_element(distances.begin(), distances.begin() + middle);
uint64_t remainder1 = median_distance % 2;
uint64_t remainder2 = median_distance2 % 2;
median_distance =
(median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2);
}
NGRAPH_INFO << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)";
NGRAPH_INFO << "tightest match: " << matching_mantissa_bits(min_distance)
<< " mantissa bits (" << a[min_distance_index] << " vs " << b[min_distance_index]
<< " at [" << min_distance_index << "])";
NGRAPH_INFO << "loosest match: " << matching_mantissa_bits(max_distance)
<< " mantissa bits (" << a[max_distance_index] << " vs " << b[max_distance_index]
<< " at [" << max_distance_index << "])";
NGRAPH_INFO << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits";
return rc;
}
bool test::all_close_f(const std::shared_ptr<runtime::Tensor>& a, bool test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b, const std::shared_ptr<runtime::Tensor>& b,
int mantissa_bits, int mantissa_bits,
......
...@@ -46,6 +46,24 @@ namespace ngraph ...@@ -46,6 +46,24 @@ namespace ngraph
/// bfloat and f32. /// bfloat and f32.
uint32_t float_distance(float a, float b); uint32_t float_distance(float a, float b);
/// \brief Determine distance between two f64 numbers
/// \param a First number to compare
/// \param b Second number to compare
/// \returns Distance
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/google/googletest/blob/master/googletest/docs/AdvancedGuide.md#floating-point-comparison
///
/// s e e e e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m
/// |----------------------------double-------------------------------------------------------------------------------------------|
///
/// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64.
uint64_t float_distance(double a, double b);
/// \brief Check if the two f32 numbers are close /// \brief Check if the two f32 numbers are close
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
...@@ -56,7 +74,7 @@ namespace ngraph ...@@ -56,7 +74,7 @@ namespace ngraph
/// References: /// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place /// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2 /// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/google/googletest/blob/master/googletest/docs/AdvancedGuide.md#floating-point-comparison /// - https://github.com/abseil/googletest/blob/master/googletest/docs/advanced.md#floating-point-comparison
/// ///
/// s e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m /// s e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m
/// |------------bfloat-----------| /// |------------bfloat-----------|
...@@ -69,6 +87,25 @@ namespace ngraph ...@@ -69,6 +87,25 @@ namespace ngraph
/// bfloat and f32. /// bfloat and f32.
bool close_f(float a, float b, int mantissa_bits = 8, int tolerance_bits = 2); bool close_f(float a, float b, int mantissa_bits = 8, int tolerance_bits = 2);
/// \brief Check if the two f64 numbers are close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/abseil/googletest/blob/master/googletest/docs/advanced.md#floating-point-comparison
///
/// s e e e e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m
/// |----------------------------double-------------------------------------------------------------------------------------------|
///
/// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64.
bool close_f(double a, double b, int tolerance_bits = 2);
/// \brief Determine distances between two vectors of f32 numbers /// \brief Determine distances between two vectors of f32 numbers
/// \param a Vector of floats to compare /// \param a Vector of floats to compare
/// \param b Vector of floats to compare /// \param b Vector of floats to compare
...@@ -78,6 +115,15 @@ namespace ngraph ...@@ -78,6 +115,15 @@ namespace ngraph
std::vector<uint32_t> float_distances(const std::vector<float>& a, std::vector<uint32_t> float_distances(const std::vector<float>& a,
const std::vector<float>& b); const std::vector<float>& b);
/// \brief Determine distances between two vectors of f64 numbers
/// \param a Vector of doubles to compare
/// \param b Vector of doubles to compare
/// \returns Vector of distances
///
/// See float_distance for limitations and assumptions.
std::vector<uint64_t> float_distances(const std::vector<double>& a,
const std::vector<double>& b);
/// \brief Determine number of matching mantissa bits given a distance /// \brief Determine number of matching mantissa bits given a distance
/// \param distance Distance calculated by float_distance /// \param distance Distance calculated by float_distance
/// \returns Number of matching mantissa bits /// \returns Number of matching mantissa bits
...@@ -85,6 +131,13 @@ namespace ngraph ...@@ -85,6 +131,13 @@ namespace ngraph
/// See float_distance for limitations and assumptions. /// See float_distance for limitations and assumptions.
uint32_t matching_mantissa_bits(uint32_t distance); uint32_t matching_mantissa_bits(uint32_t distance);
/// \brief Determine number of matching mantissa bits given a distance
/// \param distance Distance calculated by float_distance
/// \returns Number of matching mantissa bits
///
/// See float_distance for limitations and assumptions.
uint64_t matching_mantissa_bits(uint64_t distance);
/// \brief Check if the two floating point vectors are all close /// \brief Check if the two floating point vectors are all close
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
...@@ -96,6 +149,15 @@ namespace ngraph ...@@ -96,6 +149,15 @@ namespace ngraph
int mantissa_bits = 8, int mantissa_bits = 8,
int tolerance_bits = 2); int tolerance_bits = 2);
/// \brief Check if the two double floating point vectors are all close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \returns true iff the two floating point vectors are close
bool all_close_f(const std::vector<double>& a,
const std::vector<double>& b,
int tolerance_bits = 2);
/// \brief Check if the two TensorViews are all close in float /// \brief Check if the two TensorViews are all close in float
/// \param a First Tensor to compare /// \param a First Tensor to compare
/// \param b Second Tensor to compare /// \param b Second Tensor to compare
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment