Commit e3a3ad2d authored by gcwenger's avatar gcwenger Committed by Scott Cyphers

Additional info calculated & printed for all_close_f. (#2127)

* Additional info calculated & printed for all_close_f.

* Fixed format

* Tweaked test calls to all_close_f to make unambiguous.

* Tweaked test calls to all_close_f to match close_f for last unit test.
parent a4b9e6b7
This diff is collapsed.
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <climits>
#include <cmath> #include <cmath>
#include "util/all_close_f.hpp" #include "util/all_close_f.hpp"
...@@ -26,12 +27,11 @@ union FloatUnion { ...@@ -26,12 +27,11 @@ union FloatUnion {
uint32_t i; uint32_t i;
}; };
bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits) uint32_t test::float_distance(float a, float b)
{ {
// isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b)) if (!isfinite(a) || !isfinite(b))
{ {
return false; return UINT_MAX;
} }
FloatUnion a_fu{a}; FloatUnion a_fu{a};
...@@ -47,6 +47,18 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits) ...@@ -47,6 +47,18 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint); b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);
uint32_t distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint); uint32_t distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
return distance;
}
bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
{
// isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b))
{
return false;
}
uint32_t distance = float_distance(a, b);
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits // e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 ) // tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
...@@ -57,6 +69,64 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits) ...@@ -57,6 +69,64 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
return distance <= tolerance; return distance <= tolerance;
} }
vector<uint32_t> test::float_distances(const vector<float>& a, const vector<float>& b)
{
if (a.size() != b.size())
{
throw ngraph_error("a.size() != b.size() for float_distances comparison.");
}
vector<uint32_t> distances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
distances[i] = float_distance(a[i], b[i]);
}
return distances;
}
uint32_t test::matching_mantissa_bits(uint32_t distance)
{
uint32_t tolerance_needed = distance;
if (tolerance_needed < 0x80000000)
{
// Set up the dominos - turn on all the bits below maximal bit
tolerance_needed |= tolerance_needed >> 1;
tolerance_needed |= tolerance_needed >> 2;
tolerance_needed |= tolerance_needed >> 4;
tolerance_needed |= tolerance_needed >> 8;
tolerance_needed |= tolerance_needed >> 16;
// Tumble the dominos so we end up with next highest bit
++tolerance_needed;
// all_close_f is <= test for tolerance
if ((tolerance_needed >> 1) == distance)
{
tolerance_needed = distance;
}
}
uint32_t tolerance_bit_shift = 0;
while (tolerance_needed >>= 1)
{
++tolerance_bit_shift;
}
// all_close_f calculation of tolerance_bit_shift:
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
// float_length sign exp matching_matissa_bits implicit 1 tolerance_bits
//
// Assuming 0 tolerance_bits and solving for matching_matissa_bits yields:
// tolerance_bit_shift = 32 - (1 + 8 + (matching_matissa_bits - 1 ) - 0 )
// tolerance_bit_shift = 32 - (1 + 8 + (matching_matissa_bits - 1 ) )
// matching_matissa_bits = 32 - (1 + 8 + (tolerance_bit_shift - 1 ) )
uint32_t matching_matissa_bits =
tolerance_bit_shift < 24 ? (32 - (1 + 8 + (tolerance_bit_shift - 1))) : 0;
return matching_matissa_bits;
}
bool test::all_close_f(const vector<float>& a, bool test::all_close_f(const vector<float>& a,
const vector<float>& b, const vector<float>& b,
int mantissa_bits, int mantissa_bits,
...@@ -65,27 +135,70 @@ bool test::all_close_f(const vector<float>& a, ...@@ -65,27 +135,70 @@ bool test::all_close_f(const vector<float>& a,
bool rc = true; bool rc = true;
if (a.size() != b.size()) if (a.size() != b.size())
{ {
throw ngraph_error("a.size() != b.size() for all_close comparison."); throw ngraph_error("a.size() != b.size() for all_close_f comparison.");
} }
size_t count = 0; vector<uint32_t> distances = float_distances(a, b);
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
// float_length sign exp mantissa implicit 1 tolerance_bits
uint32_t tolerance_bit_shift = 32 - (1 + 8 + (mantissa_bits - 1) - tolerance_bits);
uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;
uint32_t max_distance = 0;
uint32_t min_distance = UINT_MAX;
size_t max_distance_index = 0;
size_t min_distance_index = 0;
size_t diff_count = 0;
for (size_t i = 0; i < a.size(); ++i) for (size_t i = 0; i < a.size(); ++i)
{ {
bool is_close_f = close_f(a[i], b[i], mantissa_bits, tolerance_bits); if (distances[i] > max_distance)
{
max_distance = distances[i];
max_distance_index = i;
}
if (distances[i] < min_distance)
{
min_distance = distances[i];
min_distance_index = i;
}
bool is_close_f = distances[i] <= tolerance;
if (!is_close_f) if (!is_close_f)
{ {
if (count < 5) if (diff_count < 5)
{ {
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i; NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i;
} }
rc = false; rc = false;
count++; diff_count++;
} }
} }
if (!rc) if (!rc)
{ {
NGRAPH_INFO << "diff count: " << count << " out of " << a.size(); NGRAPH_INFO << "diff count: " << diff_count << " out of " << a.size();
} }
// Find median value via partial sorting
size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
uint32_t median_distance = distances[middle];
if (distances.size() % 2 == 0)
{
// Find middle-1 value
uint64_t median_sum = static_cast<uint64_t>(median_distance) +
*max_element(distances.begin(), distances.begin() + middle);
median_distance = median_sum / 2;
}
NGRAPH_INFO << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)";
NGRAPH_INFO << "tightest match: " << matching_mantissa_bits(min_distance)
<< " mantissa bits (" << a[min_distance_index] << " vs " << b[min_distance_index]
<< " at [" << min_distance_index << "])";
NGRAPH_INFO << "loosest match: " << matching_mantissa_bits(max_distance)
<< " mantissa bits (" << a[max_distance_index] << " vs " << b[max_distance_index]
<< " at [" << max_distance_index << "])";
NGRAPH_INFO << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits";
return rc; return rc;
} }
......
...@@ -25,6 +25,27 @@ namespace ngraph ...@@ -25,6 +25,27 @@ namespace ngraph
{ {
namespace test namespace test
{ {
/// \brief Determine distance between two f32 numbers
/// \param a First number to compare
/// \param b Second number to compare
/// \returns Distance
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/google/googletest/blob/master/googletest/docs/AdvancedGuide.md#floating-point-comparison
///
/// s e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m
/// |------------bfloat-----------|
/// |----------------------------float----------------------------|
///
/// bfloat (s1, e8, m7) has 7 + 1 = 8 bits of mantissa or bit_precision
/// float (s1, e8, m23) has 23 + 1 = 24 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for
/// bfloat and f32.
uint32_t float_distance(float a, float b);
/// \brief Check if the two f32 numbers are close /// \brief Check if the two f32 numbers are close
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
...@@ -48,6 +69,22 @@ namespace ngraph ...@@ -48,6 +69,22 @@ namespace ngraph
/// bfloat and f32. /// bfloat and f32.
bool close_f(float a, float b, int mantissa_bits = 8, int tolerance_bits = 2); bool close_f(float a, float b, int mantissa_bits = 8, int tolerance_bits = 2);
/// \brief Determine distances between two vectors of f32 numbers
/// \param a Vector of floats to compare
/// \param b Vector of floats to compare
/// \returns Vector of distances
///
/// See float_distance for limitations and assumptions.
std::vector<uint32_t> float_distances(const std::vector<float>& a,
const std::vector<float>& b);
/// \brief Determine number of matching mantissa bits given a distance
/// \param distance Distance calculated by float_distance
/// \returns Number of matching mantissa bits
///
/// See float_distance for limitations and assumptions.
uint32_t matching_mantissa_bits(uint32_t distance);
/// \brief Check if the two floating point vectors are all close /// \brief Check if the two floating point vectors are all close
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment