Commit 125f7242 authored by gcwenger's avatar gcwenger Committed by Scott Cyphers

Support for all_close_f w/ doubles (#2184)

* Double support for all_close_f

* all_close_f uses fixed number of mantissa bits now. Simplified testing code.

* Initialize test data members in constructor to values which will cause test failure. Setup then sets them correctly.

* Reduce info printed out during all_close_f unit tests.
parent 91c4b553
This diff is collapsed.
......@@ -27,6 +27,11 @@ union FloatUnion {
uint32_t i;
};
union DoubleUnion {
double d;
uint64_t i;
};
uint32_t test::float_distance(float a, float b)
{
if (!isfinite(a) || !isfinite(b))
......@@ -50,6 +55,29 @@ uint32_t test::float_distance(float a, float b)
return distance;
}
uint64_t test::float_distance(double a, double b)
{
if (!isfinite(a) || !isfinite(b))
{
return ULLONG_MAX;
}
DoubleUnion a_du{a};
DoubleUnion b_du{b};
uint64_t a_uint = a_du.i;
uint64_t b_uint = b_du.i;
// A trick to handle both positive and negative numbers, see https://goo.gl/YbdnFQ
// - If negative: convert to two's complement
// - If positive: mask with sign bit
uint64_t sign_mask = static_cast<uint64_t>(1U) << 63;
a_uint = (sign_mask & a_uint) ? (~a_uint + 1) : (sign_mask | a_uint);
b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);
uint64_t distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
return distance;
}
bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
{
// isfinite(a) => !isinf(a) && !isnan(a)
......@@ -69,6 +97,27 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
return distance <= tolerance;
}
bool test::close_f(double a, double b, int tolerance_bits)
{
constexpr int mantissa_bits = 53;
// isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b))
{
return false;
}
uint64_t distance = float_distance(a, b);
// e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
// double_length sign exp mantissa implicit 1 tolerance_bits
uint64_t tolerance_bit_shift = 64 - (1 + 11 + (mantissa_bits - 1) - tolerance_bits);
uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;
return distance <= tolerance;
}
vector<uint32_t> test::float_distances(const vector<float>& a, const vector<float>& b)
{
if (a.size() != b.size())
......@@ -84,6 +133,21 @@ vector<uint32_t> test::float_distances(const vector<float>& a, const vector<floa
return distances;
}
vector<uint64_t> test::float_distances(const vector<double>& a, const vector<double>& b)
{
if (a.size() != b.size())
{
throw ngraph_error("a.size() != b.size() for float_distances comparison.");
}
vector<uint64_t> distances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
distances[i] = float_distance(a[i], b[i]);
}
return distances;
}
uint32_t test::matching_mantissa_bits(uint32_t distance)
{
uint32_t tolerance_needed = distance;
......@@ -127,6 +191,50 @@ uint32_t test::matching_mantissa_bits(uint32_t distance)
return matching_matissa_bits;
}
uint64_t test::matching_mantissa_bits(uint64_t distance)
{
uint64_t tolerance_needed = distance;
if (tolerance_needed < 0x8000000000000000)
{
// Set up the dominos - turn on all the bits below maximal bit
tolerance_needed |= tolerance_needed >> 1;
tolerance_needed |= tolerance_needed >> 2;
tolerance_needed |= tolerance_needed >> 4;
tolerance_needed |= tolerance_needed >> 8;
tolerance_needed |= tolerance_needed >> 16;
tolerance_needed |= tolerance_needed >> 32;
// Tumble the dominos so we end up with next highest bit
++tolerance_needed;
// all_close_f is <= test for tolerance
if ((tolerance_needed >> 1) == distance)
{
tolerance_needed = distance;
}
}
uint64_t tolerance_bit_shift = 0;
while (tolerance_needed >>= 1)
{
++tolerance_bit_shift;
}
// all_close_f calculation of tolerance_bit_shift:
// e.g. for double with 53 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
// double_length sign exp matching_matissa_bits implicit 1 tolerance_bits
//
// Assuming 0 tolerance_bits and solving for matching_matissa_bits yields:
// tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) - 0 )
// tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) )
// matching_matissa_bits = 64 - (1 + 11 + (tolerance_bit_shift - 1 ) )
uint64_t matching_matissa_bits =
tolerance_bit_shift < 53 ? (64 - (1 + 11 + (tolerance_bit_shift - 1))) : 0;
return matching_matissa_bits;
}
bool test::all_close_f(const vector<float>& a,
const vector<float>& b,
int mantissa_bits,
......@@ -206,6 +314,82 @@ bool test::all_close_f(const vector<float>& a,
return rc;
}
bool test::all_close_f(const vector<double>& a, const vector<double>& b, int tolerance_bits)
{
constexpr int mantissa_bits = 53;
bool rc = true;
if (a.size() != b.size())
{
throw ngraph_error("a.size() != b.size() for all_close_f comparison.");
}
vector<uint64_t> distances = float_distances(a, b);
// e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
// double_length sign exp mantissa implicit 1 tolerance_bits
uint64_t tolerance_bit_shift = 64 - (1 + 11 + (mantissa_bits - 1) - tolerance_bits);
uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;
uint64_t max_distance = 0;
uint64_t min_distance = ULLONG_MAX;
size_t max_distance_index = 0;
size_t min_distance_index = 0;
size_t diff_count = 0;
for (size_t i = 0; i < a.size(); ++i)
{
if (distances[i] > max_distance)
{
max_distance = distances[i];
max_distance_index = i;
}
if (distances[i] < min_distance)
{
min_distance = distances[i];
min_distance_index = i;
}
bool is_close_f = distances[i] <= tolerance;
if (!is_close_f)
{
if (diff_count < 5)
{
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i;
}
rc = false;
diff_count++;
}
}
if (!rc)
{
NGRAPH_INFO << "diff count: " << diff_count << " out of " << a.size();
}
// Find median value via partial sorting
size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
uint64_t median_distance = distances[middle];
if (distances.size() % 2 == 0)
{
uint64_t median_distance2 = *max_element(distances.begin(), distances.begin() + middle);
uint64_t remainder1 = median_distance % 2;
uint64_t remainder2 = median_distance2 % 2;
median_distance =
(median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2);
}
NGRAPH_INFO << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)";
NGRAPH_INFO << "tightest match: " << matching_mantissa_bits(min_distance)
<< " mantissa bits (" << a[min_distance_index] << " vs " << b[min_distance_index]
<< " at [" << min_distance_index << "])";
NGRAPH_INFO << "loosest match: " << matching_mantissa_bits(max_distance)
<< " mantissa bits (" << a[max_distance_index] << " vs " << b[max_distance_index]
<< " at [" << max_distance_index << "])";
NGRAPH_INFO << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits";
return rc;
}
bool test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b,
int mantissa_bits,
......
......@@ -46,6 +46,24 @@ namespace ngraph
/// bfloat and f32.
uint32_t float_distance(float a, float b);
/// \brief Determine distance between two f64 numbers
/// \param a First number to compare
/// \param b Second number to compare
/// \returns Distance
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/google/googletest/blob/master/googletest/docs/AdvancedGuide.md#floating-point-comparison
///
/// s e e e e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m
/// |----------------------------double-------------------------------------------------------------------------------------------|
///
/// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64.
uint64_t float_distance(double a, double b);
/// \brief Check if the two f32 numbers are close
/// \param a First number to compare
/// \param b Second number to compare
......@@ -56,7 +74,7 @@ namespace ngraph
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/google/googletest/blob/master/googletest/docs/AdvancedGuide.md#floating-point-comparison
/// - https://github.com/abseil/googletest/blob/master/googletest/docs/advanced.md#floating-point-comparison
///
/// s e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m
/// |------------bfloat-----------|
......@@ -69,6 +87,25 @@ namespace ngraph
/// bfloat and f32.
bool close_f(float a, float b, int mantissa_bits = 8, int tolerance_bits = 2);
/// \brief Check if the two f64 numbers are close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/abseil/googletest/blob/master/googletest/docs/advanced.md#floating-point-comparison
///
/// s e e e e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m
/// |----------------------------double-------------------------------------------------------------------------------------------|
///
/// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64.
bool close_f(double a, double b, int tolerance_bits = 2);
/// \brief Determine distances between two vectors of f32 numbers
/// \param a Vector of floats to compare
/// \param b Vector of floats to compare
......@@ -78,6 +115,15 @@ namespace ngraph
std::vector<uint32_t> float_distances(const std::vector<float>& a,
const std::vector<float>& b);
/// \brief Determine distances between two vectors of f64 numbers
/// \param a Vector of doubles to compare
/// \param b Vector of doubles to compare
/// \returns Vector of distances
///
/// See float_distance for limitations and assumptions.
std::vector<uint64_t> float_distances(const std::vector<double>& a,
const std::vector<double>& b);
/// \brief Determine number of matching mantissa bits given a distance
/// \param distance Distance calculated by float_distance
/// \returns Number of matching mantissa bits
......@@ -85,6 +131,13 @@ namespace ngraph
/// See float_distance for limitations and assumptions.
uint32_t matching_mantissa_bits(uint32_t distance);
/// \brief Determine number of matching mantissa bits given a distance
/// \param distance Distance calculated by float_distance
/// \returns Number of matching mantissa bits
///
/// See float_distance for limitations and assumptions.
uint64_t matching_mantissa_bits(uint64_t distance);
/// \brief Check if the two floating point vectors are all close
/// \param a First number to compare
/// \param b Second number to compare
......@@ -96,6 +149,15 @@ namespace ngraph
int mantissa_bits = 8,
int tolerance_bits = 2);
/// \brief Check if the two double floating point vectors are all close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \returns true iff the two floating point vectors are close
bool all_close_f(const std::vector<double>& a,
const std::vector<double>& b,
int tolerance_bits = 2);
/// \brief Check if the two TensorViews are all close in float
/// \param a First Tensor to compare
/// \param b Second Tensor to compare
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment