Commit 0eaa960c authored by gcwenger's avatar gcwenger Committed by Scott Cyphers

Simplified all_close_f interface and tightened default criteria (#2285)

* Simplified & tightened all_close_f parameters

Removed specification of mantissa bits for all_close_f in favor
of just specifying tolerance bits. Tightened up all_close_f default.
Fixed LRN unit test which had insufficient result precision to pass
tighter all_close_f tolerance.

* Addressed PR comments.

Reworked mantissa bit and tolerance constants.
Clarified and improved graph comparison tolerance calculation flexibility.
Clarified unit test tolerance testing.
parent 15a0bf19
This diff is collapsed.
...@@ -104,13 +104,30 @@ public: ...@@ -104,13 +104,30 @@ public:
msg << "Test backed op run w/ original graph dependencies:" msg << "Test backed op run w/ original graph dependencies:"
<< "\n"; << "\n";
msg << get_results_str(ref_data_vector, bk_data_vector); msg << get_results_str(ref_data_vector, bk_data_vector);
bool all_close_graph = test::all_close_f(ref_data_vector, bk_data_vector); // Future work will better determine useful graph comparison thresholds.
// For a very small sample of tested graphs initial criteria is:
// * Comparison of ops using inputs from preceeding ops (original
// graph dependencies) allows for a little better than 1/3 of
// the possible bits to match
// * Isolated operation allows for 2/3 of the possible bits to match
constexpr int one_third_of_available_bits = (MAX_FLOAT_BITS + 1) / 3;
constexpr int in_graph_tolerance =
FLOAT_MANTISSA_BITS - one_third_of_available_bits;
constexpr int isolated_tolerance =
FLOAT_MANTISSA_BITS - (one_third_of_available_bits * 2);
::testing::AssertionResult all_close_graph =
test::all_close_f(ref_data_vector, bk_data_vector, in_graph_tolerance);
msg << "Test backed op run isolated w/ inputs from ref graph run:" msg << "Test backed op run isolated w/ inputs from ref graph run:"
<< "\n"; << "\n";
msg << get_results_str(ref_data_vector, bk_isolated_data_vector); msg << get_results_str(ref_data_vector, bk_isolated_data_vector);
bool all_close_isolated = ::testing::AssertionResult all_close_isolated =
test::all_close_f(ref_data_vector, bk_isolated_data_vector); test::all_close_f(ref_data_vector, bk_isolated_data_vector, isolated_tolerance);
EXPECT_TRUE(all_close_graph && all_close_isolated) << msg.str(); if (!all_close_graph || !all_close_isolated)
{
cout << msg.str();
}
EXPECT_TRUE(all_close_graph);
EXPECT_TRUE(all_close_isolated);
} }
else if (et == element::f64) else if (et == element::f64)
{ {
...@@ -123,16 +140,21 @@ public: ...@@ -123,16 +140,21 @@ public:
// When testing with original graph dependencies test w/ loose f64 tolerance // When testing with original graph dependencies test w/ loose f64 tolerance
constexpr int tolerance_bits = 30; constexpr int tolerance_bits = 30;
bool all_close_graph = ::testing::AssertionResult all_close_graph =
test::all_close_f(ref_data_vector, bk_data_vector, tolerance_bits); test::all_close_f(ref_data_vector, bk_data_vector, tolerance_bits);
msg << "Test backed op run isolated w/ inputs from ref graph run:" msg << "Test backed op run isolated w/ inputs from ref graph run:"
<< "\n"; << "\n";
msg << get_results_str(ref_data_vector, bk_isolated_data_vector); msg << get_results_str(ref_data_vector, bk_isolated_data_vector);
// When testing with isolated graph dependencies test w/ default (tight) f64 tolerance // When testing with isolated graph dependencies test w/ default (tight) f64 tolerance
bool all_close_isolated = ::testing::AssertionResult all_close_isolated =
test::all_close_f(ref_data_vector, bk_isolated_data_vector); test::all_close_f(ref_data_vector, bk_isolated_data_vector);
EXPECT_TRUE(all_close_graph && all_close_isolated) << msg.str(); if (!all_close_graph || !all_close_isolated)
{
cout << msg.str();
}
EXPECT_TRUE(all_close_graph);
EXPECT_TRUE(all_close_isolated);
} }
else if (et == element::i8) else if (et == element::i8)
{ {
......
...@@ -553,7 +553,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_stable_acc) ...@@ -553,7 +553,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_stable_acc)
auto ref_results = execute(ref_func, args, "INTERPRETER"); auto ref_results = execute(ref_func, args, "INTERPRETER");
auto bk_results = execute(bk_func, args, "${BACKEND_NAME}"); auto bk_results = execute(bk_func, args, "${BACKEND_NAME}");
EXPECT_TRUE(test::all_close_f(ref_results.at(0), bk_results.at(0), 24, 3)); EXPECT_TRUE(
test::all_close_f(ref_results.at(0), bk_results.at(0), DEFAULT_FLOAT_TOLERANCE_BITS + 1));
} }
NGRAPH_TEST(${BACKEND_NAME}, sum_stable_acc_double) NGRAPH_TEST(${BACKEND_NAME}, sum_stable_acc_double)
...@@ -611,7 +612,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_stable_simple_float) ...@@ -611,7 +612,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_stable_simple_float)
auto ref_results = execute(ref_func, args, "INTERPRETER"); auto ref_results = execute(ref_func, args, "INTERPRETER");
auto bk_results = execute(bk_func, args, "${BACKEND_NAME}"); auto bk_results = execute(bk_func, args, "${BACKEND_NAME}");
EXPECT_TRUE(test::all_close_f(ref_results.at(0), bk_results.at(0), 24, 1)); EXPECT_TRUE(
test::all_close_f(ref_results.at(0), bk_results.at(0), DEFAULT_FLOAT_TOLERANCE_BITS - 1));
} }
NGRAPH_TEST(${BACKEND_NAME}, sum_stable_simple_double) NGRAPH_TEST(${BACKEND_NAME}, sum_stable_simple_double)
......
...@@ -1243,7 +1243,11 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn) ...@@ -1243,7 +1243,11 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn)
{ {
Shape shape{2, 3, 2, 1}; Shape shape{2, 3, 2, 1};
auto A = make_shared<op::Parameter>(element::f32, shape); auto A = make_shared<op::Parameter>(element::f32, shape);
auto lrn = make_shared<op::LRN>(A, 1., 2., 1., 3); double alpha = 3;
double beta = 0.5;
double bias = 1;
size_t size = 3;
auto lrn = make_shared<op::LRN>(A, alpha, beta, bias, size);
auto f = make_shared<Function>(lrn, ParameterVector{A}); auto f = make_shared<Function>(lrn, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
...@@ -1257,17 +1261,17 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn) ...@@ -1257,17 +1261,17 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn)
backend->call_with_validate(handle, {result}, {a}); backend->call_with_validate(handle, {result}, {a});
vector<float> expected{0.f, vector<float> expected{0.f,
0.05325444f, 0.3015113f,
0.03402646f, 0.4364357f,
0.01869806f, 0.5f,
0.06805293f, 0.8728715f,
0.03287071f, 0.8451542f,
0.00509002f, 0.5970223f,
0.00356153f, 0.6115928f,
0.00174719f, 0.5642765f,
0.0012555f, 0.5669467f,
0.00322708f, 0.7784989f,
0.00235574f}; 0.7720487f};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result))); EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
} }
......
...@@ -606,7 +606,8 @@ NGRAPH_TEST(${BACKEND_NAME}, topk_3d_large_input_max) ...@@ -606,7 +606,8 @@ NGRAPH_TEST(${BACKEND_NAME}, topk_3d_large_input_max)
for (size_t i = 0; i < gpu_results_1.size(); i++) for (size_t i = 0; i < gpu_results_1.size(); i++)
{ {
EXPECT_TRUE(test::all_close_f(gpu_results_1.at(i), interp_results_1.at(i), 24, 0)); EXPECT_TRUE(test::all_close_f(
gpu_results_1.at(i), interp_results_1.at(i), MIN_FLOAT_TOLERANCE_BITS));
} }
} }
...@@ -644,7 +645,8 @@ NGRAPH_TEST(${BACKEND_NAME}, topk_3d_large_input_min) ...@@ -644,7 +645,8 @@ NGRAPH_TEST(${BACKEND_NAME}, topk_3d_large_input_min)
for (size_t i = 0; i < gpu_results_1.size(); i++) for (size_t i = 0; i < gpu_results_1.size(); i++)
{ {
EXPECT_TRUE(test::all_close_f(gpu_results_1.at(i), interp_results_1.at(i), 24, 0)); EXPECT_TRUE(test::all_close_f(
gpu_results_1.at(i), interp_results_1.at(i), MIN_FLOAT_TOLERANCE_BITS));
} }
} }
......
...@@ -212,10 +212,10 @@ TEST(gpu_test, topk_fanout_graph_transform) ...@@ -212,10 +212,10 @@ TEST(gpu_test, topk_fanout_graph_transform)
EXPECT_EQ((vector<int32_t>{2, 1, 1, 2, 1, 2, 0, 1}), read_vector<int32_t>(r0)); EXPECT_EQ((vector<int32_t>{2, 1, 1, 2, 1, 2, 0, 1}), read_vector<int32_t>(r0));
EXPECT_EQ((vector<int32_t>{2, 1, 1, 2, 1, 2, 0, 1}), read_vector<int32_t>(r1)); EXPECT_EQ((vector<int32_t>{2, 1, 1, 2, 1, 2, 0, 1}), read_vector<int32_t>(r1));
EXPECT_TRUE( EXPECT_TRUE(test::all_close_f(
test::all_close_f(vector<float>{4, 4, 3, 3, 3, 4, 2, 3}, read_vector<float>(r2), 24, 0)); vector<float>{4, 4, 3, 3, 3, 4, 2, 3}, read_vector<float>(r2), MIN_FLOAT_TOLERANCE_BITS));
EXPECT_TRUE( EXPECT_TRUE(test::all_close_f(
test::all_close_f(vector<float>{4, 4, 3, 3, 3, 4, 2, 3}, read_vector<float>(r3), 24, 0)); vector<float>{4, 4, 3, 3, 3, 4, 2, 3}, read_vector<float>(r3), MIN_FLOAT_TOLERANCE_BITS));
auto reshape_count = count_ops_of_type<ngraph::op::Reshape>(gpu_f); auto reshape_count = count_ops_of_type<ngraph::op::Reshape>(gpu_f);
EXPECT_EQ(reshape_count, 10); EXPECT_EQ(reshape_count, 10);
} }
......
...@@ -78,7 +78,7 @@ uint64_t test::float_distance(double a, double b) ...@@ -78,7 +78,7 @@ uint64_t test::float_distance(double a, double b)
return distance; return distance;
} }
bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits) bool test::close_f(float a, float b, int tolerance_bits)
{ {
// isfinite(a) => !isinf(a) && !isnan(a) // isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b)) if (!isfinite(a) || !isfinite(b))
...@@ -91,7 +91,7 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits) ...@@ -91,7 +91,7 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits // e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 ) // tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
// float_length sign exp mantissa implicit 1 tolerance_bits // float_length sign exp mantissa implicit 1 tolerance_bits
uint32_t tolerance_bit_shift = 32 - (1 + 8 + (mantissa_bits - 1) - tolerance_bits); uint32_t tolerance_bit_shift = 32 - (1 + 8 + (FLOAT_MANTISSA_BITS - 1) - tolerance_bits);
uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift; uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;
return distance <= tolerance; return distance <= tolerance;
...@@ -99,8 +99,6 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits) ...@@ -99,8 +99,6 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
bool test::close_f(double a, double b, int tolerance_bits) bool test::close_f(double a, double b, int tolerance_bits)
{ {
constexpr int mantissa_bits = 53;
// isfinite(a) => !isinf(a) && !isnan(a) // isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b)) if (!isfinite(a) || !isfinite(b))
{ {
...@@ -112,7 +110,7 @@ bool test::close_f(double a, double b, int tolerance_bits) ...@@ -112,7 +110,7 @@ bool test::close_f(double a, double b, int tolerance_bits)
// e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits // e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 ) // tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
// double_length sign exp mantissa implicit 1 tolerance_bits // double_length sign exp mantissa implicit 1 tolerance_bits
uint64_t tolerance_bit_shift = 64 - (1 + 11 + (mantissa_bits - 1) - tolerance_bits); uint64_t tolerance_bit_shift = 64 - (1 + 11 + (DOUBLE_MANTISSA_BITS - 1) - tolerance_bits);
uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift; uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;
return distance <= tolerance; return distance <= tolerance;
...@@ -222,11 +220,18 @@ uint32_t test::matching_mantissa_bits(uint64_t distance) ...@@ -222,11 +220,18 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
return matching_matissa_bits; return matching_matissa_bits;
} }
::testing::AssertionResult test::all_close_f(const vector<float>& a, ::testing::AssertionResult
const vector<float>& b, test::all_close_f(const vector<float>& a, const vector<float>& b, int tolerance_bits)
int mantissa_bits,
int tolerance_bits)
{ {
if (tolerance_bits < MIN_FLOAT_TOLERANCE_BITS)
{
tolerance_bits = MIN_FLOAT_TOLERANCE_BITS;
}
if (tolerance_bits >= FLOAT_MANTISSA_BITS)
{
tolerance_bits = FLOAT_MANTISSA_BITS - 1;
}
bool rc = true; bool rc = true;
stringstream msg; stringstream msg;
if (a.size() != b.size()) if (a.size() != b.size())
...@@ -238,7 +243,7 @@ uint32_t test::matching_mantissa_bits(uint64_t distance) ...@@ -238,7 +243,7 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits // e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 ) // tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
// float_length sign exp mantissa implicit 1 tolerance_bits // float_length sign exp mantissa implicit 1 tolerance_bits
uint32_t tolerance_bit_shift = 32 - (1 + 8 + (mantissa_bits - 1) - tolerance_bits); uint32_t tolerance_bit_shift = 32 - (1 + 8 + (FLOAT_MANTISSA_BITS - 1) - tolerance_bits);
uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift; uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;
uint32_t max_distance = 0; uint32_t max_distance = 0;
uint32_t min_distance = UINT_MAX; uint32_t min_distance = UINT_MAX;
...@@ -289,15 +294,15 @@ uint32_t test::matching_mantissa_bits(uint64_t distance) ...@@ -289,15 +294,15 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr)) if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr))
{ {
// Short unobtrusive message when passing // Short unobtrusive message when passing
std::cout << "[ INFO ] Verifying match of >= " << (mantissa_bits - tolerance_bits) std::cout << "[ INFO ] Verifying match of <= " << (FLOAT_MANTISSA_BITS - tolerance_bits)
<< " mantissa bits (" << mantissa_bits << " bits precision - " << tolerance_bits << " mantissa bits (" << FLOAT_MANTISSA_BITS << " bits precision - "
<< " tolerance). Loosest match found is " << matching_mantissa_bits(max_distance) << tolerance_bits << " tolerance). Loosest match found is "
<< " mantissa bits.\n"; << matching_mantissa_bits(max_distance) << " mantissa bits.\n";
} }
msg << "passing criteria - mismatch allowed @ mantissa bit: " msg << "passing criteria - mismatch allowed @ mantissa bit: "
<< (mantissa_bits - tolerance_bits) << " or later (" << mantissa_bits << (FLOAT_MANTISSA_BITS - tolerance_bits) << " or later (" << tolerance_bits
<< " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n"; << " tolerance bits)\n";
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1) msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match - mismatch occurred @ mantissa bit: " << "tightest match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index] << matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index]
...@@ -318,7 +323,14 @@ uint32_t test::matching_mantissa_bits(uint64_t distance) ...@@ -318,7 +323,14 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
::testing::AssertionResult ::testing::AssertionResult
test::all_close_f(const vector<double>& a, const vector<double>& b, int tolerance_bits) test::all_close_f(const vector<double>& a, const vector<double>& b, int tolerance_bits)
{ {
constexpr int mantissa_bits = 53; if (tolerance_bits < 0)
{
tolerance_bits = 0;
}
if (tolerance_bits >= DOUBLE_MANTISSA_BITS)
{
tolerance_bits = DOUBLE_MANTISSA_BITS - 1;
}
bool rc = true; bool rc = true;
stringstream msg; stringstream msg;
...@@ -331,7 +343,7 @@ uint32_t test::matching_mantissa_bits(uint64_t distance) ...@@ -331,7 +343,7 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
// e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits // e.g. for double with 52 bit mantissa, 2 bit accuracy, and hard-coded 11 bit exponent_bits
// tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 ) // tolerance_bit_shift = 64 - (1 + 11 + (53 - 1 ) - 2 )
// double_length sign exp mantissa implicit 1 tolerance_bits // double_length sign exp mantissa implicit 1 tolerance_bits
uint64_t tolerance_bit_shift = 64 - (1 + 11 + (mantissa_bits - 1) - tolerance_bits); uint64_t tolerance_bit_shift = 64 - (1 + 11 + (DOUBLE_MANTISSA_BITS - 1) - tolerance_bits);
uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift; uint64_t tolerance = static_cast<uint64_t>(1U) << tolerance_bit_shift;
uint64_t max_distance = 0; uint64_t max_distance = 0;
uint64_t min_distance = ULLONG_MAX; uint64_t min_distance = ULLONG_MAX;
...@@ -379,15 +391,16 @@ uint32_t test::matching_mantissa_bits(uint64_t distance) ...@@ -379,15 +391,16 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr)) if (rc && (std::getenv("NGRAPH_GTEST_INFO") != nullptr))
{ {
// Short unobtrusive message when passing // Short unobtrusive message when passing
std::cout << "[ INFO ] Verifying match of >= " << (mantissa_bits - tolerance_bits) std::cout << "[ INFO ] Verifying match of >= "
<< " mantissa bits (" << mantissa_bits << " bits precision - " << tolerance_bits << (DOUBLE_MANTISSA_BITS - tolerance_bits) << " mantissa bits ("
<< DOUBLE_MANTISSA_BITS << " bits precision - " << tolerance_bits
<< " tolerance). Loosest match found is " << matching_mantissa_bits(max_distance) << " tolerance). Loosest match found is " << matching_mantissa_bits(max_distance)
<< " mantissa bits.\n"; << " mantissa bits.\n";
} }
msg << "passing criteria - mismatch allowed @ mantissa bit: " msg << "passing criteria - mismatch allowed @ mantissa bit: "
<< (mantissa_bits - tolerance_bits) << " or later (" << mantissa_bits << (DOUBLE_MANTISSA_BITS - tolerance_bits) << " or later (" << tolerance_bits
<< " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n"; << " tolerance bits)\n";
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1) msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match - mismatch occurred @ mantissa bit: " << "tightest match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index] << matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index]
...@@ -407,7 +420,6 @@ uint32_t test::matching_mantissa_bits(uint64_t distance) ...@@ -407,7 +420,6 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
::testing::AssertionResult test::all_close_f(const std::shared_ptr<runtime::Tensor>& a, ::testing::AssertionResult test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b, const std::shared_ptr<runtime::Tensor>& b,
int mantissa_bits,
int tolerance_bits) int tolerance_bits)
{ {
// Check that the layouts are compatible // Check that the layouts are compatible
...@@ -420,14 +432,12 @@ uint32_t test::matching_mantissa_bits(uint64_t distance) ...@@ -420,14 +432,12 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
return ::testing::AssertionFailure() << "Cannot compare tensors with different shapes"; return ::testing::AssertionFailure() << "Cannot compare tensors with different shapes";
} }
return test::all_close_f( return test::all_close_f(read_float_vector(a), read_float_vector(b), tolerance_bits);
read_float_vector(a), read_float_vector(b), mantissa_bits, tolerance_bits);
} }
::testing::AssertionResult ::testing::AssertionResult
test::all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as, test::all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs, const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int mantissa_bits,
int tolerance_bits) int tolerance_bits)
{ {
if (as.size() != bs.size()) if (as.size() != bs.size())
...@@ -436,7 +446,7 @@ uint32_t test::matching_mantissa_bits(uint64_t distance) ...@@ -436,7 +446,7 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
} }
for (size_t i = 0; i < as.size(); ++i) for (size_t i = 0; i < as.size(); ++i)
{ {
auto ar = test::all_close_f(as[i], bs[i], mantissa_bits, tolerance_bits); auto ar = test::all_close_f(as[i], bs[i], tolerance_bits);
if (!ar) if (!ar)
{ {
return ar; return ar;
......
...@@ -22,6 +22,43 @@ ...@@ -22,6 +22,43 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "test_tools.hpp" #include "test_tools.hpp"
static constexpr int BFLOAT_MANTISSA_BITS = 8;
static constexpr int FLOAT_MANTISSA_BITS = 24;
static constexpr int DOUBLE_MANTISSA_BITS = 53;
// Maximum available float bits
#ifndef MAX_FLOAT_BITS
#define MAX_FLOAT_BITS FLOAT_MANTISSA_BITS
#endif
// Minimum float tolerance bits possible
#ifndef MIN_FLOAT_TOLERANCE_BITS
#define MIN_FLOAT_TOLERANCE_BITS (FLOAT_MANTISSA_BITS - MAX_FLOAT_BITS)
#endif
static_assert((MAX_FLOAT_BITS > 0) && (MAX_FLOAT_BITS <= FLOAT_MANTISSA_BITS),
"MAX_FLOAT_BITS must be in range (0, 24]");
static_assert((MIN_FLOAT_TOLERANCE_BITS >= 0) && (MIN_FLOAT_TOLERANCE_BITS < FLOAT_MANTISSA_BITS),
"MIN_FLOAT_TOLERANCE_BITS must be in range [0, 24)");
// Default float tolerance bits
#ifndef DEFAULT_FLOAT_TOLERANCE_BITS
#define DEFAULT_FLOAT_TOLERANCE_BITS (MIN_FLOAT_TOLERANCE_BITS + 2)
#endif
// Default float tolerance bits
#ifndef DEFAULT_DOUBLE_TOLERANCE_BITS
#define DEFAULT_DOUBLE_TOLERANCE_BITS 2
#endif
static_assert((DEFAULT_FLOAT_TOLERANCE_BITS >= 0) &&
(DEFAULT_FLOAT_TOLERANCE_BITS < FLOAT_MANTISSA_BITS),
"DEFAULT_FLOAT_TOLERANCE_BITS must be in range [0, 24)");
static_assert((DEFAULT_DOUBLE_TOLERANCE_BITS >= 0) &&
(DEFAULT_DOUBLE_TOLERANCE_BITS < DOUBLE_MANTISSA_BITS),
"DEFAULT_DOUBLE_TOLERANCE_BITS must be in range [0, 53)");
namespace ngraph namespace ngraph
{ {
namespace test namespace test
...@@ -68,7 +105,6 @@ namespace ngraph ...@@ -68,7 +105,6 @@ namespace ngraph
/// \brief Check if the two f32 numbers are close /// \brief Check if the two f32 numbers are close
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
/// \param mantissa_bits The mantissa width of the underlying number before casting to float
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP /// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP
/// ///
...@@ -86,7 +122,7 @@ namespace ngraph ...@@ -86,7 +122,7 @@ namespace ngraph
/// ///
/// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for /// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for
/// bfloat and f32. /// bfloat and f32.
bool close_f(float a, float b, int mantissa_bits = 8, int tolerance_bits = 2); bool close_f(float a, float b, int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
/// \brief Check if the two f64 numbers are close /// \brief Check if the two f64 numbers are close
/// \param a First number to compare /// \param a First number to compare
...@@ -105,7 +141,7 @@ namespace ngraph ...@@ -105,7 +141,7 @@ namespace ngraph
/// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision /// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision
/// ///
/// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64. /// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64.
bool close_f(double a, double b, int tolerance_bits = 2); bool close_f(double a, double b, int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS);
/// \brief Determine distances between two vectors of f32 numbers /// \brief Determine distances between two vectors of f32 numbers
/// \param a Vector of floats to compare /// \param a Vector of floats to compare
...@@ -142,13 +178,11 @@ namespace ngraph ...@@ -142,13 +178,11 @@ namespace ngraph
/// \brief Check if the two floating point vectors are all close /// \brief Check if the two floating point vectors are all close
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
/// \param mantissa_bits The mantissa width of the underlying number before casting to float
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// \returns ::testing::AssertionSuccess iff the two floating point vectors are close /// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
::testing::AssertionResult all_close_f(const std::vector<float>& a, ::testing::AssertionResult all_close_f(const std::vector<float>& a,
const std::vector<float>& b, const std::vector<float>& b,
int mantissa_bits = 8, int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
int tolerance_bits = 2);
/// \brief Check if the two double floating point vectors are all close /// \brief Check if the two double floating point vectors are all close
/// \param a First number to compare /// \param a First number to compare
...@@ -157,29 +191,25 @@ namespace ngraph ...@@ -157,29 +191,25 @@ namespace ngraph
/// \returns ::testing::AssertionSuccess iff the two floating point vectors are close /// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
::testing::AssertionResult all_close_f(const std::vector<double>& a, ::testing::AssertionResult all_close_f(const std::vector<double>& a,
const std::vector<double>& b, const std::vector<double>& b,
int tolerance_bits = 2); int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS);
/// \brief Check if the two TensorViews are all close in float /// \brief Check if the two TensorViews are all close in float
/// \param a First Tensor to compare /// \param a First Tensor to compare
/// \param b Second Tensor to compare /// \param b Second Tensor to compare
/// \param mantissa_bits The mantissa width of the underlying number before casting to float
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// Returns true iff the two TensorViews are all close in float /// Returns true iff the two TensorViews are all close in float
::testing::AssertionResult all_close_f(const std::shared_ptr<runtime::Tensor>& a, ::testing::AssertionResult all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b, const std::shared_ptr<runtime::Tensor>& b,
int mantissa_bits = 8, int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
int tolerance_bits = 2);
/// \brief Check if the two vectors of TensorViews are all close in float /// \brief Check if the two vectors of TensorViews are all close in float
/// \param as First vector of Tensor to compare /// \param as First vector of Tensor to compare
/// \param bs Second vector of Tensor to compare /// \param bs Second vector of Tensor to compare
/// \param mantissa_bits The mantissa width of the underlying number before casting to float
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// Returns true iff the two TensorViews are all close in float /// Returns true iff the two TensorViews are all close in float
::testing::AssertionResult ::testing::AssertionResult
all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as, all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs, const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int mantissa_bits = 8, int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
int tolerance_bits = 2);
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment