Unverified Commit fcdfc4ce authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

change all_close tests to return gtest AssertionResult instead of bool (#2195)

* change all_close tests to return gtest AssertionResult instead of bool to allow for better error messages

* change throw to return error

* address PR comments and fix compile error
parent 15d9b658
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "gtest/gtest.h"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
...@@ -34,13 +35,14 @@ namespace ngraph ...@@ -34,13 +35,14 @@ namespace ngraph
/// \param atol Absolute tolerance /// \param atol Absolute tolerance
/// \returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|. /// \returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T> template <typename T>
typename std::enable_if<std::is_signed<T>::value, bool>::type typename std::enable_if<std::is_signed<T>::value, ::testing::AssertionResult>::type
all_close(const std::vector<T>& a, all_close(const std::vector<T>& a,
const std::vector<T>& b, const std::vector<T>& b,
T rtol = static_cast<T>(1e-5), T rtol = static_cast<T>(1e-5),
T atol = static_cast<T>(1e-8)) T atol = static_cast<T>(1e-8))
{ {
bool rc = true; bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
assert(a.size() == b.size()); assert(a.size() == b.size());
size_t count = 0; size_t count = 0;
for (size_t i = 0; i < a.size(); ++i) for (size_t i = 0; i < a.size(); ++i)
...@@ -50,19 +52,15 @@ namespace ngraph ...@@ -50,19 +52,15 @@ namespace ngraph
{ {
if (count < 5) if (count < 5)
{ {
NGRAPH_INFO ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i] << " is not close to " << b[i] << " at index " << i << "\n";
<< a[i] << " is not close to " << b[i] << " at index " << i;
} }
count++; count++;
rc = false; rc = false;
} }
} }
if (!rc) ar_fail << "diff count: " << count << " out of " << a.size() << "\n";
{ return rc ? ::testing::AssertionSuccess() : ar_fail;
NGRAPH_INFO << "diff count: " << count << " out of " << a.size();
}
return rc;
} }
/// \brief Same as numpy.allclose /// \brief Same as numpy.allclose
...@@ -72,24 +70,25 @@ namespace ngraph ...@@ -72,24 +70,25 @@ namespace ngraph
/// \param atol Absolute tolerance /// \param atol Absolute tolerance
/// \returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|. /// \returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T> template <typename T>
typename std::enable_if<std::is_unsigned<T>::value, bool>::type typename std::enable_if<std::is_unsigned<T>::value, ::testing::AssertionResult>::type
all_close(const std::vector<T>& a, all_close(const std::vector<T>& a,
const std::vector<T>& b, const std::vector<T>& b,
T rtol = static_cast<T>(1e-5), T rtol = static_cast<T>(1e-5),
T atol = static_cast<T>(1e-8)) T atol = static_cast<T>(1e-8))
{ {
bool rc = true; bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
assert(a.size() == b.size()); assert(a.size() == b.size());
for (size_t i = 0; i < a.size(); ++i) for (size_t i = 0; i < a.size(); ++i)
{ {
T abs_diff = (a[i] > b[i]) ? (a[i] - b[i]) : (b[i] - a[i]); T abs_diff = (a[i] > b[i]) ? (a[i] - b[i]) : (b[i] - a[i]);
if (abs_diff > atol + rtol * b[i]) if (abs_diff > atol + rtol * b[i])
{ {
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i; ar_fail << a[i] << " is not close to " << b[i] << " at index " << i;
rc = false; rc = false;
} }
} }
return rc; return rc ? ::testing::AssertionSuccess() : ar_fail;
} }
/// \brief Same as numpy.allclose /// \brief Same as numpy.allclose
...@@ -99,20 +98,22 @@ namespace ngraph ...@@ -99,20 +98,22 @@ namespace ngraph
/// \param atol Absolute tolerance /// \param atol Absolute tolerance
/// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|. /// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T> template <typename T>
bool all_close(const std::shared_ptr<ngraph::runtime::Tensor>& a, ::testing::AssertionResult all_close(const std::shared_ptr<ngraph::runtime::Tensor>& a,
const std::shared_ptr<ngraph::runtime::Tensor>& b, const std::shared_ptr<ngraph::runtime::Tensor>& b,
T rtol = 1e-5f, T rtol = 1e-5f,
T atol = 1e-8f) T atol = 1e-8f)
{ {
// Check that the layouts are compatible // Check that the layouts are compatible
if (*a->get_tensor_layout() != *b->get_tensor_layout()) if (*a->get_tensor_layout() != *b->get_tensor_layout())
{ {
throw ngraph_error("Cannot compare tensors with different layouts"); return ::testing::AssertionFailure()
<< "Cannot compare tensors with different layouts";
} }
if (a->get_shape() != b->get_shape()) if (a->get_shape() != b->get_shape())
{ {
return false; return ::testing::AssertionFailure()
<< "Cannot compare tensors with different shapes";
} }
return all_close(read_vector<T>(a), read_vector<T>(b), rtol, atol); return all_close(read_vector<T>(a), read_vector<T>(b), rtol, atol);
...@@ -125,23 +126,26 @@ namespace ngraph ...@@ -125,23 +126,26 @@ namespace ngraph
/// \param atol Absolute tolerance /// \param atol Absolute tolerance
/// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|. /// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T> template <typename T>
bool all_close(const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& as, ::testing::AssertionResult
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& bs, all_close(const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& as,
T rtol, const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& bs,
T atol) T rtol,
T atol)
{ {
if (as.size() != bs.size()) if (as.size() != bs.size())
{ {
return false; return ::testing::AssertionFailure()
<< "Cannot compare tensors with different sizes";
} }
for (size_t i = 0; i < as.size(); ++i) for (size_t i = 0; i < as.size(); ++i)
{ {
if (!all_close(as[i], bs[i], rtol, atol)) auto ar = all_close(as[i], bs[i], rtol, atol);
if (!ar)
{ {
return false; return ar;
} }
} }
return true; return ::testing::AssertionSuccess();
} }
} }
} }
...@@ -235,15 +235,16 @@ uint64_t test::matching_mantissa_bits(uint64_t distance) ...@@ -235,15 +235,16 @@ uint64_t test::matching_mantissa_bits(uint64_t distance)
return matching_matissa_bits; return matching_matissa_bits;
} }
bool test::all_close_f(const vector<float>& a, ::testing::AssertionResult test::all_close_f(const vector<float>& a,
const vector<float>& b, const vector<float>& b,
int mantissa_bits, int mantissa_bits,
int tolerance_bits) int tolerance_bits)
{ {
bool rc = true; bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
if (a.size() != b.size()) if (a.size() != b.size())
{ {
throw ngraph_error("a.size() != b.size() for all_close_f comparison."); return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
} }
vector<uint32_t> distances = float_distances(a, b); vector<uint32_t> distances = float_distances(a, b);
...@@ -274,8 +275,8 @@ bool test::all_close_f(const vector<float>& a, ...@@ -274,8 +275,8 @@ bool test::all_close_f(const vector<float>& a,
{ {
if (diff_count < 5) if (diff_count < 5)
{ {
NGRAPH_INFO << std::setprecision(std::numeric_limits<long double>::digits10 + 1) ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i]
<< a[i] << " is not close to " << b[i] << " at index " << i; << " is not close to " << b[i] << " at index " << i << "\n";
} }
rc = false; rc = false;
...@@ -284,7 +285,7 @@ bool test::all_close_f(const vector<float>& a, ...@@ -284,7 +285,7 @@ bool test::all_close_f(const vector<float>& a,
} }
if (!rc) if (!rc)
{ {
NGRAPH_INFO << "diff count: " << diff_count << " out of " << a.size(); ar_fail << "diff count: " << diff_count << " out of " << a.size() << "\n";
} }
// Find median value via partial sorting // Find median value via partial sorting
size_t middle = distances.size() / 2; size_t middle = distances.size() / 2;
...@@ -298,30 +299,32 @@ bool test::all_close_f(const vector<float>& a, ...@@ -298,30 +299,32 @@ bool test::all_close_f(const vector<float>& a,
median_distance = median_sum / 2; median_distance = median_sum / 2;
} }
NGRAPH_INFO << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits (" ar_fail << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)"; << mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n";
NGRAPH_INFO << std::setprecision(std::numeric_limits<long double>::digits10 + 1) ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match: " << matching_mantissa_bits(min_distance) << "tightest match: " << matching_mantissa_bits(min_distance) << " mantissa bits ("
<< " mantissa bits (" << a[min_distance_index] << " vs " << b[min_distance_index] << a[min_distance_index] << " vs " << b[min_distance_index] << " at ["
<< " at [" << min_distance_index << "])"; << min_distance_index << "])\n";
NGRAPH_INFO << std::setprecision(std::numeric_limits<long double>::digits10 + 1) ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match: " << matching_mantissa_bits(max_distance) << "loosest match: " << matching_mantissa_bits(max_distance) << " mantissa bits ("
<< " mantissa bits (" << a[max_distance_index] << " vs " << b[max_distance_index] << a[max_distance_index] << " vs " << b[max_distance_index] << " at ["
<< " at [" << max_distance_index << "])"; << max_distance_index << "])\n";
NGRAPH_INFO << "median match: " << matching_mantissa_bits(median_distance) ar_fail << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits"; << " mantissa bits\n";
return rc; return rc ? ::testing::AssertionSuccess() : ar_fail;
} }
bool test::all_close_f(const vector<double>& a, const vector<double>& b, int tolerance_bits) ::testing::AssertionResult
test::all_close_f(const vector<double>& a, const vector<double>& b, int tolerance_bits)
{ {
constexpr int mantissa_bits = 53; constexpr int mantissa_bits = 53;
bool rc = true; bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
if (a.size() != b.size()) if (a.size() != b.size())
{ {
throw ngraph_error("a.size() != b.size() for all_close_f comparison."); return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
} }
vector<uint64_t> distances = float_distances(a, b); vector<uint64_t> distances = float_distances(a, b);
...@@ -352,17 +355,14 @@ bool test::all_close_f(const vector<double>& a, const vector<double>& b, int tol ...@@ -352,17 +355,14 @@ bool test::all_close_f(const vector<double>& a, const vector<double>& b, int tol
{ {
if (diff_count < 5) if (diff_count < 5)
{ {
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i; ar_fail << a[i] << " is not close to " << b[i] << " at index " << i << "\n";
} }
rc = false; rc = false;
diff_count++; diff_count++;
} }
} }
if (!rc) ar_fail << "diff count: " << diff_count << " out of " << a.size() << "\n";
{
NGRAPH_INFO << "diff count: " << diff_count << " out of " << a.size();
}
// Find median value via partial sorting // Find median value via partial sorting
size_t middle = distances.size() / 2; size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end()); std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
...@@ -376,54 +376,56 @@ bool test::all_close_f(const vector<double>& a, const vector<double>& b, int tol ...@@ -376,54 +376,56 @@ bool test::all_close_f(const vector<double>& a, const vector<double>& b, int tol
(median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2); (median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2);
} }
NGRAPH_INFO << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits (" ar_fail << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)"; << mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)\n";
NGRAPH_INFO << "tightest match: " << matching_mantissa_bits(min_distance) ar_fail << "tightest match: " << matching_mantissa_bits(min_distance) << " mantissa bits ("
<< " mantissa bits (" << a[min_distance_index] << " vs " << b[min_distance_index] << a[min_distance_index] << " vs " << b[min_distance_index] << " at ["
<< " at [" << min_distance_index << "])"; << min_distance_index << "])\n";
NGRAPH_INFO << "loosest match: " << matching_mantissa_bits(max_distance) ar_fail << "loosest match: " << matching_mantissa_bits(max_distance) << " mantissa bits ("
<< " mantissa bits (" << a[max_distance_index] << " vs " << b[max_distance_index] << a[max_distance_index] << " vs " << b[max_distance_index] << " at ["
<< " at [" << max_distance_index << "])"; << max_distance_index << "])\n";
NGRAPH_INFO << "median match: " << matching_mantissa_bits(median_distance) ar_fail << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits"; << " mantissa bits\n";
return rc; return rc ? ::testing::AssertionSuccess() : ar_fail;
} }
bool test::all_close_f(const std::shared_ptr<runtime::Tensor>& a, ::testing::AssertionResult test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b, const std::shared_ptr<runtime::Tensor>& b,
int mantissa_bits, int mantissa_bits,
int tolerance_bits) int tolerance_bits)
{ {
// Check that the layouts are compatible // Check that the layouts are compatible
if (*a->get_tensor_layout() != *b->get_tensor_layout()) if (*a->get_tensor_layout() != *b->get_tensor_layout())
{ {
throw ngraph_error("Cannot compare tensors with different layouts"); return ::testing::AssertionFailure() << "Cannot compare tensors with different layouts";
} }
if (a->get_shape() != b->get_shape()) if (a->get_shape() != b->get_shape())
{ {
return false; return ::testing::AssertionFailure() << "Cannot compare tensors with different shapes";
} }
return test::all_close_f( return test::all_close_f(
read_float_vector(a), read_float_vector(b), mantissa_bits, tolerance_bits); read_float_vector(a), read_float_vector(b), mantissa_bits, tolerance_bits);
} }
bool test::all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as, ::testing::AssertionResult
const std::vector<std::shared_ptr<runtime::Tensor>>& bs, test::all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
int mantissa_bits, const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int tolerance_bits) int mantissa_bits,
int tolerance_bits)
{ {
if (as.size() != bs.size()) if (as.size() != bs.size())
{ {
return false; return ::testing::AssertionFailure() << "Cannot compare tensors with different sizes";
} }
for (size_t i = 0; i < as.size(); ++i) for (size_t i = 0; i < as.size(); ++i)
{ {
if (!test::all_close_f(as[i], bs[i], mantissa_bits, tolerance_bits)) auto ar = test::all_close_f(as[i], bs[i], mantissa_bits, tolerance_bits);
if (!ar)
{ {
return false; return ar;
} }
} }
return true; return ::testing::AssertionSuccess();
} }
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "gtest/gtest.h"
#include "test_tools.hpp" #include "test_tools.hpp"
namespace ngraph namespace ngraph
...@@ -143,20 +144,20 @@ namespace ngraph ...@@ -143,20 +144,20 @@ namespace ngraph
/// \param b Second number to compare /// \param b Second number to compare
/// \param mantissa_bits The mantissa width of the underlying number before casting to float /// \param mantissa_bits The mantissa width of the underlying number before casting to float
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// \returns true iff the two floating point vectors are close /// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
bool all_close_f(const std::vector<float>& a, ::testing::AssertionResult all_close_f(const std::vector<float>& a,
const std::vector<float>& b, const std::vector<float>& b,
int mantissa_bits = 8, int mantissa_bits = 8,
int tolerance_bits = 2); int tolerance_bits = 2);
/// \brief Check if the two double floating point vectors are all close /// \brief Check if the two double floating point vectors are all close
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// \returns true iff the two floating point vectors are close /// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
bool all_close_f(const std::vector<double>& a, ::testing::AssertionResult all_close_f(const std::vector<double>& a,
const std::vector<double>& b, const std::vector<double>& b,
int tolerance_bits = 2); int tolerance_bits = 2);
/// \brief Check if the two TensorViews are all close in float /// \brief Check if the two TensorViews are all close in float
/// \param a First Tensor to compare /// \param a First Tensor to compare
...@@ -164,10 +165,10 @@ namespace ngraph ...@@ -164,10 +165,10 @@ namespace ngraph
/// \param mantissa_bits The mantissa width of the underlying number before casting to float /// \param mantissa_bits The mantissa width of the underlying number before casting to float
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// Returns true iff the two TensorViews are all close in float /// Returns true iff the two TensorViews are all close in float
bool all_close_f(const std::shared_ptr<runtime::Tensor>& a, ::testing::AssertionResult all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b, const std::shared_ptr<runtime::Tensor>& b,
int mantissa_bits = 8, int mantissa_bits = 8,
int tolerance_bits = 2); int tolerance_bits = 2);
/// \brief Check if the two vectors of TensorViews are all close in float /// \brief Check if the two vectors of TensorViews are all close in float
/// \param as First vector of Tensor to compare /// \param as First vector of Tensor to compare
...@@ -175,9 +176,10 @@ namespace ngraph ...@@ -175,9 +176,10 @@ namespace ngraph
/// \param mantissa_bits The mantissa width of the underlying number before casting to float /// \param mantissa_bits The mantissa width of the underlying number before casting to float
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// Returns true iff the two TensorViews are all close in float /// Returns true iff the two TensorViews are all close in float
bool all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as, ::testing::AssertionResult
const std::vector<std::shared_ptr<runtime::Tensor>>& bs, all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
int mantissa_bits = 8, const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int tolerance_bits = 2); int mantissa_bits = 8,
int tolerance_bits = 2);
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment