Commit abd1c70d authored by gcwenger's avatar gcwenger Committed by Scott Cyphers

Added "min_signal" paramater to float testing (#2653)

min_signal is intended to optionally skip checking float
distances when numbers are close enough to 0.
Only skips when both numbers are < min_signal.
Intention is to allow tighter float testing in certain
cases where most values are not near 0, but values near
zero are differing by more bits than values farther from 0.
Should be used with caution in limited cases.
parent c6c2aafb
......@@ -151,6 +151,8 @@ protected:
uint32_t upper_bound_as_int = expected_as_int + targeted_bit;
upper_bound = FloatUnion(upper_bound_as_int).f;
past_upper_bound = FloatUnion(upper_bound_as_int + 1).f;
min_signal_too_low = expected;
min_signal_enables_passing = FloatUnion(upper_bound_as_int + 2).f;
uint32_t lower_bound_as_int = expected_as_int - targeted_bit;
lower_bound = FloatUnion(lower_bound_as_int).f;
......@@ -162,6 +164,8 @@ protected:
uint32_t lower_bound_as_int = expected_as_int + targeted_bit;
lower_bound = FloatUnion(lower_bound_as_int).f;
past_lower_bound = FloatUnion(lower_bound_as_int + 1).f;
min_signal_too_low = expected;
min_signal_enables_passing = FloatUnion(lower_bound_as_int + 2).f;
uint32_t upper_bound_as_int = expected_as_int - targeted_bit;
upper_bound = FloatUnion(upper_bound_as_int).f;
......@@ -174,6 +178,8 @@ protected:
upper_bound = FloatUnion(upper_bound_as_int).f;
uint32_t past_upper_bound_as_int = upper_bound_as_int + 1;
past_upper_bound = FloatUnion(past_upper_bound_as_int).f;
min_signal_too_low = expected;
min_signal_enables_passing = FloatUnion(upper_bound_as_int + 2).f;
lower_bound = FloatUnion(upper_bound_as_int | 0x80000000).f;
past_lower_bound = FloatUnion(past_upper_bound_as_int | 0x80000000).f;
......@@ -186,6 +192,8 @@ protected:
float lower_bound;
float past_upper_bound;
float past_lower_bound;
float min_signal_too_low;
float min_signal_enables_passing;
};
TEST_P(all_close_f_param_test, test_boundaries)
......@@ -216,13 +224,43 @@ TEST_P(all_close_f_param_test, test_boundaries)
test::all_close_f(vector<float>({expected}), vector<float>({lower_bound}), tolerance_bits))
<< ss.str();
EXPECT_FALSE(test::close_f(expected, past_upper_bound, tolerance_bits)) << ss.str();
EXPECT_FALSE(test::close_f(expected, past_upper_bound, tolerance_bits, min_signal_too_low))
<< ss.str();
EXPECT_TRUE(
test::close_f(expected, past_upper_bound, tolerance_bits, min_signal_enables_passing))
<< ss.str();
EXPECT_FALSE(test::all_close_f(
vector<float>({expected}), vector<float>({past_upper_bound}), tolerance_bits))
<< ss.str();
EXPECT_FALSE(test::all_close_f(vector<float>({expected}),
vector<float>({past_upper_bound}),
tolerance_bits,
min_signal_too_low))
<< ss.str();
EXPECT_TRUE(test::all_close_f(vector<float>({expected}),
vector<float>({past_upper_bound}),
tolerance_bits,
min_signal_enables_passing))
<< ss.str();
EXPECT_FALSE(test::close_f(expected, past_lower_bound, tolerance_bits)) << ss.str();
EXPECT_FALSE(test::close_f(expected, past_lower_bound, tolerance_bits, min_signal_too_low))
<< ss.str();
EXPECT_TRUE(
test::close_f(expected, past_lower_bound, tolerance_bits, min_signal_enables_passing))
<< ss.str();
EXPECT_FALSE(test::all_close_f(
vector<float>({expected}), vector<float>({past_lower_bound}), tolerance_bits))
<< ss.str();
EXPECT_FALSE(test::all_close_f(vector<float>({expected}),
vector<float>({past_lower_bound}),
tolerance_bits,
min_signal_too_low))
<< ss.str();
EXPECT_TRUE(test::all_close_f(vector<float>({expected}),
vector<float>({past_lower_bound}),
tolerance_bits,
min_signal_enables_passing))
<< ss.str();
}
INSTANTIATE_TEST_CASE_P(test_simple_floats_with_range_of_precisions,
......@@ -270,6 +308,8 @@ protected:
uint64_t upper_bound_as_int = expected_as_int + targeted_bit;
upper_bound = DoubleUnion(upper_bound_as_int).d;
past_upper_bound = DoubleUnion(upper_bound_as_int + 1).d;
min_signal_too_low = expected;
min_signal_enables_passing = DoubleUnion(upper_bound_as_int + 2).d;
uint64_t lower_bound_as_int = expected_as_int - targeted_bit;
lower_bound = DoubleUnion(lower_bound_as_int).d;
......@@ -281,6 +321,8 @@ protected:
uint64_t lower_bound_as_int = expected_as_int + targeted_bit;
lower_bound = DoubleUnion(lower_bound_as_int).d;
past_lower_bound = DoubleUnion(lower_bound_as_int + 1).d;
min_signal_too_low = expected;
min_signal_enables_passing = DoubleUnion(lower_bound_as_int + 2).d;
uint64_t upper_bound_as_int = expected_as_int - targeted_bit;
upper_bound = DoubleUnion(upper_bound_as_int).d;
......@@ -293,6 +335,8 @@ protected:
upper_bound = DoubleUnion(upper_bound_as_int).d;
uint64_t past_upper_bound_as_int = upper_bound_as_int + 1;
past_upper_bound = DoubleUnion(past_upper_bound_as_int).d;
min_signal_too_low = expected;
min_signal_enables_passing = DoubleUnion(upper_bound_as_int + 2).d;
lower_bound = DoubleUnion(upper_bound_as_int | 0x8000000000000000).d;
past_lower_bound = DoubleUnion(past_upper_bound_as_int | 0x8000000000000000).d;
......@@ -305,6 +349,8 @@ protected:
double lower_bound;
double past_upper_bound;
double past_lower_bound;
double min_signal_too_low;
double min_signal_enables_passing;
};
TEST_P(all_close_f_double_param_test, test_boundaries)
......@@ -316,7 +362,6 @@ TEST_P(all_close_f_double_param_test, test_boundaries)
}
// Format verbose info to only print out in case of test failure
stringstream ss;
ss << "Testing target of: " << expected << " (" << double_to_bits(expected) << ")\n";
ss << "Matching to targets with: " << tolerance_bits << " tolerance_bits\n";
......@@ -336,13 +381,43 @@ TEST_P(all_close_f_double_param_test, test_boundaries)
vector<double>({expected}), vector<double>({lower_bound}), tolerance_bits))
<< ss.str();
EXPECT_FALSE(test::close_f(expected, past_upper_bound, tolerance_bits)) << ss.str();
EXPECT_FALSE(test::close_f(expected, past_upper_bound, tolerance_bits, min_signal_too_low))
<< ss.str();
EXPECT_TRUE(
test::close_f(expected, past_upper_bound, tolerance_bits, min_signal_enables_passing))
<< ss.str();
EXPECT_FALSE(test::all_close_f(
vector<double>({expected}), vector<double>({past_upper_bound}), tolerance_bits))
<< ss.str();
EXPECT_FALSE(test::all_close_f(vector<double>({expected}),
vector<double>({past_upper_bound}),
tolerance_bits,
min_signal_too_low))
<< ss.str();
EXPECT_TRUE(test::all_close_f(vector<double>({expected}),
vector<double>({past_upper_bound}),
tolerance_bits,
min_signal_enables_passing))
<< ss.str();
EXPECT_FALSE(test::close_f(expected, past_lower_bound, tolerance_bits)) << ss.str();
EXPECT_FALSE(test::close_f(expected, past_lower_bound, tolerance_bits, min_signal_too_low))
<< ss.str();
EXPECT_TRUE(
test::close_f(expected, past_lower_bound, tolerance_bits, min_signal_enables_passing))
<< ss.str();
EXPECT_FALSE(test::all_close_f(
vector<double>({expected}), vector<double>({past_lower_bound}), tolerance_bits))
<< ss.str();
EXPECT_FALSE(test::all_close_f(vector<double>({expected}),
vector<double>({past_lower_bound}),
tolerance_bits,
min_signal_too_low))
<< ss.str();
EXPECT_TRUE(test::all_close_f(vector<double>({expected}),
vector<double>({past_lower_bound}),
tolerance_bits,
min_signal_enables_passing))
<< ss.str();
}
INSTANTIATE_TEST_CASE_P(
......@@ -397,6 +472,8 @@ TEST(all_close_f, mantissa_8_near_0)
// 0.f, the ground-truth value
float expected = bits_to_float("0 00000000 000 0000 0000 0000 0000 0000");
float computed;
float min_signal_too_low = bits_to_float("0 00000000 000 0100 0000 0000 0000 0001");
float min_signal_enables_passing = bits_to_float("0 00000000 000 0100 0000 0000 0000 0010");
// ~3.67342E-40, the exact upper bound
computed = bits_to_float("0 00000000 000 0100 0000 0000 0000 0000");
......@@ -407,8 +484,16 @@ TEST(all_close_f, mantissa_8_near_0)
// ~3.67343E-40, the next representable number bigger than upper bound
computed = bits_to_float("0 00000000 000 0100 0000 0000 0000 0001");
EXPECT_FALSE(test::close_f(expected, computed, tolerance_bits));
EXPECT_FALSE(test::close_f(expected, computed, tolerance_bits, min_signal_too_low));
EXPECT_TRUE(test::close_f(expected, computed, tolerance_bits, min_signal_enables_passing));
EXPECT_FALSE(
test::all_close_f(vector<float>({expected}), vector<float>({computed}), tolerance_bits));
EXPECT_FALSE(test::all_close_f(
vector<float>({expected}), vector<float>({computed}), tolerance_bits, min_signal_too_low));
EXPECT_TRUE(test::all_close_f(vector<float>({expected}),
vector<float>({computed}),
tolerance_bits,
min_signal_enables_passing));
// ~-3.67342E-40, the exact lower bound
computed = bits_to_float("1 00000000 000 0100 0000 0000 0000 0000");
......@@ -419,8 +504,16 @@ TEST(all_close_f, mantissa_8_near_0)
// ~-3.67343E-40, the next representable number smaller than lower bound
computed = bits_to_float("1 00000000 000 0100 0000 0000 0000 0001");
EXPECT_FALSE(test::close_f(expected, computed, tolerance_bits));
EXPECT_FALSE(test::close_f(expected, computed, tolerance_bits, min_signal_too_low));
EXPECT_TRUE(test::close_f(expected, computed, tolerance_bits, min_signal_enables_passing));
EXPECT_FALSE(
test::all_close_f(vector<float>({expected}), vector<float>({computed}), tolerance_bits));
EXPECT_FALSE(test::all_close_f(
vector<float>({expected}), vector<float>({computed}), tolerance_bits, min_signal_too_low));
EXPECT_TRUE(test::all_close_f(vector<float>({expected}),
vector<float>({computed}),
tolerance_bits,
min_signal_enables_passing));
}
// Test the exact bounds near -0.f
......@@ -467,6 +560,8 @@ TEST(all_close_f, mantissa_8_near_n0)
// 0.f, the ground-truth value
float expected = bits_to_float("1 00000000 000 0000 0000 0000 0000 0000");
float computed;
float min_signal_too_low = bits_to_float("0 00000000 000 0100 0000 0000 0000 0001");
float min_signal_enables_passing = bits_to_float("0 00000000 000 0100 0000 0000 0000 0010");
// ~3.67342E-40, the exact upper bound
computed = bits_to_float("0 00000000 000 0100 0000 0000 0000 0000");
......@@ -477,8 +572,16 @@ TEST(all_close_f, mantissa_8_near_n0)
// ~3.67343E-40, the next representable number bigger than upper bound
computed = bits_to_float("0 00000000 000 0100 0000 0000 0000 0001");
EXPECT_FALSE(test::close_f(expected, computed, tolerance_bits));
EXPECT_FALSE(test::close_f(expected, computed, tolerance_bits, min_signal_too_low));
EXPECT_TRUE(test::close_f(expected, computed, tolerance_bits, min_signal_enables_passing));
EXPECT_FALSE(
test::all_close_f(vector<float>({expected}), vector<float>({computed}), tolerance_bits));
EXPECT_FALSE(test::all_close_f(
vector<float>({expected}), vector<float>({computed}), tolerance_bits, min_signal_too_low));
EXPECT_TRUE(test::all_close_f(vector<float>({expected}),
vector<float>({computed}),
tolerance_bits,
min_signal_enables_passing));
// ~-3.67342E-40, the exact lower bound
computed = bits_to_float("1 00000000 000 0100 0000 0000 0000 0000");
......@@ -489,8 +592,16 @@ TEST(all_close_f, mantissa_8_near_n0)
// ~-3.67343E-40, the next representable number smaller than lower bound
computed = bits_to_float("1 00000000 000 0100 0000 0000 0000 0001");
EXPECT_FALSE(test::close_f(expected, computed, tolerance_bits));
EXPECT_FALSE(test::close_f(expected, computed, tolerance_bits, min_signal_too_low));
EXPECT_TRUE(test::close_f(expected, computed, tolerance_bits, min_signal_enables_passing));
EXPECT_FALSE(
test::all_close_f(vector<float>({expected}), vector<float>({computed}), tolerance_bits));
EXPECT_FALSE(test::all_close_f(
vector<float>({expected}), vector<float>({computed}), tolerance_bits, min_signal_too_low));
EXPECT_TRUE(test::all_close_f(vector<float>({expected}),
vector<float>({computed}),
tolerance_bits,
min_signal_enables_passing));
}
// Test the exact bounds near 1.f
......
......@@ -32,15 +32,21 @@ union DoubleUnion {
uint64_t i;
};
uint32_t test::float_distance(float a, float b)
constexpr uint32_t FLOAT_BELOW_MIN_SIGNAL = UINT_MAX;
constexpr uint32_t FLOAT_MAX_DIFF = UINT_MAX - 1;
constexpr uint64_t DOUBLE_BELOW_MIN_SIGNAL = ULLONG_MAX;
constexpr uint64_t DOUBLE_MAX_DIFF = ULLONG_MAX - 1;
uint32_t test::float_distance(float a, float b, float min_signal)
{
if (!isfinite(a) || !isfinite(b))
{
return UINT_MAX;
return FLOAT_MAX_DIFF;
}
FloatUnion a_fu{a};
FloatUnion b_fu{b};
FloatUnion min_signal_fu{min_signal};
uint32_t a_uint = a_fu.i;
uint32_t b_uint = b_fu.i;
......@@ -48,22 +54,42 @@ uint32_t test::float_distance(float a, float b)
// - If negative: convert to two's complement
// - If positive: mask with sign bit
uint32_t sign_mask = static_cast<uint32_t>(1U) << 31;
uint32_t abs_value_bits_mask = ~sign_mask;
a_uint = (sign_mask & a_uint) ? (~a_uint + 1) : (sign_mask | a_uint);
b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);
uint32_t distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
uint32_t distance;
uint32_t a_uint_abs = (abs_value_bits_mask & a_fu.i);
uint32_t b_uint_abs = (abs_value_bits_mask & b_fu.i);
uint32_t min_signal_uint_abs = (abs_value_bits_mask & min_signal_fu.i);
if ((a_uint_abs < min_signal_uint_abs) && (b_uint_abs < min_signal_uint_abs))
{
// Both a & b below minimum signal
distance = FLOAT_BELOW_MIN_SIGNAL;
}
else
{
distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
// We've reserved UINT_MAX to mean FLOAT_BELOW_MIN_SIGNAL
if (distance == UINT_MAX)
{
distance = FLOAT_MAX_DIFF;
}
}
return distance;
}
uint64_t test::float_distance(double a, double b)
uint64_t test::float_distance(double a, double b, double min_signal)
{
if (!isfinite(a) || !isfinite(b))
{
return ULLONG_MAX;
return DOUBLE_MAX_DIFF;
}
DoubleUnion a_du{a};
DoubleUnion b_du{b};
DoubleUnion min_signal_du{min_signal};
uint64_t a_uint = a_du.i;
uint64_t b_uint = b_du.i;
......@@ -71,14 +97,33 @@ uint64_t test::float_distance(double a, double b)
// - If negative: convert to two's complement
// - If positive: mask with sign bit
uint64_t sign_mask = static_cast<uint64_t>(1U) << 63;
uint64_t abs_value_bits_mask = ~sign_mask;
a_uint = (sign_mask & a_uint) ? (~a_uint + 1) : (sign_mask | a_uint);
b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);
uint64_t distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
uint64_t distance;
uint64_t a_uint_abs = (abs_value_bits_mask & a_du.i);
uint64_t b_uint_abs = (abs_value_bits_mask & b_du.i);
uint64_t min_signal_uint_abs = (abs_value_bits_mask & min_signal_du.i);
if ((a_uint_abs < min_signal_uint_abs) && (b_uint_abs < min_signal_uint_abs))
{
// Both a & b below minimum signal
distance = DOUBLE_BELOW_MIN_SIGNAL;
}
else
{
distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
// We've reserved ULLONG_MAX to mean DOUBLE_BELOW_MIN_SIGNAL
if (distance == ULLONG_MAX)
{
distance = DOUBLE_MAX_DIFF;
}
}
return distance;
}
bool test::close_f(float a, float b, int tolerance_bits)
bool test::close_f(float a, float b, int tolerance_bits, float min_signal)
{
// isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b))
......@@ -86,7 +131,7 @@ bool test::close_f(float a, float b, int tolerance_bits)
return false;
}
uint32_t distance = float_distance(a, b);
uint32_t distance = float_distance(a, b, min_signal);
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
......@@ -94,10 +139,10 @@ bool test::close_f(float a, float b, int tolerance_bits)
uint32_t tolerance_bit_shift = 32 - (1 + 8 + (FLOAT_MANTISSA_BITS - 1) - tolerance_bits);
uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;
return distance <= tolerance;
return (distance <= tolerance) || (distance == FLOAT_BELOW_MIN_SIGNAL);
}
bool test::close_f(double a, double b, int tolerance_bits)
bool test::close_f(double a, double b, int tolerance_bits, double min_signal)
{
// isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b))
......@@ -105,7 +150,7 @@ bool test::close_f(double a, double b, int tolerance_bits)
return false;
}
uint64_t distance = float_distance(a, b);
uint64_t distance = float_distance(a, b, min_signal);
// e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
......@@ -113,10 +158,11 @@ bool test::close_f(double a, double b, int tolerance_bits)
uint64_t tolerance_bit_shift = 64 - (1 + 11 + (DOUBLE_MANTISSA_BITS - 1) - tolerance_bits);
uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;
return distance <= tolerance;
return (distance <= tolerance) || (distance == DOUBLE_BELOW_MIN_SIGNAL);
}
vector<uint32_t> test::float_distances(const vector<float>& a, const vector<float>& b)
vector<uint32_t>
test::float_distances(const vector<float>& a, const vector<float>& b, float min_signal)
{
if (a.size() != b.size())
{
......@@ -125,13 +171,14 @@ vector<uint32_t> test::float_distances(const vector<float>& a, const vector<floa
vector<uint32_t> distances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
distances[i] = float_distance(a[i], b[i]);
distances[i] = float_distance(a[i], b[i], min_signal);
}
return distances;
}
vector<uint64_t> test::float_distances(const vector<double>& a, const vector<double>& b)
vector<uint64_t>
test::float_distances(const vector<double>& a, const vector<double>& b, double min_signal)
{
if (a.size() != b.size())
{
......@@ -140,7 +187,7 @@ vector<uint64_t> test::float_distances(const vector<double>& a, const vector<dou
vector<uint64_t> distances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
distances[i] = float_distance(a[i], b[i]);
distances[i] = float_distance(a[i], b[i], min_signal);
}
return distances;
......@@ -220,8 +267,10 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
return matching_matissa_bits;
}
::testing::AssertionResult
test::all_close_f(const vector<float>& a, const vector<float>& b, int tolerance_bits)
::testing::AssertionResult test::all_close_f(const vector<float>& a,
const vector<float>& b,
int tolerance_bits,
float min_signal)
{
if (tolerance_bits < MIN_FLOAT_TOLERANCE_BITS)
{
......@@ -242,7 +291,7 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
{
return ::testing::AssertionSuccess() << "No elements to compare";
}
vector<uint32_t> distances = float_distances(a, b);
vector<uint32_t> distances = float_distances(a, b, min_signal);
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
......@@ -250,12 +299,20 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
uint32_t tolerance_bit_shift = 32 - (1 + 8 + (FLOAT_MANTISSA_BITS - 1) - tolerance_bits);
uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;
uint32_t max_distance = 0;
uint32_t min_distance = UINT_MAX;
uint32_t min_distance = FLOAT_BELOW_MIN_SIGNAL;
size_t max_distance_index = 0;
size_t min_distance_index = 0;
size_t diff_count = 0;
size_t below_min_count = 0;
for (size_t i = 0; i < a.size(); ++i)
{
if (distances[i] == FLOAT_BELOW_MIN_SIGNAL)
{
// Special value that indicates both values were below min_signal
below_min_count++;
continue;
}
if (distances[i] > max_distance)
{
max_distance = distances[i];
......@@ -295,18 +352,35 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
median_distance = median_sum / 2;
}
bool all_below_min_signal = below_min_count == distances.size();
if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr))
{
// Short unobtrusive message when passing
std::cout << "[ INFO ] Verifying match of <= " << (FLOAT_MANTISSA_BITS - tolerance_bits)
<< " mantissa bits (" << FLOAT_MANTISSA_BITS << " bits precision - "
<< tolerance_bits << " tolerance). Loosest match found is "
<< matching_mantissa_bits(max_distance) << " mantissa bits.\n";
<< tolerance_bits << " tolerance). ";
if (all_below_min_signal)
{
std::cout << "All values below min_signal: " << min_signal << "\n";
}
else
{
std::cout << below_min_count << " value(s) below min_signal: " << min_signal
<< " Loosest match found is " << matching_mantissa_bits(max_distance)
<< " mantissa bits.\n";
}
}
msg << "passing criteria - mismatch allowed @ mantissa bit: "
<< (FLOAT_MANTISSA_BITS - tolerance_bits) << " or later (" << tolerance_bits
<< " tolerance bits)\n";
if (all_below_min_signal)
{
msg << "All values below min_signal: " << min_signal << "\n";
}
else
{
msg << below_min_count << " value(s) below min_signal: " << min_signal << "\n";
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index]
......@@ -317,6 +391,7 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
<< " vs " << b[max_distance_index] << " at [" << max_distance_index << "])\n";
msg << "median match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(median_distance) << " or next bit\n";
}
::testing::AssertionResult res =
rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
......@@ -324,8 +399,10 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
return res;
}
::testing::AssertionResult
test::all_close_f(const vector<double>& a, const vector<double>& b, int tolerance_bits)
::testing::AssertionResult test::all_close_f(const vector<double>& a,
const vector<double>& b,
int tolerance_bits,
double min_signal)
{
if (tolerance_bits < 0)
{
......@@ -346,8 +423,7 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
{
return ::testing::AssertionSuccess() << "No elements to compare";
}
vector<uint64_t> distances = float_distances(a, b);
vector<uint64_t> distances = float_distances(a, b, min_signal);
// e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
......@@ -355,12 +431,20 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
uint64_t tolerance_bit_shift = 64 - (1 + 11 + (DOUBLE_MANTISSA_BITS - 1) - tolerance_bits);
uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;
uint64_t max_distance = 0;
uint64_t min_distance = ULLONG_MAX;
uint64_t min_distance = DOUBLE_BELOW_MIN_SIGNAL;
size_t max_distance_index = 0;
size_t min_distance_index = 0;
size_t diff_count = 0;
size_t below_min_count = 0;
for (size_t i = 0; i < a.size(); ++i)
{
if (distances[i] == DOUBLE_BELOW_MIN_SIGNAL)
{
// Special value that indicates both values were below min_signal
below_min_count++;
continue;
}
if (distances[i] > max_distance)
{
max_distance = distances[i];
......@@ -383,7 +467,10 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
diff_count++;
}
}
if (!rc)
{
msg << "diff count: " << diff_count << " out of " << a.size() << "\n";
}
// Find median value via partial sorting
size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
......@@ -397,19 +484,36 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
(median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2);
}
bool all_below_min_signal = below_min_count == distances.size();
if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr))
{
// Short unobtrusive message when passing
std::cout << "[ INFO ] Verifying match of >= "
<< (DOUBLE_MANTISSA_BITS - tolerance_bits) << " mantissa bits ("
<< DOUBLE_MANTISSA_BITS << " bits precision - " << tolerance_bits
<< " tolerance). Loosest match found is " << matching_mantissa_bits(max_distance)
<< " tolerance). ";
if (all_below_min_signal)
{
std::cout << "All values below min_signal: " << min_signal << "\n";
}
else
{
std::cout << below_min_count << " value(s) below min_signal: " << min_signal
<< " Loosest match found is " << matching_mantissa_bits(max_distance)
<< " mantissa bits.\n";
}
}
msg << "passing criteria - mismatch allowed @ mantissa bit: "
<< (DOUBLE_MANTISSA_BITS - tolerance_bits) << " or later (" << tolerance_bits
<< " tolerance bits)\n";
if (all_below_min_signal)
{
msg << "All values below min_signal: " << min_signal << "\n";
}
else
{
msg << below_min_count << " value(s) below min_signal: " << min_signal << "\n";
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index]
......@@ -420,6 +524,7 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
<< " vs " << b[max_distance_index] << " at [" << max_distance_index << "])\n";
msg << "median match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(median_distance) << " or next bit\n";
}
::testing::AssertionResult res =
rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
......@@ -429,7 +534,8 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
::testing::AssertionResult test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b,
int tolerance_bits)
int tolerance_bits,
float min_signal)
{
// Check that the layouts are compatible
if (*a->get_tensor_layout() != *b->get_tensor_layout())
......@@ -441,13 +547,15 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
return ::testing::AssertionFailure() << "Cannot compare tensors with different shapes";
}
return test::all_close_f(read_float_vector(a), read_float_vector(b), tolerance_bits);
return test::all_close_f(
read_float_vector(a), read_float_vector(b), tolerance_bits, min_signal);
}
::testing::AssertionResult
test::all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int tolerance_bits)
int tolerance_bits,
float min_signal)
{
if (as.size() != bs.size())
{
......@@ -455,7 +563,7 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
}
for (size_t i = 0; i < as.size(); ++i)
{
auto ar = test::all_close_f(as[i], bs[i], tolerance_bits);
auto ar = test::all_close_f(as[i], bs[i], tolerance_bits, min_signal);
if (!ar)
{
return ar;
......
......@@ -66,6 +66,7 @@ namespace ngraph
/// \brief Determine distance between two f32 numbers
/// \param a First number to compare
/// \param b Second number to compare
/// \param min_signal Minimum value for comparisons
/// \returns Distance
///
/// References:
......@@ -82,11 +83,12 @@ namespace ngraph
///
/// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for
/// bfloat and f32.
uint32_t float_distance(float a, float b);
uint32_t float_distance(float a, float b, float min_signal = 0.0f);
/// \brief Determine distance between two f64 numbers
/// \param a First number to compare
/// \param b Second number to compare
/// \param min_signal Minimum value for comparisons
/// \returns Distance
///
/// References:
......@@ -100,12 +102,13 @@ namespace ngraph
/// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64.
uint64_t float_distance(double a, double b);
uint64_t float_distance(double a, double b, double min_signal = 0.0);
/// \brief Check if the two f32 numbers are close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP
///
/// References:
......@@ -122,12 +125,16 @@ namespace ngraph
///
/// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for
/// bfloat and f32.
bool close_f(float a, float b, int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
bool close_f(float a,
float b,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
/// \brief Check if the two f64 numbers are close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP
///
/// References:
......@@ -141,25 +148,32 @@ namespace ngraph
/// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64.
bool close_f(double a, double b, int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS);
bool close_f(double a,
double b,
int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS,
double min_signal = 0.0);
/// \brief Determine distances between two vectors of f32 numbers
/// \param a Vector of floats to compare
/// \param b Vector of floats to compare
/// \param min_signal Minimum value for comparisons
/// \returns Vector of distances
///
/// See float_distance for limitations and assumptions.
std::vector<uint32_t> float_distances(const std::vector<float>& a,
const std::vector<float>& b);
const std::vector<float>& b,
float min_signal = 0.0f);
/// \brief Determine distances between two vectors of f64 numbers
/// \param a Vector of doubles to compare
/// \param b Vector of doubles to compare
/// \param min_signal Minimum value for comparisons
/// \returns Vector of distances
///
/// See float_distance for limitations and assumptions.
std::vector<uint64_t> float_distances(const std::vector<double>& a,
const std::vector<double>& b);
const std::vector<double>& b,
double min_signal = 0.0);
/// \brief Determine number of matching mantissa bits given a distance
/// \param distance Distance calculated by float_distance
......@@ -179,37 +193,45 @@ namespace ngraph
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
::testing::AssertionResult all_close_f(const std::vector<float>& a,
const std::vector<float>& b,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
/// \brief Check if the two double floating point vectors are all close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
::testing::AssertionResult all_close_f(const std::vector<double>& a,
const std::vector<double>& b,
int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS);
int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS,
double min_signal = 0.0);
/// \brief Check if the two TensorViews are all close in float
/// \param a First Tensor to compare
/// \param b Second Tensor to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// Returns true iff the two TensorViews are all close in float
::testing::AssertionResult all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
/// \brief Check if the two vectors of TensorViews are all close in float
/// \param as First vector of Tensor to compare
/// \param bs Second vector of Tensor to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// Returns true iff the two TensorViews are all close in float
::testing::AssertionResult
all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment