Unverified Commit 37380b8d authored by Fenglei's avatar Fenglei Committed by GitHub

Merge branch 'master' into master

parents b3d70927 db554b5d
......@@ -106,9 +106,6 @@ non-device-specific optimizations:
with nGraph.
- **Memory management** -- Prevent peak memory usage by intercepting
a graph with or by a "saved checkpoint," and to enable data auditing.
- **Data layout abstraction** -- Make abstraction easier and faster
with nGraph translating element order to work best for whatever given
or available device.
Beta Limitations
----------------
......
......@@ -16,7 +16,7 @@ workloads on CPU for inference, please refer to the links below.
| Framework (Version) | Installation guide | Notes
|----------------------------|----------------------------------------|-----------------------------------
| TensorFlow* 1.12 | [Pip install](https://github.com/NervanaSystems/ngraph-tf) or [Build from source](https://github.com/NervanaSystems/ngraph-tf) | 20 [Validated workloads]
| TensorFlow* 1.12 | [Pip install](https://github.com/NervanaSystems/ngraph-tf/tree/v0.8.0#option-1-use-a-pre-built-ngraph-tensorflow-bridge) or [Build from source](https://github.com/NervanaSystems/ngraph-tf/tree/v0.8.0#option-2-build-ngraph-bridge-from-source-using-tensorflow-source) | 20 [Validated workloads]
| MXNet* 1.3 | [Pip install](https://github.com/NervanaSystems/ngraph-mxnet#Installation) or [Build from source](https://github.com/NervanaSystems/ngraph-mxnet#building-with-ngraph-support)| 18 [Validated workloads]
| ONNX 1.3 | [Pip install](https://github.com/NervanaSystems/ngraph-onnx#installation) | 14 [Validated workloads]
......@@ -93,7 +93,7 @@ to improve it:
[contrib guide]: https://ngraph.nervanasys.com/docs/latest/project/code-contributor-README.html
[pull request]: https://github.com/NervanaSystems/ngraph/pulls
[how to import]: https://ngraph.nervanasys.com/docs/latest/howto/import.html
[ngraph_wireframes_with_notice]: doc/sphinx/source/graphics/ngraph_wireframes_with_notice.png "nGraph wireframe"
[ngraph_wireframes_with_notice]: doc/sphinx/source/graphics/ngraph_wireframes_with_notice_updated.png "nGraph wireframe"
[ngraph-compiler-stack-readme]: doc/sphinx/source/graphics/ngraph-compiler-stack-readme.png "nGraph Compiler Stack"
[build-status]: https://travis-ci.org/NervanaSystems/ngraph/branches
[build-status-badge]: https://travis-ci.org/NervanaSystems/ngraph.svg?branch=master
......
......@@ -1807,7 +1807,7 @@ SEARCH_INCLUDES = YES
# preprocessor.
# This tag requires that the tag SEARCH_INCLUDES is set to YES.
INCLUDE_PATH =
INCLUDE_PATH = ../../src
# You can use the INCLUDE_FILE_PATTERNS tag to specify one or more wildcard
# patterns (like *.h and *.hpp) to filter out the header-files in the
......
.. batch_norm_inference.rst:
##################
BatchNormInference
##################
.. code-block:: cpp
BatchNormInference // Adjust input for mean and variance
Description
===========
Inputs
------
+---------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+=====================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+---------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``mean`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``variances`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
Attributes
----------
+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+
Outputs
-------
+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``normalized`` | same as ``gamma`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+
Mathematical Definition
=======================
The axes of the input fall into two categories: positional and channel, with
channel being axis 1. For each position, there are :math:`C` channel values,
each normalized independently.
Normalization of a channel sample is controlled by two values:
* the `mean` :math:`\mu`, and
* the `variance` :math:`\sigma^2`;
and by two scaling attributes: :math:`\gamma` and :math:`\beta`.
.. math::
\mathtt{normalized}_{\bullet, c, \ldots} = \frac{\mathtt{input}_{\bullet, c, \ldots}-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c
C++ Interface
==============
.. doxygenclass:: ngraph::op::BatchNormInference
:project: ngraph
:members:
.. batch_norm.rst:
.. batch_norm_training.rst:
#########
BatchNorm
#########
#################
BatchNormTraining
#################
.. code-block:: cpp
BatchNorm // Produces a normalized output
BatchNormTraining // Compute mean and variance from the input.
Description
===========
Produces a normalized output.
Inputs
------
+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``input`` | same as ``gamma`` | \(..., C, ...\) |
+---------------------+-------------------------+-----------------------------+
| ``gamma`` | any | \(C\) |
+---------------------+-------------------------+-----------------------------+
| ``beta`` | same as ``gamma`` | \(C\) |
+---------------------+-------------------------+-----------------------------+
| ``global_mean`` | same as ``gamma`` | \(C\) |
+---------------------+-------------------------+-----------------------------+
| ``global_variance`` | same as ``gamma`` | \(C\) |
+---------------------+-------------------------+-----------------------------+
| ``use_global`` | ``bool`` | \(\) |
+---------------------+-------------------------+-----------------------------+
+---------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+=====================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+---------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
Attributes
----------
+------------------+--------------------+---------------------+
| Name | Type | Notes |
+==================+====================+=====================+
| ``epsilon`` | same as ``input`` | Bias for variance |
+------------------+--------------------+---------------------+
| ``channel_axis`` | size_t | Channel axis |
+------------------+--------------------+---------------------+
+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+
Outputs
-------
......@@ -51,16 +43,15 @@ Outputs
+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``normalized`` | same as ``gamma`` | same as ``input`` |
| ``normalized`` | same as ``gamma`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+
| ``batch_mean`` | same as ``gamma`` | \(C\) |
| ``batch_mean`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
| ``batch_variance`` | same as ``gamma`` | \(C\) |
| ``batch_variance`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
The ``batch_mean`` and ``batch_variance`` outputs are computed per-channel from
``input``. The values only need to be computed if ``use_global`` is ``false``,
or if they are used.
``input``.
Mathematical Definition
......@@ -72,22 +63,29 @@ each normalized independently.
Normalization of a channel sample is controlled by two values:
* the mean :math:`\mu`, and
* the variance :math:`\sigma^2`;
* the `batch_mean` :math:`\mu`, and
* the `batch_variance` :math:`\sigma^2`;
and by two scaling attributes: :math:`\gamma` and :math:`\beta`.
The values for :math:`\mu` and :math:`\sigma^2` come either from computing the
mean and variance of ``input``, or from ``global_mean`` and ``global_variance``,
depending on the value of ``use_global``.
The values for :math:`\mu` and :math:`\sigma^2` come from computing the
mean and variance of ``input``.
.. math::
\mu_c &= \mathop{\mathbb{E}}\left(\mathtt{input}_{\bullet, c, \ldots}\right)\\
\sigma^2_c &= \mathop{\mathtt{Var}}\left(\mathtt{input}_{\bullet, c, \ldots}\right)\\
\mathtt{normlized}_{\bullet, c, \ldots} &= \frac{\mathtt{input}_{\bullet, c, \ldots}-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c
Backprop
========
.. math::
y_c = \frac{x_c-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c
[\overline{\texttt{input}}, \overline{\texttt{gamma}}, \overline{\texttt{beta}}]=\\
\mathop{\texttt{BatchNormTrainingBackprop}}(\texttt{input},\texttt{gamma},\texttt{beta},\texttt{mean},\texttt{variance},\overline{\texttt{normed_input}}).
The mean and variance can be arguments, or they may be computed for each channel
of ``input`` over the positional axes. When computed from ``input``, the mean
and variance per-channel are available as outputs.
C++ Interface
......@@ -98,8 +96,3 @@ C++ Interface
:members:
.. doxygenclass:: ngraph::op::BatchNormInference
:project: ngraph
:members:
.. batch_norm_training_backprop.rst:
#########################
BatchNormTrainingBackprop
#########################
.. code-block:: cpp
BatchNormTrainingBackprop // Compute mean and variance backprop from the input.
Description
===========
Computes the ``input``, ``gamma`` and ``beta`` backprop increments.
Inputs
------
+----------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+======================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+----------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``mean`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``variance`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``normalized_delta`` | same as ``input`` | same as ``input`` |
+----------------------+-------------------------+------------------------------+
Attributes
----------
+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+
Outputs
-------
+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``input_delta`` | same as ``input`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+
| ``gamma_delta`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
| ``beta_delta`` | same as ``beta`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
Mathematical Definition
=======================
It is easiest to simplify by looking at a single channel and flattening the
remaining axes into a vector; so ``gamma`` and ``beta`` are scalars, and ``input`` is an
:math:`N`-element vector.
The step by step forward training computation is
.. math::
\mathtt{mean} &= \frac{\sum{\mathtt{input}_i}}{N}\\
\mathtt{centered}_i &= \mathtt{input}_i - \mathtt{mean}\\
\mathtt{square}_i &= \mathtt{centered}_i^2\\
\mathtt{variance} &= \frac{\sum \mathtt{square}_i}{N}\\
\mathtt{invsqrt} &= \frac{1}{\sqrt{\mathtt{variance}+\epsilon}}\\
\mathtt{gmul} &= \texttt{gamma}\cdot \mathtt{invsqrt}\\
\mathtt{normed}_i &= \mathtt{centered}_i\mathtt{gmul}+\texttt{beta}
Using the notation :math:`\overline{\texttt{name}}` for :math:`\texttt{name_delta}`
and :math:`\overline{x} \leftarrow y`
to mean the backprop value for :math:`\texttt{x_delta}` is a sum that includes :math:`y`.
We work backwards
.. math::
\overline{\texttt{beta}}&\leftarrow \overline{\texttt{normed}}\\
\overline{\texttt{gmul}}&\leftarrow \sum \overline{\texttt{normed}}_i\\
\overline{\texttt{centered}}_i&\leftarrow\overline{\texttt{normed}}_i\texttt{gmul}\\
\overline{\texttt{gamma}}&\leftarrow \overline{\texttt{gmul}}\cdot\texttt{invsqrt}\\
\overline{\texttt{invsqrt}}&\leftarrow\texttt{gamma}\cdot\overline{\texttt{gmul}}\\
\overline{\texttt{variance}}&\leftarrow -\frac{\overline{\texttt{invsqrt}}\cdot\texttt{invsqrt}}{2\cdot(\texttt{variance}+\epsilon)}\\
\overline{\texttt{square}}_i&\leftarrow\frac{\overline{\texttt{variance}}}{N}\\
\overline{\texttt{centered}}_i&\leftarrow 2\cdot\texttt{centered}_i\cdot\overline{\texttt{square}}_i\\
\overline{\texttt{input}}_i&\leftarrow\overline{\texttt{centered}}_i\\
\overline{\texttt{mean}}&\leftarrow\sum\overline{\texttt{centered}}_i\\
\overline{\texttt{input}}_i&\leftarrow\frac{\overline{\texttt{mean}}}{N}
C++ Interface
==============
.. doxygenclass:: ngraph::op::BatchNormTrainingBackprop
:project: ngraph
:members:
......@@ -56,7 +56,9 @@ Not currently a comprehensive list.
* :doc:`atan`
* :doc:`avg_pool`
* :doc:`avg_pool_backprop`
* :doc:`batch_norm`
* :doc:`batch_norm_inference`
* :doc:`batch_norm_training`
* :doc:`batch_norm_training_backprop`
* :doc:`broadcast`
* :doc:`ceiling`
* :doc:`concat`
......@@ -123,7 +125,9 @@ Not currently a comprehensive list.
atan.rst
avg_pool.rst
avg_pool_backprop.rst
batch_norm.rst
batch_norm_inference.rst
batch_norm_training.rst
batch_norm_training_backprop.rst
broadcast.rst
ceiling.rst
concat.rst
......
......@@ -27,7 +27,7 @@ and multi-device support of nGraph Compiler, please refer to [Framework integrat
| Framework & Runtime | Supported | Validated
|----------------------------|--------------------|-------------
| TensorFlow* 1.12 | :heavy_check_mark: | :heavy_check_mark:
| MXNet* 1.4 | :heavy_check_mark: | :heavy_check_mark:
| MXNet* 1.3 | :heavy_check_mark: | :heavy_check_mark:
| ONNX 1.3 | :heavy_check_mark: | :heavy_check_mark:
| ONNX Runtime Functional | Functional | No
| PyTorch (via ONNXIFI) | Functional | No
......@@ -56,7 +56,7 @@ stack, and early adopters will be able test them in 2019.
| Backend | supported
| Backend | Supported
|-----------------------------------------------|-------------------
| Intel® Architecture CPU | :heavy_check_mark:
| Intel® Architecture GPUs | Functional via clDNN and PlaidML
......
......@@ -35,21 +35,24 @@ namespace ngraph
const AxisSet& reduction_axes)
{
CoordinateTransform output_transform(out_shape);
std::vector<T> c(shape_size(out_shape));
for (const Coordinate& output_coord : output_transform)
{
out[output_transform.index(output_coord)] = 0;
c[output_transform.index(output_coord)] = 0;
}
CoordinateTransform input_transform(in_shape);
T c = 0;
for (const Coordinate& input_coord : input_transform)
{
Coordinate output_coord = reduce(input_coord, reduction_axes);
T y = arg[input_transform.index(input_coord)] - c;
T y = arg[input_transform.index(input_coord)] -
c[output_transform.index(output_coord)];
T t = out[output_transform.index(output_coord)] + y;
c = (t - out[output_transform.index(output_coord)]) - y;
c[output_transform.index(output_coord)] =
(t - out[output_transform.index(output_coord)]) - y;
out[output_transform.index(output_coord)] = t;
}
}
......
This diff is collapsed.
......@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <climits>
#include <cmath>
#include "util/all_close_f.hpp"
......@@ -26,12 +27,11 @@ union FloatUnion {
uint32_t i;
};
bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
uint32_t test::float_distance(float a, float b)
{
// isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b))
{
return false;
return UINT_MAX;
}
FloatUnion a_fu{a};
......@@ -47,6 +47,18 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);
uint32_t distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
return distance;
}
bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
{
// isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b))
{
return false;
}
uint32_t distance = float_distance(a, b);
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
......@@ -57,6 +69,64 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
return distance <= tolerance;
}
vector<uint32_t> test::float_distances(const vector<float>& a, const vector<float>& b)
{
if (a.size() != b.size())
{
throw ngraph_error("a.size() != b.size() for float_distances comparison.");
}
vector<uint32_t> distances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
distances[i] = float_distance(a[i], b[i]);
}
return distances;
}
uint32_t test::matching_mantissa_bits(uint32_t distance)
{
uint32_t tolerance_needed = distance;
if (tolerance_needed < 0x80000000)
{
// Set up the dominos - turn on all the bits below maximal bit
tolerance_needed |= tolerance_needed >> 1;
tolerance_needed |= tolerance_needed >> 2;
tolerance_needed |= tolerance_needed >> 4;
tolerance_needed |= tolerance_needed >> 8;
tolerance_needed |= tolerance_needed >> 16;
// Tumble the dominos so we end up with next highest bit
++tolerance_needed;
// all_close_f is <= test for tolerance
if ((tolerance_needed >> 1) == distance)
{
tolerance_needed = distance;
}
}
uint32_t tolerance_bit_shift = 0;
while (tolerance_needed >>= 1)
{
++tolerance_bit_shift;
}
// all_close_f calculation of tolerance_bit_shift:
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
// float_length sign exp matching_matissa_bits implicit 1 tolerance_bits
//
// Assuming 0 tolerance_bits and solving for matching_matissa_bits yields:
// tolerance_bit_shift = 32 - (1 + 8 + (matching_matissa_bits - 1 ) - 0 )
// tolerance_bit_shift = 32 - (1 + 8 + (matching_matissa_bits - 1 ) )
// matching_matissa_bits = 32 - (1 + 8 + (tolerance_bit_shift - 1 ) )
uint32_t matching_matissa_bits =
tolerance_bit_shift < 24 ? (32 - (1 + 8 + (tolerance_bit_shift - 1))) : 0;
return matching_matissa_bits;
}
bool test::all_close_f(const vector<float>& a,
const vector<float>& b,
int mantissa_bits,
......@@ -65,27 +135,70 @@ bool test::all_close_f(const vector<float>& a,
bool rc = true;
if (a.size() != b.size())
{
throw ngraph_error("a.size() != b.size() for all_close comparison.");
throw ngraph_error("a.size() != b.size() for all_close_f comparison.");
}
size_t count = 0;
vector<uint32_t> distances = float_distances(a, b);
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
// float_length sign exp mantissa implicit 1 tolerance_bits
uint32_t tolerance_bit_shift = 32 - (1 + 8 + (mantissa_bits - 1) - tolerance_bits);
uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;
uint32_t max_distance = 0;
uint32_t min_distance = UINT_MAX;
size_t max_distance_index = 0;
size_t min_distance_index = 0;
size_t diff_count = 0;
for (size_t i = 0; i < a.size(); ++i)
{
bool is_close_f = close_f(a[i], b[i], mantissa_bits, tolerance_bits);
if (distances[i] > max_distance)
{
max_distance = distances[i];
max_distance_index = i;
}
if (distances[i] < min_distance)
{
min_distance = distances[i];
min_distance_index = i;
}
bool is_close_f = distances[i] <= tolerance;
if (!is_close_f)
{
if (count < 5)
if (diff_count < 5)
{
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i;
}
rc = false;
count++;
diff_count++;
}
}
if (!rc)
{
NGRAPH_INFO << "diff count: " << count << " out of " << a.size();
NGRAPH_INFO << "diff count: " << diff_count << " out of " << a.size();
}
// Find median value via partial sorting
size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
uint32_t median_distance = distances[middle];
if (distances.size() % 2 == 0)
{
// Find middle-1 value
uint64_t median_sum = static_cast<uint64_t>(median_distance) +
*max_element(distances.begin(), distances.begin() + middle);
median_distance = median_sum / 2;
}
NGRAPH_INFO << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)";
NGRAPH_INFO << "tightest match: " << matching_mantissa_bits(min_distance)
<< " mantissa bits (" << a[min_distance_index] << " vs " << b[min_distance_index]
<< " at [" << min_distance_index << "])";
NGRAPH_INFO << "loosest match: " << matching_mantissa_bits(max_distance)
<< " mantissa bits (" << a[max_distance_index] << " vs " << b[max_distance_index]
<< " at [" << max_distance_index << "])";
NGRAPH_INFO << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits";
return rc;
}
......
......@@ -25,6 +25,27 @@ namespace ngraph
{
namespace test
{
/// \brief Determine distance between two f32 numbers
/// \param a First number to compare
/// \param b Second number to compare
/// \returns Distance
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/google/googletest/blob/master/googletest/docs/AdvancedGuide.md#floating-point-comparison
///
/// s e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m
/// |------------bfloat-----------|
/// |----------------------------float----------------------------|
///
/// bfloat (s1, e8, m7) has 7 + 1 = 8 bits of mantissa or bit_precision
/// float (s1, e8, m23) has 23 + 1 = 24 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for
/// bfloat and f32.
uint32_t float_distance(float a, float b);
/// \brief Check if the two f32 numbers are close
/// \param a First number to compare
/// \param b Second number to compare
......@@ -48,6 +69,22 @@ namespace ngraph
/// bfloat and f32.
bool close_f(float a, float b, int mantissa_bits = 8, int tolerance_bits = 2);
/// \brief Determine distances between two vectors of f32 numbers
/// \param a Vector of floats to compare
/// \param b Vector of floats to compare
/// \returns Vector of distances
///
/// See float_distance for limitations and assumptions.
std::vector<uint32_t> float_distances(const std::vector<float>& a,
const std::vector<float>& b);
/// \brief Determine number of matching mantissa bits given a distance
/// \param distance Distance calculated by float_distance
/// \returns Number of matching mantissa bits
///
/// See float_distance for limitations and assumptions.
uint32_t matching_mantissa_bits(uint32_t distance);
/// \brief Check if the two floating point vectors are all close
/// \param a First number to compare
/// \param b Second number to compare
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment