Commit abd1c70d authored by gcwenger's avatar gcwenger Committed by Scott Cyphers

Added "min_signal" paramater to float testing (#2653)

min_signal is intended to optionally skip checking float
distances when numbers are close enough to 0.
Only skips when both numbers are < min_signal.
Intention is to allow tighter float testing in certain
cases where most values are not near 0, but values near
zero are differing by more bits than values farther from 0.
Should be used with caution in limited cases.
parent c6c2aafb
This diff is collapsed.
This diff is collapsed.
...@@ -66,6 +66,7 @@ namespace ngraph ...@@ -66,6 +66,7 @@ namespace ngraph
/// \brief Determine distance between two f32 numbers /// \brief Determine distance between two f32 numbers
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
/// \param min_signal Minimum value for comparisons
/// \returns Distance /// \returns Distance
/// ///
/// References: /// References:
...@@ -82,11 +83,12 @@ namespace ngraph ...@@ -82,11 +83,12 @@ namespace ngraph
/// ///
/// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for /// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for
/// bfloat and f32. /// bfloat and f32.
uint32_t float_distance(float a, float b); uint32_t float_distance(float a, float b, float min_signal = 0.0f);
/// \brief Determine distance between two f64 numbers /// \brief Determine distance between two f64 numbers
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
/// \param min_signal Minimum value for comparisons
/// \returns Distance /// \returns Distance
/// ///
/// References: /// References:
...@@ -100,12 +102,13 @@ namespace ngraph ...@@ -100,12 +102,13 @@ namespace ngraph
/// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision /// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision
/// ///
/// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64. /// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64.
uint64_t float_distance(double a, double b); uint64_t float_distance(double a, double b, double min_signal = 0.0);
/// \brief Check if the two f32 numbers are close /// \brief Check if the two f32 numbers are close
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP /// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP
/// ///
/// References: /// References:
...@@ -122,12 +125,16 @@ namespace ngraph ...@@ -122,12 +125,16 @@ namespace ngraph
/// ///
/// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for /// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for
/// bfloat and f32. /// bfloat and f32.
bool close_f(float a, float b, int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS); bool close_f(float a,
float b,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
/// \brief Check if the two f64 numbers are close /// \brief Check if the two f64 numbers are close
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP /// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP
/// ///
/// References: /// References:
...@@ -141,25 +148,32 @@ namespace ngraph ...@@ -141,25 +148,32 @@ namespace ngraph
/// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision /// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision
/// ///
/// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64. /// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64.
bool close_f(double a, double b, int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS); bool close_f(double a,
double b,
int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS,
double min_signal = 0.0);
/// \brief Determine distances between two vectors of f32 numbers /// \brief Determine distances between two vectors of f32 numbers
/// \param a Vector of floats to compare /// \param a Vector of floats to compare
/// \param b Vector of floats to compare /// \param b Vector of floats to compare
/// \param min_signal Minimum value for comparisons
/// \returns Vector of distances /// \returns Vector of distances
/// ///
/// See float_distance for limitations and assumptions. /// See float_distance for limitations and assumptions.
std::vector<uint32_t> float_distances(const std::vector<float>& a, std::vector<uint32_t> float_distances(const std::vector<float>& a,
const std::vector<float>& b); const std::vector<float>& b,
float min_signal = 0.0f);
/// \brief Determine distances between two vectors of f64 numbers /// \brief Determine distances between two vectors of f64 numbers
/// \param a Vector of doubles to compare /// \param a Vector of doubles to compare
/// \param b Vector of doubles to compare /// \param b Vector of doubles to compare
/// \param min_signal Minimum value for comparisons
/// \returns Vector of distances /// \returns Vector of distances
/// ///
/// See float_distance for limitations and assumptions. /// See float_distance for limitations and assumptions.
std::vector<uint64_t> float_distances(const std::vector<double>& a, std::vector<uint64_t> float_distances(const std::vector<double>& a,
const std::vector<double>& b); const std::vector<double>& b,
double min_signal = 0.0);
/// \brief Determine number of matching mantissa bits given a distance /// \brief Determine number of matching mantissa bits given a distance
/// \param distance Distance calculated by float_distance /// \param distance Distance calculated by float_distance
...@@ -179,37 +193,45 @@ namespace ngraph ...@@ -179,37 +193,45 @@ namespace ngraph
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns ::testing::AssertionSuccess iff the two floating point vectors are close /// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
::testing::AssertionResult all_close_f(const std::vector<float>& a, ::testing::AssertionResult all_close_f(const std::vector<float>& a,
const std::vector<float>& b, const std::vector<float>& b,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS); int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
/// \brief Check if the two double floating point vectors are all close /// \brief Check if the two double floating point vectors are all close
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns ::testing::AssertionSuccess iff the two floating point vectors are close /// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
::testing::AssertionResult all_close_f(const std::vector<double>& a, ::testing::AssertionResult all_close_f(const std::vector<double>& a,
const std::vector<double>& b, const std::vector<double>& b,
int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS); int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS,
double min_signal = 0.0);
/// \brief Check if the two TensorViews are all close in float /// \brief Check if the two TensorViews are all close in float
/// \param a First Tensor to compare /// \param a First Tensor to compare
/// \param b Second Tensor to compare /// \param b Second Tensor to compare
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// Returns true iff the two TensorViews are all close in float /// Returns true iff the two TensorViews are all close in float
::testing::AssertionResult all_close_f(const std::shared_ptr<runtime::Tensor>& a, ::testing::AssertionResult all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b, const std::shared_ptr<runtime::Tensor>& b,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS); int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
/// \brief Check if the two vectors of TensorViews are all close in float /// \brief Check if the two vectors of TensorViews are all close in float
/// \param as First vector of Tensor to compare /// \param as First vector of Tensor to compare
/// \param bs Second vector of Tensor to compare /// \param bs Second vector of Tensor to compare
/// \param tolerance_bits Bit tolerance error /// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// Returns true iff the two TensorViews are all close in float /// Returns true iff the two TensorViews are all close in float
::testing::AssertionResult ::testing::AssertionResult
all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as, all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs, const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS); int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment