Commit 656b2f72 authored by gcwenger's avatar gcwenger Committed by Robert Kimball

Addressed some all_close_f loose ends (#2200)

* Simplified all_close_f tolerance determination. Made all_close_f less verbose. Transitioned graph comparison to use double comparison when appropriate.

* Reworked all_close_f AssertionResult returning code to include message in AssertionResult on success also.

* Added CMake option NGRAPH_GTEST_INFO, which controls whether or not any gtest info cout is done. Quiets all_close_f tests by default.

* Moved NGRAPH_GTEST_INFO from build option to environment variable.
parent dd41bb62
...@@ -150,9 +150,12 @@ protected: ...@@ -150,9 +150,12 @@ protected:
TEST_P(all_close_f_param_test, test_boundaries) TEST_P(all_close_f_param_test, test_boundaries)
{ {
// Print short string documenting which test is being run if (std::getenv("NGRAPH_GTEST_INFO") != nullptr)
std::cout << "[ INFO ] Test params: (" << expected << ", " << mantissa_bits << ", " {
<< tolerance_bits << ")\n"; // Print short string documenting which test is being run
std::cout << "[ INFO ] Test params: (" << expected << ", " << mantissa_bits << ", "
<< tolerance_bits << ")\n";
}
// Format verbose info to only print out in case of test failure // Format verbose info to only print out in case of test failure
stringstream ss; stringstream ss;
...@@ -190,7 +193,6 @@ TEST_P(all_close_f_param_test, test_boundaries) ...@@ -190,7 +193,6 @@ TEST_P(all_close_f_param_test, test_boundaries)
<< ss.str(); << ss.str();
} }
// Avoid warning with how gtest defines INSTANTIATE_TEST_CASE_P
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
test_simple_floats_with_range_of_precisions, test_simple_floats_with_range_of_precisions,
all_close_f_param_test, all_close_f_param_test,
...@@ -279,8 +281,11 @@ protected: ...@@ -279,8 +281,11 @@ protected:
TEST_P(all_close_f_double_param_test, test_boundaries) TEST_P(all_close_f_double_param_test, test_boundaries)
{ {
// Print short string documenting which test is being run if (std::getenv("NGRAPH_GTEST_INFO") != nullptr)
std::cout << "[ INFO ] Test params: (" << expected << ", " << tolerance_bits << ")\n"; {
// Print short string documenting which test is being run
std::cout << "[ INFO ] Test params: (" << expected << ", " << tolerance_bits << ")\n";
}
// Format verbose info to only print out in case of test failure // Format verbose info to only print out in case of test failure
...@@ -313,7 +318,6 @@ TEST_P(all_close_f_double_param_test, test_boundaries) ...@@ -313,7 +318,6 @@ TEST_P(all_close_f_double_param_test, test_boundaries)
<< ss.str(); << ss.str();
} }
// Avoid warning with how gtest defines INSTANTIATE_TEST_CASE_P
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
test_simple_doubles_with_range_of_precisions, test_simple_doubles_with_range_of_precisions,
all_close_f_double_param_test, all_close_f_double_param_test,
......
...@@ -92,7 +92,7 @@ public: ...@@ -92,7 +92,7 @@ public:
test::all_close<char>(ref_data_vector, bk_isolated_data_vector); test::all_close<char>(ref_data_vector, bk_isolated_data_vector);
EXPECT_TRUE(all_close_graph && all_close_isolated); EXPECT_TRUE(all_close_graph && all_close_isolated);
} }
else if ((et == element::f32) || (et == element::f64)) else if (et == element::f32)
{ {
vector<float> ref_data_vector = read_float_vector(ref_data); vector<float> ref_data_vector = read_float_vector(ref_data);
vector<float> bk_data_vector = read_float_vector(bk_data); vector<float> bk_data_vector = read_float_vector(bk_data);
...@@ -106,6 +106,26 @@ public: ...@@ -106,6 +106,26 @@ public:
test::all_close_f(ref_data_vector, bk_isolated_data_vector); test::all_close_f(ref_data_vector, bk_isolated_data_vector);
EXPECT_TRUE(all_close_graph && all_close_isolated); EXPECT_TRUE(all_close_graph && all_close_isolated);
} }
else if (et == element::f64)
{
vector<double> ref_data_vector = read_vector<double>(ref_data);
vector<double> bk_data_vector = read_vector<double>(bk_data);
vector<double> bk_isolated_data_vector = read_vector<double>(bk_isolated_data);
cout << "Test backed op run w/ original graph dependencies:" << endl;
print_results(ref_data_vector, bk_data_vector);
// When testing with original graph dependencies test w/ loose f64 tolerance
constexpr int tolerance_bits = 30;
bool all_close_graph =
test::all_close_f(ref_data_vector, bk_data_vector, tolerance_bits);
cout << "Test backed op run isolated w/ inputs from ref graph run:" << endl;
print_results(ref_data_vector, bk_isolated_data_vector);
// When testing with isolated graph dependencies test w/ default (tight) f64 tolerance
bool all_close_isolated =
test::all_close_f(ref_data_vector, bk_isolated_data_vector);
EXPECT_TRUE(all_close_graph && all_close_isolated);
}
else if (et == element::i8) else if (et == element::i8)
{ {
vector<int8_t> ref_data_vector = read_vector<int8_t>(ref_data); vector<int8_t> ref_data_vector = read_vector<int8_t>(ref_data);
......
...@@ -150,29 +150,23 @@ vector<uint64_t> test::float_distances(const vector<double>& a, const vector<dou ...@@ -150,29 +150,23 @@ vector<uint64_t> test::float_distances(const vector<double>& a, const vector<dou
uint32_t test::matching_mantissa_bits(uint32_t distance) uint32_t test::matching_mantissa_bits(uint32_t distance)
{ {
uint32_t tolerance_needed = distance; uint32_t tolerance_bit_shift = 0;
uint32_t num_bits_on = 0;
if (tolerance_needed < 0x80000000) // Do some bit probing to find the most significant bit that's on,
// as well as how many bits are on.
for (uint32_t check_bit = 0; check_bit < 32; ++check_bit)
{ {
// Set up the dominos - turn on all the bits below maximal bit if (distance & (1 << check_bit))
tolerance_needed |= tolerance_needed >> 1;
tolerance_needed |= tolerance_needed >> 2;
tolerance_needed |= tolerance_needed >> 4;
tolerance_needed |= tolerance_needed >> 8;
tolerance_needed |= tolerance_needed >> 16;
// Tumble the dominos so we end up with next highest bit
++tolerance_needed;
// all_close_f is <= test for tolerance
if ((tolerance_needed >> 1) == distance)
{ {
tolerance_needed = distance; tolerance_bit_shift = check_bit;
++num_bits_on;
} }
} }
uint32_t tolerance_bit_shift = 0; // all_close_f is <= test for tolerance (where tolerance is uint32_t with single bit on)
while (tolerance_needed >>= 1) // So if more than one bit is on we need the next higher tolerance
if (num_bits_on > 1)
{ {
++tolerance_bit_shift; ++tolerance_bit_shift;
} }
...@@ -191,32 +185,25 @@ uint32_t test::matching_mantissa_bits(uint32_t distance) ...@@ -191,32 +185,25 @@ uint32_t test::matching_mantissa_bits(uint32_t distance)
return matching_matissa_bits; return matching_matissa_bits;
} }
uint64_t test::matching_mantissa_bits(uint64_t distance) uint32_t test::matching_mantissa_bits(uint64_t distance)
{ {
uint64_t tolerance_needed = distance; uint32_t tolerance_bit_shift = 0;
uint32_t num_bits_on = 0;
if (tolerance_needed < 0x8000000000000000) // Do some bit probing to find the most significant bit that's on,
// as well as how many bits are on.
for (uint32_t check_bit = 0; check_bit < 64; ++check_bit)
{ {
// Set up the dominos - turn on all the bits below maximal bit if (distance & (1ull << check_bit))
tolerance_needed |= tolerance_needed >> 1;
tolerance_needed |= tolerance_needed >> 2;
tolerance_needed |= tolerance_needed >> 4;
tolerance_needed |= tolerance_needed >> 8;
tolerance_needed |= tolerance_needed >> 16;
tolerance_needed |= tolerance_needed >> 32;
// Tumble the dominos so we end up with next highest bit
++tolerance_needed;
// all_close_f is <= test for tolerance
if ((tolerance_needed >> 1) == distance)
{ {
tolerance_needed = distance; tolerance_bit_shift = check_bit;
++num_bits_on;
} }
} }
uint64_t tolerance_bit_shift = 0; // all_close_f is <= test for tolerance (where tolerance is uint64_t with single bit on)
while (tolerance_needed >>= 1) // So if more than one bit is on we need the next higher tolerance
if (num_bits_on > 1)
{ {
++tolerance_bit_shift; ++tolerance_bit_shift;
} }
...@@ -230,7 +217,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance) ...@@ -230,7 +217,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
// tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) - 0 ) // tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) - 0 )
// tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) ) // tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) )
// matching_matissa_bits = 64 - (1 + 11 + (tolerance_bit_shift - 1 ) ) // matching_matissa_bits = 64 - (1 + 11 + (tolerance_bit_shift - 1 ) )
uint64_t matching_matissa_bits = uint32_t matching_matissa_bits =
tolerance_bit_shift < 53 ? (64 - (1 + 11 + (tolerance_bit_shift - 1))) : 0; tolerance_bit_shift < 53 ? (64 - (1 + 11 + (tolerance_bit_shift - 1))) : 0;
return matching_matissa_bits; return matching_matissa_bits;
} }
...@@ -241,7 +228,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance) ...@@ -241,7 +228,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
int tolerance_bits) int tolerance_bits)
{ {
bool rc = true; bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure(); stringstream msg;
if (a.size() != b.size()) if (a.size() != b.size())
{ {
return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison."; return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
...@@ -275,8 +262,8 @@ uint64_t test::matching_mantissa_bits(uint64_t distance) ...@@ -275,8 +262,8 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
{ {
if (diff_count < 5) if (diff_count < 5)
{ {
ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i] msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i]
<< " is not close to " << b[i] << " at index " << i << "\n"; << " is not close to " << b[i] << " at index " << i << "\n";
} }
rc = false; rc = false;
...@@ -285,7 +272,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance) ...@@ -285,7 +272,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
} }
if (!rc) if (!rc)
{ {
ar_fail << "diff count: " << diff_count << " out of " << a.size() << "\n"; msg << "diff count: " << diff_count << " out of " << a.size() << "\n";
} }
// Find median value via partial sorting // Find median value via partial sorting
size_t middle = distances.size() / 2; size_t middle = distances.size() / 2;
...@@ -299,20 +286,31 @@ uint64_t test::matching_mantissa_bits(uint64_t distance) ...@@ -299,20 +286,31 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
median_distance = median_sum / 2; median_distance = median_sum / 2;
} }
ar_fail << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits (" if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr))
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n"; {
ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1) // Short unobtrusive message when passing
<< "tightest match: " << matching_mantissa_bits(min_distance) << " mantissa bits (" std::cout << "[ INFO ] Verifying match of >= " << (mantissa_bits - tolerance_bits)
<< a[min_distance_index] << " vs " << b[min_distance_index] << " at [" << " mantissa bits (" << mantissa_bits << " bits precision - " << tolerance_bits
<< min_distance_index << "])\n"; << " tolerance). Loosest match found is " << matching_mantissa_bits(max_distance)
ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << " mantissa bits.\n";
<< "loosest match: " << matching_mantissa_bits(max_distance) << " mantissa bits (" }
<< a[max_distance_index] << " vs " << b[max_distance_index] << " at ["
<< max_distance_index << "])\n"; msg << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
ar_fail << "median match: " << matching_mantissa_bits(median_distance) << mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n";
<< " mantissa bits\n"; msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match: " << matching_mantissa_bits(min_distance) << " mantissa bits ("
return rc ? ::testing::AssertionSuccess() : ar_fail; << a[min_distance_index] << " vs " << b[min_distance_index] << " at [" << min_distance_index
<< "])\n";
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match: " << matching_mantissa_bits(max_distance) << " mantissa bits ("
<< a[max_distance_index] << " vs " << b[max_distance_index] << " at [" << max_distance_index
<< "])\n";
msg << "median match: " << matching_mantissa_bits(median_distance) << " mantissa bits\n";
::testing::AssertionResult res =
rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
res << msg.str();
return res;
} }
::testing::AssertionResult ::testing::AssertionResult
...@@ -321,7 +319,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance) ...@@ -321,7 +319,7 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
constexpr int mantissa_bits = 53; constexpr int mantissa_bits = 53;
bool rc = true; bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure(); stringstream msg;
if (a.size() != b.size()) if (a.size() != b.size())
{ {
return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison."; return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
...@@ -355,14 +353,14 @@ uint64_t test::matching_mantissa_bits(uint64_t distance) ...@@ -355,14 +353,14 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
{ {
if (diff_count < 5) if (diff_count < 5)
{ {
ar_fail << a[i] << " is not close to " << b[i] << " at index " << i << "\n"; msg << a[i] << " is not close to " << b[i] << " at index " << i << "\n";
} }
rc = false; rc = false;
diff_count++; diff_count++;
} }
} }
ar_fail << "diff count: " << diff_count << " out of " << a.size() << "\n"; msg << "diff count: " << diff_count << " out of " << a.size() << "\n";
// Find median value via partial sorting // Find median value via partial sorting
size_t middle = distances.size() / 2; size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end()); std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
...@@ -376,18 +374,31 @@ uint64_t test::matching_mantissa_bits(uint64_t distance) ...@@ -376,18 +374,31 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
(median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2); (median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2);
} }
ar_fail << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits (" if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr))
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n"; {
ar_fail << "tightest match: " << matching_mantissa_bits(min_distance) << " mantissa bits (" // Short unobtrusive message when passing
<< a[min_distance_index] << " vs " << b[min_distance_index] << " at [" std::cout << "[ INFO ] Verifying match of >= " << (mantissa_bits - tolerance_bits)
<< min_distance_index << "])\n"; << " mantissa bits (" << mantissa_bits << " bits precision - " << tolerance_bits
ar_fail << "loosest match: " << matching_mantissa_bits(max_distance) << " mantissa bits (" << " tolerance). Loosest match found is " << matching_mantissa_bits(max_distance)
<< a[max_distance_index] << " vs " << b[max_distance_index] << " at [" << " mantissa bits.\n";
<< max_distance_index << "])\n"; }
ar_fail << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits\n"; msg << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n";
return rc ? ::testing::AssertionSuccess() : ar_fail; msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match: " << matching_mantissa_bits(min_distance) << " mantissa bits ("
<< a[min_distance_index] << " vs " << b[min_distance_index] << " at [" << min_distance_index
<< "])\n";
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match: " << matching_mantissa_bits(max_distance) << " mantissa bits ("
<< a[max_distance_index] << " vs " << b[max_distance_index] << " at [" << max_distance_index
<< "])\n";
msg << "median match: " << matching_mantissa_bits(median_distance) << " mantissa bits\n";
::testing::AssertionResult res =
rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
res << msg.str();
return res;
} }
::testing::AssertionResult test::all_close_f(const std::shared_ptr<runtime::Tensor>& a, ::testing::AssertionResult test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
......
...@@ -137,7 +137,7 @@ namespace ngraph ...@@ -137,7 +137,7 @@ namespace ngraph
/// \returns Number of matching mantissa bits /// \returns Number of matching mantissa bits
/// ///
/// See float_distance for limitations and assumptions. /// See float_distance for limitations and assumptions.
uint64_t matching_mantissa_bits(uint64_t distance); uint32_t matching_mantissa_bits(uint64_t distance);
/// \brief Check if the two floating point vectors are all close /// \brief Check if the two floating point vectors are all close
/// \param a First number to compare /// \param a First number to compare
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment