Commit 656b2f72 authored by gcwenger's avatar gcwenger Committed by Robert Kimball

Addressed some all_close_f loose ends (#2200)

* Simplified all_close_f tolerance determination. Made all_close_f less verbose. Transitioned graph comparison to use double comparison when appropriate.

* Reworked all_close_f AssertionResult returning code to include message in AssertionResult on success also.

* Added CMake option NGRAPH_GTEST_INFO, which controls whether or not any gtest info cout is done. Quiets all_close_f tests by default.

* Moved NGRAPH_GTEST_INFO from build option to environment variable.
parent dd41bb62
......@@ -150,9 +150,12 @@ protected:
TEST_P(all_close_f_param_test, test_boundaries)
{
if (std::getenv("NGRAPH_GTEST_INFO") != nullptr)
{
// Print short string documenting which test is being run
std::cout << "[ INFO ] Test params: (" << expected << ", " << mantissa_bits << ", "
<< tolerance_bits << ")\n";
}
// Format verbose info to only print out in case of test failure
stringstream ss;
......@@ -190,7 +193,6 @@ TEST_P(all_close_f_param_test, test_boundaries)
<< ss.str();
}
// Avoid warning with how gtest defines INSTANTIATE_TEST_CASE_P
INSTANTIATE_TEST_CASE_P(
test_simple_floats_with_range_of_precisions,
all_close_f_param_test,
......@@ -279,8 +281,11 @@ protected:
TEST_P(all_close_f_double_param_test, test_boundaries)
{
if (std::getenv("NGRAPH_GTEST_INFO") != nullptr)
{
// Print short string documenting which test is being run
std::cout << "[ INFO ] Test params: (" << expected << ", " << tolerance_bits << ")\n";
}
// Format verbose info to only print out in case of test failure
......@@ -313,7 +318,6 @@ TEST_P(all_close_f_double_param_test, test_boundaries)
<< ss.str();
}
// Avoid warning with how gtest defines INSTANTIATE_TEST_CASE_P
INSTANTIATE_TEST_CASE_P(
test_simple_doubles_with_range_of_precisions,
all_close_f_double_param_test,
......
......@@ -92,7 +92,7 @@ public:
test::all_close<char>(ref_data_vector, bk_isolated_data_vector);
EXPECT_TRUE(all_close_graph && all_close_isolated);
}
else if ((et == element::f32) || (et == element::f64))
else if (et == element::f32)
{
vector<float> ref_data_vector = read_float_vector(ref_data);
vector<float> bk_data_vector = read_float_vector(bk_data);
......@@ -106,6 +106,26 @@ public:
test::all_close_f(ref_data_vector, bk_isolated_data_vector);
EXPECT_TRUE(all_close_graph && all_close_isolated);
}
else if (et == element::f64)
{
vector<double> ref_data_vector = read_vector<double>(ref_data);
vector<double> bk_data_vector = read_vector<double>(bk_data);
vector<double> bk_isolated_data_vector = read_vector<double>(bk_isolated_data);
cout << "Test backed op run w/ original graph dependencies:" << endl;
print_results(ref_data_vector, bk_data_vector);
// When testing with original graph dependencies test w/ loose f64 tolerance
constexpr int tolerance_bits = 30;
bool all_close_graph =
test::all_close_f(ref_data_vector, bk_data_vector, tolerance_bits);
cout << "Test backed op run isolated w/ inputs from ref graph run:" << endl;
print_results(ref_data_vector, bk_isolated_data_vector);
// When testing with isolated graph dependencies test w/ default (tight) f64 tolerance
bool all_close_isolated =
test::all_close_f(ref_data_vector, bk_isolated_data_vector);
EXPECT_TRUE(all_close_graph && all_close_isolated);
}
else if (et == element::i8)
{
vector<int8_t> ref_data_vector = read_vector<int8_t>(ref_data);
......
......@@ -150,29 +150,23 @@ vector<uint64_t> test::float_distances(const vector<double>& a, const vector<dou
uint32_t test::matching_mantissa_bits(uint32_t distance)
{
uint32_t tolerance_needed = distance;
uint32_t tolerance_bit_shift = 0;
uint32_t num_bits_on = 0;
if (tolerance_needed < 0x80000000)
// Do some bit probing to find the most significant bit that's on,
// as well as how many bits are on.
for (uint32_t check_bit = 0; check_bit < 32; ++check_bit)
{
// Set up the dominos - turn on all the bits below maximal bit
tolerance_needed |= tolerance_needed >> 1;
tolerance_needed |= tolerance_needed >> 2;
tolerance_needed |= tolerance_needed >> 4;
tolerance_needed |= tolerance_needed >> 8;
tolerance_needed |= tolerance_needed >> 16;
// Tumble the dominos so we end up with next highest bit
++tolerance_needed;
// all_close_f is <= test for tolerance
if ((tolerance_needed >> 1) == distance)
if (distance & (1 << check_bit))
{
tolerance_needed = distance;
tolerance_bit_shift = check_bit;
++num_bits_on;
}
}
uint32_t tolerance_bit_shift = 0;
while (tolerance_needed >>= 1)
// all_close_f is <= test for tolerance (where tolerance is uint32_t with single bit on)
// So if more than one bit is on we need the next higher tolerance
if (num_bits_on > 1)
{
++tolerance_bit_shift;
}
......@@ -191,32 +185,25 @@ uint32_t test::matching_mantissa_bits(uint32_t distance)
return matching_matissa_bits;
}
uint64_t test::matching_mantissa_bits(uint64_t distance)
uint32_t test::matching_mantissa_bits(uint64_t distance)
{
uint64_t tolerance_needed = distance;
uint32_t tolerance_bit_shift = 0;
uint32_t num_bits_on = 0;
if (tolerance_needed < 0x8000000000000000)
// Do some bit probing to find the most significant bit that's on,
// as well as how many bits are on.
for (uint32_t check_bit = 0; check_bit < 64; ++check_bit)
{
// Set up the dominos - turn on all the bits below maximal bit
tolerance_needed |= tolerance_needed >> 1;
tolerance_needed |= tolerance_needed >> 2;
tolerance_needed |= tolerance_needed >> 4;
tolerance_needed |= tolerance_needed >> 8;
tolerance_needed |= tolerance_needed >> 16;
tolerance_needed |= tolerance_needed >> 32;
// Tumble the dominos so we end up with next highest bit
++tolerance_needed;
// all_close_f is <= test for tolerance
if ((tolerance_needed >> 1) == distance)
if (distance & (1ull << check_bit))
{
tolerance_needed = distance;
tolerance_bit_shift = check_bit;
++num_bits_on;
}
}
uint64_t tolerance_bit_shift = 0;
while (tolerance_needed >>= 1)
// all_close_f is <= test for tolerance (where tolerance is uint64_t with single bit on)
// So if more than one bit is on we need the next higher tolerance
if (num_bits_on > 1)
{
++tolerance_bit_shift;
}
......@@ -230,7 +217,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
// tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) - 0 )
// tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) )
// matching_matissa_bits = 64 - (1 + 11 + (tolerance_bit_shift - 1 ) )
uint64_t matching_matissa_bits =
uint32_t matching_matissa_bits =
tolerance_bit_shift < 53 ? (64 - (1 + 11 + (tolerance_bit_shift - 1))) : 0;
return matching_matissa_bits;
}
......@@ -241,7 +228,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
int tolerance_bits)
{
bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
stringstream msg;
if (a.size() != b.size())
{
return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
......@@ -275,7 +262,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
{
if (diff_count < 5)
{
ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i]
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i]
<< " is not close to " << b[i] << " at index " << i << "\n";
}
......@@ -285,7 +272,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
}
if (!rc)
{
ar_fail << "diff count: " << diff_count << " out of " << a.size() << "\n";
msg << "diff count: " << diff_count << " out of " << a.size() << "\n";
}
// Find median value via partial sorting
size_t middle = distances.size() / 2;
......@@ -299,20 +286,31 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
median_distance = median_sum / 2;
}
ar_fail << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr))
{
// Short unobtrusive message when passing
std::cout << "[ INFO ] Verifying match of >= " << (mantissa_bits - tolerance_bits)
<< " mantissa bits (" << mantissa_bits << " bits precision - " << tolerance_bits
<< " tolerance). Loosest match found is " << matching_mantissa_bits(max_distance)
<< " mantissa bits.\n";
}
msg << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n";
ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match: " << matching_mantissa_bits(min_distance) << " mantissa bits ("
<< a[min_distance_index] << " vs " << b[min_distance_index] << " at ["
<< min_distance_index << "])\n";
ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< a[min_distance_index] << " vs " << b[min_distance_index] << " at [" << min_distance_index
<< "])\n";
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match: " << matching_mantissa_bits(max_distance) << " mantissa bits ("
<< a[max_distance_index] << " vs " << b[max_distance_index] << " at ["
<< max_distance_index << "])\n";
ar_fail << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits\n";
return rc ? ::testing::AssertionSuccess() : ar_fail;
<< a[max_distance_index] << " vs " << b[max_distance_index] << " at [" << max_distance_index
<< "])\n";
msg << "median match: " << matching_mantissa_bits(median_distance) << " mantissa bits\n";
::testing::AssertionResult res =
rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
res << msg.str();
return res;
}
::testing::AssertionResult
......@@ -321,7 +319,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
constexpr int mantissa_bits = 53;
bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
stringstream msg;
if (a.size() != b.size())
{
return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
......@@ -355,14 +353,14 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
{
if (diff_count < 5)
{
ar_fail << a[i] << " is not close to " << b[i] << " at index " << i << "\n";
msg << a[i] << " is not close to " << b[i] << " at index " << i << "\n";
}
rc = false;
diff_count++;
}
}
ar_fail << "diff count: " << diff_count << " out of " << a.size() << "\n";
msg << "diff count: " << diff_count << " out of " << a.size() << "\n";
// Find median value via partial sorting
size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
......@@ -376,18 +374,31 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
(median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2);
}
ar_fail << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr))
{
// Short unobtrusive message when passing
std::cout << "[ INFO ] Verifying match of >= " << (mantissa_bits - tolerance_bits)
<< " mantissa bits (" << mantissa_bits << " bits precision - " << tolerance_bits
<< " tolerance). Loosest match found is " << matching_mantissa_bits(max_distance)
<< " mantissa bits.\n";
}
msg << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n";
ar_fail << "tightest match: " << matching_mantissa_bits(min_distance) << " mantissa bits ("
<< a[min_distance_index] << " vs " << b[min_distance_index] << " at ["
<< min_distance_index << "])\n";
ar_fail << "loosest match: " << matching_mantissa_bits(max_distance) << " mantissa bits ("
<< a[max_distance_index] << " vs " << b[max_distance_index] << " at ["
<< max_distance_index << "])\n";
ar_fail << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits\n";
return rc ? ::testing::AssertionSuccess() : ar_fail;
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match: " << matching_mantissa_bits(min_distance) << " mantissa bits ("
<< a[min_distance_index] << " vs " << b[min_distance_index] << " at [" << min_distance_index
<< "])\n";
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match: " << matching_mantissa_bits(max_distance) << " mantissa bits ("
<< a[max_distance_index] << " vs " << b[max_distance_index] << " at [" << max_distance_index
<< "])\n";
msg << "median match: " << matching_mantissa_bits(median_distance) << " mantissa bits\n";
::testing::AssertionResult res =
rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
res << msg.str();
return res;
}
::testing::AssertionResult test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
......
......@@ -137,7 +137,7 @@ namespace ngraph
/// \returns Number of matching mantissa bits
///
/// See float_distance for limitations and assumptions.
uint64_t matching_mantissa_bits(uint64_t distance);
uint32_t matching_mantissa_bits(uint64_t distance);
/// \brief Check if the two floating point vectors are all close
/// \param a First number to compare
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment