Unverified Commit fcdfc4ce authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

change all_close tests to return gtest AssertionResult instead of bool (#2195)

* change all_close tests to return gtest AssertionResult instead of bool to allow for better error messages

* change throw to return error

* address PR comments and fix compile error
parent 15d9b658
......@@ -20,6 +20,7 @@
#include <memory>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/type/element_type.hpp"
#include "test_tools.hpp"
......@@ -34,13 +35,14 @@ namespace ngraph
/// \param atol Absolute tolerance
/// \returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T>
typename std::enable_if<std::is_signed<T>::value, bool>::type
typename std::enable_if<std::is_signed<T>::value, ::testing::AssertionResult>::type
all_close(const std::vector<T>& a,
const std::vector<T>& b,
T rtol = static_cast<T>(1e-5),
T atol = static_cast<T>(1e-8))
{
bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
assert(a.size() == b.size());
size_t count = 0;
for (size_t i = 0; i < a.size(); ++i)
......@@ -50,19 +52,15 @@ namespace ngraph
{
if (count < 5)
{
NGRAPH_INFO
<< std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< a[i] << " is not close to " << b[i] << " at index " << i;
ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< a[i] << " is not close to " << b[i] << " at index " << i << "\n";
}
count++;
rc = false;
}
}
if (!rc)
{
NGRAPH_INFO << "diff count: " << count << " out of " << a.size();
}
return rc;
ar_fail << "diff count: " << count << " out of " << a.size() << "\n";
return rc ? ::testing::AssertionSuccess() : ar_fail;
}
/// \brief Same as numpy.allclose
......@@ -72,24 +70,25 @@ namespace ngraph
/// \param atol Absolute tolerance
/// \returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T>
typename std::enable_if<std::is_unsigned<T>::value, bool>::type
typename std::enable_if<std::is_unsigned<T>::value, ::testing::AssertionResult>::type
all_close(const std::vector<T>& a,
const std::vector<T>& b,
T rtol = static_cast<T>(1e-5),
T atol = static_cast<T>(1e-8))
{
bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
assert(a.size() == b.size());
for (size_t i = 0; i < a.size(); ++i)
{
T abs_diff = (a[i] > b[i]) ? (a[i] - b[i]) : (b[i] - a[i]);
if (abs_diff > atol + rtol * b[i])
{
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i;
ar_fail << a[i] << " is not close to " << b[i] << " at index " << i;
rc = false;
}
}
return rc;
return rc ? ::testing::AssertionSuccess() : ar_fail;
}
/// \brief Same as numpy.allclose
......@@ -99,20 +98,22 @@ namespace ngraph
/// \param atol Absolute tolerance
/// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T>
bool all_close(const std::shared_ptr<ngraph::runtime::Tensor>& a,
const std::shared_ptr<ngraph::runtime::Tensor>& b,
T rtol = 1e-5f,
T atol = 1e-8f)
::testing::AssertionResult all_close(const std::shared_ptr<ngraph::runtime::Tensor>& a,
const std::shared_ptr<ngraph::runtime::Tensor>& b,
T rtol = 1e-5f,
T atol = 1e-8f)
{
// Check that the layouts are compatible
if (*a->get_tensor_layout() != *b->get_tensor_layout())
{
throw ngraph_error("Cannot compare tensors with different layouts");
return ::testing::AssertionFailure()
<< "Cannot compare tensors with different layouts";
}
if (a->get_shape() != b->get_shape())
{
return false;
return ::testing::AssertionFailure()
<< "Cannot compare tensors with different shapes";
}
return all_close(read_vector<T>(a), read_vector<T>(b), rtol, atol);
......@@ -125,23 +126,26 @@ namespace ngraph
/// \param atol Absolute tolerance
/// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T>
bool all_close(const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& as,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& bs,
T rtol,
T atol)
::testing::AssertionResult
all_close(const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& as,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& bs,
T rtol,
T atol)
{
if (as.size() != bs.size())
{
return false;
return ::testing::AssertionFailure()
<< "Cannot compare tensors with different sizes";
}
for (size_t i = 0; i < as.size(); ++i)
{
if (!all_close(as[i], bs[i], rtol, atol))
auto ar = all_close(as[i], bs[i], rtol, atol);
if (!ar)
{
return false;
return ar;
}
}
return true;
return ::testing::AssertionSuccess();
}
}
}
......@@ -235,15 +235,16 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
return matching_matissa_bits;
}
bool test::all_close_f(const vector<float>& a,
const vector<float>& b,
int mantissa_bits,
int tolerance_bits)
::testing::AssertionResult test::all_close_f(const vector<float>& a,
const vector<float>& b,
int mantissa_bits,
int tolerance_bits)
{
bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
if (a.size() != b.size())
{
throw ngraph_error("a.size() != b.size() for all_close_f comparison.");
return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
}
vector<uint32_t> distances = float_distances(a, b);
......@@ -274,8 +275,8 @@ bool test::all_close_f(const vector<float>& a,
{
if (diff_count < 5)
{
NGRAPH_INFO << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< a[i] << " is not close to " << b[i] << " at index " << i;
ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i]
<< " is not close to " << b[i] << " at index " << i << "\n";
}
rc = false;
......@@ -284,7 +285,7 @@ bool test::all_close_f(const vector<float>& a,
}
if (!rc)
{
NGRAPH_INFO << "diff count: " << diff_count << " out of " << a.size();
ar_fail << "diff count: " << diff_count << " out of " << a.size() << "\n";
}
// Find median value via partial sorting
size_t middle = distances.size() / 2;
......@@ -298,30 +299,32 @@ bool test::all_close_f(const vector<float>& a,
median_distance = median_sum / 2;
}
NGRAPH_INFO << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)";
NGRAPH_INFO << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match: " << matching_mantissa_bits(min_distance)
<< " mantissa bits (" << a[min_distance_index] << " vs " << b[min_distance_index]
<< " at [" << min_distance_index << "])";
NGRAPH_INFO << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match: " << matching_mantissa_bits(max_distance)
<< " mantissa bits (" << a[max_distance_index] << " vs " << b[max_distance_index]
<< " at [" << max_distance_index << "])";
NGRAPH_INFO << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits";
return rc;
ar_fail << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n";
ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match: " << matching_mantissa_bits(min_distance) << " mantissa bits ("
<< a[min_distance_index] << " vs " << b[min_distance_index] << " at ["
<< min_distance_index << "])\n";
ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match: " << matching_mantissa_bits(max_distance) << " mantissa bits ("
<< a[max_distance_index] << " vs " << b[max_distance_index] << " at ["
<< max_distance_index << "])\n";
ar_fail << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits\n";
return rc ? ::testing::AssertionSuccess() : ar_fail;
}
bool test::all_close_f(const vector<double>& a, const vector<double>& b, int tolerance_bits)
::testing::AssertionResult
test::all_close_f(const vector<double>& a, const vector<double>& b, int tolerance_bits)
{
constexpr int mantissa_bits = 53;
bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
if (a.size() != b.size())
{
throw ngraph_error("a.size() != b.size() for all_close_f comparison.");
return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
}
vector<uint64_t> distances = float_distances(a, b);
......@@ -352,17 +355,14 @@ bool test::all_close_f(const vector<double>& a, const vector<double>& b, int tol
{
if (diff_count < 5)
{
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i;
ar_fail << a[i] << " is not close to " << b[i] << " at index " << i << "\n";
}
rc = false;
diff_count++;
}
}
if (!rc)
{
NGRAPH_INFO << "diff count: " << diff_count << " out of " << a.size();
}
ar_fail << "diff count: " << diff_count << " out of " << a.size() << "\n";
// Find median value via partial sorting
size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
......@@ -376,54 +376,56 @@ bool test::all_close_f(const vector<double>& a, const vector<double>& b, int tol
(median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2);
}
NGRAPH_INFO << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)";
NGRAPH_INFO << "tightest match: " << matching_mantissa_bits(min_distance)
<< " mantissa bits (" << a[min_distance_index] << " vs " << b[min_distance_index]
<< " at [" << min_distance_index << "])";
NGRAPH_INFO << "loosest match: " << matching_mantissa_bits(max_distance)
<< " mantissa bits (" << a[max_distance_index] << " vs " << b[max_distance_index]
<< " at [" << max_distance_index << "])";
NGRAPH_INFO << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits";
return rc;
ar_fail << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n";
ar_fail << "tightest match: " << matching_mantissa_bits(min_distance) << " mantissa bits ("
<< a[min_distance_index] << " vs " << b[min_distance_index] << " at ["
<< min_distance_index << "])\n";
ar_fail << "loosest match: " << matching_mantissa_bits(max_distance) << " mantissa bits ("
<< a[max_distance_index] << " vs " << b[max_distance_index] << " at ["
<< max_distance_index << "])\n";
ar_fail << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits\n";
return rc ? ::testing::AssertionSuccess() : ar_fail;
}
bool test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b,
int mantissa_bits,
int tolerance_bits)
::testing::AssertionResult test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b,
int mantissa_bits,
int tolerance_bits)
{
// Check that the layouts are compatible
if (*a->get_tensor_layout() != *b->get_tensor_layout())
{
throw ngraph_error("Cannot compare tensors with different layouts");
return ::testing::AssertionFailure() << "Cannot compare tensors with different layouts";
}
if (a->get_shape() != b->get_shape())
{
return false;
return ::testing::AssertionFailure() << "Cannot compare tensors with different shapes";
}
return test::all_close_f(
read_float_vector(a), read_float_vector(b), mantissa_bits, tolerance_bits);
}
bool test::all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int mantissa_bits,
int tolerance_bits)
::testing::AssertionResult
test::all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int mantissa_bits,
int tolerance_bits)
{
if (as.size() != bs.size())
{
return false;
return ::testing::AssertionFailure() << "Cannot compare tensors with different sizes";
}
for (size_t i = 0; i < as.size(); ++i)
{
if (!test::all_close_f(as[i], bs[i], mantissa_bits, tolerance_bits))
auto ar = test::all_close_f(as[i], bs[i], mantissa_bits, tolerance_bits);
if (!ar)
{
return false;
return ar;
}
}
return true;
return ::testing::AssertionSuccess();
}
......@@ -19,6 +19,7 @@
#include <memory>
#include <vector>
#include "gtest/gtest.h"
#include "test_tools.hpp"
namespace ngraph
......@@ -143,20 +144,20 @@ namespace ngraph
/// \param b Second number to compare
/// \param mantissa_bits The mantissa width of the underlying number before casting to float
/// \param tolerance_bits Bit tolerance error
/// \returns true iff the two floating point vectors are close
bool all_close_f(const std::vector<float>& a,
const std::vector<float>& b,
int mantissa_bits = 8,
int tolerance_bits = 2);
/// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
::testing::AssertionResult all_close_f(const std::vector<float>& a,
const std::vector<float>& b,
int mantissa_bits = 8,
int tolerance_bits = 2);
/// \brief Check if the two double floating point vectors are all close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \returns true iff the two floating point vectors are close
bool all_close_f(const std::vector<double>& a,
const std::vector<double>& b,
int tolerance_bits = 2);
/// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
::testing::AssertionResult all_close_f(const std::vector<double>& a,
const std::vector<double>& b,
int tolerance_bits = 2);
/// \brief Check if the two TensorViews are all close in float
/// \param a First Tensor to compare
......@@ -164,10 +165,10 @@ namespace ngraph
/// \param mantissa_bits The mantissa width of the underlying number before casting to float
/// \param tolerance_bits Bit tolerance error
/// Returns true iff the two TensorViews are all close in float
bool all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b,
int mantissa_bits = 8,
int tolerance_bits = 2);
::testing::AssertionResult all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b,
int mantissa_bits = 8,
int tolerance_bits = 2);
/// \brief Check if the two vectors of TensorViews are all close in float
/// \param as First vector of Tensor to compare
......@@ -175,9 +176,10 @@ namespace ngraph
/// \param mantissa_bits The mantissa width of the underlying number before casting to float
/// \param tolerance_bits Bit tolerance error
/// Returns true iff the two TensorViews are all close in float
bool all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int mantissa_bits = 8,
int tolerance_bits = 2);
::testing::AssertionResult
all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int mantissa_bits = 8,
int tolerance_bits = 2);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment