Unverified Commit 37380b8d authored by Fenglei's avatar Fenglei Committed by GitHub

Merge branch 'master' into master

parents b3d70927 db554b5d
...@@ -106,9 +106,6 @@ non-device-specific optimizations: ...@@ -106,9 +106,6 @@ non-device-specific optimizations:
with nGraph. with nGraph.
- **Memory management** -- Prevent peak memory usage by intercepting - **Memory management** -- Prevent peak memory usage by intercepting
a graph with or by a "saved checkpoint," and to enable data auditing. a graph with or by a "saved checkpoint," and to enable data auditing.
- **Data layout abstraction** -- Make abstraction easier and faster
with nGraph translating element order to work best for whatever given
or available device.
Beta Limitations Beta Limitations
---------------- ----------------
......
...@@ -16,7 +16,7 @@ workloads on CPU for inference, please refer to the links below. ...@@ -16,7 +16,7 @@ workloads on CPU for inference, please refer to the links below.
| Framework (Version) | Installation guide | Notes | Framework (Version) | Installation guide | Notes
|----------------------------|----------------------------------------|----------------------------------- |----------------------------|----------------------------------------|-----------------------------------
| TensorFlow* 1.12 | [Pip install](https://github.com/NervanaSystems/ngraph-tf) or [Build from source](https://github.com/NervanaSystems/ngraph-tf) | 20 [Validated workloads] | TensorFlow* 1.12 | [Pip install](https://github.com/NervanaSystems/ngraph-tf/tree/v0.8.0#option-1-use-a-pre-built-ngraph-tensorflow-bridge) or [Build from source](https://github.com/NervanaSystems/ngraph-tf/tree/v0.8.0#option-2-build-ngraph-bridge-from-source-using-tensorflow-source) | 20 [Validated workloads]
| MXNet* 1.3 | [Pip install](https://github.com/NervanaSystems/ngraph-mxnet#Installation) or [Build from source](https://github.com/NervanaSystems/ngraph-mxnet#building-with-ngraph-support)| 18 [Validated workloads] | MXNet* 1.3 | [Pip install](https://github.com/NervanaSystems/ngraph-mxnet#Installation) or [Build from source](https://github.com/NervanaSystems/ngraph-mxnet#building-with-ngraph-support)| 18 [Validated workloads]
| ONNX 1.3 | [Pip install](https://github.com/NervanaSystems/ngraph-onnx#installation) | 14 [Validated workloads] | ONNX 1.3 | [Pip install](https://github.com/NervanaSystems/ngraph-onnx#installation) | 14 [Validated workloads]
...@@ -93,7 +93,7 @@ to improve it: ...@@ -93,7 +93,7 @@ to improve it:
[contrib guide]: https://ngraph.nervanasys.com/docs/latest/project/code-contributor-README.html [contrib guide]: https://ngraph.nervanasys.com/docs/latest/project/code-contributor-README.html
[pull request]: https://github.com/NervanaSystems/ngraph/pulls [pull request]: https://github.com/NervanaSystems/ngraph/pulls
[how to import]: https://ngraph.nervanasys.com/docs/latest/howto/import.html [how to import]: https://ngraph.nervanasys.com/docs/latest/howto/import.html
[ngraph_wireframes_with_notice]: doc/sphinx/source/graphics/ngraph_wireframes_with_notice.png "nGraph wireframe" [ngraph_wireframes_with_notice]: doc/sphinx/source/graphics/ngraph_wireframes_with_notice_updated.png "nGraph wireframe"
[ngraph-compiler-stack-readme]: doc/sphinx/source/graphics/ngraph-compiler-stack-readme.png "nGraph Compiler Stack" [ngraph-compiler-stack-readme]: doc/sphinx/source/graphics/ngraph-compiler-stack-readme.png "nGraph Compiler Stack"
[build-status]: https://travis-ci.org/NervanaSystems/ngraph/branches [build-status]: https://travis-ci.org/NervanaSystems/ngraph/branches
[build-status-badge]: https://travis-ci.org/NervanaSystems/ngraph.svg?branch=master [build-status-badge]: https://travis-ci.org/NervanaSystems/ngraph.svg?branch=master
......
...@@ -1807,7 +1807,7 @@ SEARCH_INCLUDES = YES ...@@ -1807,7 +1807,7 @@ SEARCH_INCLUDES = YES
# preprocessor. # preprocessor.
# This tag requires that the tag SEARCH_INCLUDES is set to YES. # This tag requires that the tag SEARCH_INCLUDES is set to YES.
INCLUDE_PATH = INCLUDE_PATH = ../../src
# You can use the INCLUDE_FILE_PATTERNS tag to specify one or more wildcard # You can use the INCLUDE_FILE_PATTERNS tag to specify one or more wildcard
# patterns (like *.h and *.hpp) to filter out the header-files in the # patterns (like *.h and *.hpp) to filter out the header-files in the
......
.. batch_norm_inference.rst:
##################
BatchNormInference
##################
.. code-block:: cpp
BatchNormInference // Adjust input for mean and variance
Description
===========
Inputs
------
+---------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+=====================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+---------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``mean`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``variances`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
Attributes
----------
+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+
Outputs
-------
+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``normalized`` | same as ``gamma`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+
Mathematical Definition
=======================
The axes of the input fall into two categories: positional and channel, with
channel being axis 1. For each position, there are :math:`C` channel values,
each normalized independently.
Normalization of a channel sample is controlled by two values:
* the `mean` :math:`\mu`, and
* the `variance` :math:`\sigma^2`;
and by two scaling attributes: :math:`\gamma` and :math:`\beta`.
.. math::
\mathtt{normalized}_{\bullet, c, \ldots} = \frac{\mathtt{input}_{\bullet, c, \ldots}-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c
C++ Interface
==============
.. doxygenclass:: ngraph::op::BatchNormInference
:project: ngraph
:members:
.. batch_norm.rst: .. batch_norm_training.rst:
######### #################
BatchNorm BatchNormTraining
######### #################
.. code-block:: cpp .. code-block:: cpp
BatchNorm // Produces a normalized output BatchNormTraining // Compute mean and variance from the input.
Description Description
=========== ===========
Produces a normalized output.
Inputs Inputs
------ ------
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+------------------------------+
| Name | Element Type | Shape | | Name | Element Type | Shape |
+=====================+=========================+=============================+ +=====================+=========================+==============================+
| ``input`` | same as ``gamma`` | \(..., C, ...\) | | ``input`` | real | :math:`(\bullet, C, \ldots)` |
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+------------------------------+
| ``gamma`` | any | \(C\) | | ``gamma`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+------------------------------+
| ``beta`` | same as ``gamma`` | \(C\) | | ``beta`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+------------------------------+
| ``global_mean`` | same as ``gamma`` | \(C\) |
+---------------------+-------------------------+-----------------------------+
| ``global_variance`` | same as ``gamma`` | \(C\) |
+---------------------+-------------------------+-----------------------------+
| ``use_global`` | ``bool`` | \(\) |
+---------------------+-------------------------+-----------------------------+
Attributes Attributes
---------- ----------
+------------------+--------------------+---------------------+ +------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes | | Name | Type | Notes |
+==================+====================+=====================+ +==================+====================+========================================================+
| ``epsilon`` | same as ``input`` | Bias for variance | | ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+---------------------+ +------------------+--------------------+--------------------------------------------------------+
| ``channel_axis`` | size_t | Channel axis |
+------------------+--------------------+---------------------+
Outputs Outputs
------- -------
...@@ -51,16 +43,15 @@ Outputs ...@@ -51,16 +43,15 @@ Outputs
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape | | Name | Element Type | Shape |
+=====================+=========================+=============================+ +=====================+=========================+=============================+
| ``normalized`` | same as ``gamma`` | same as ``input`` | | ``normalized`` | same as ``gamma`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+-----------------------------+
| ``batch_mean`` | same as ``gamma`` | \(C\) | | ``batch_mean`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+-----------------------------+
| ``batch_variance`` | same as ``gamma`` | \(C\) | | ``batch_variance`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+-----------------------------+
The ``batch_mean`` and ``batch_variance`` outputs are computed per-channel from The ``batch_mean`` and ``batch_variance`` outputs are computed per-channel from
``input``. The values only need to be computed if ``use_global`` is ``false``, ``input``.
or if they are used.
Mathematical Definition Mathematical Definition
...@@ -72,22 +63,29 @@ each normalized independently. ...@@ -72,22 +63,29 @@ each normalized independently.
Normalization of a channel sample is controlled by two values: Normalization of a channel sample is controlled by two values:
* the mean :math:`\mu`, and * the `batch_mean` :math:`\mu`, and
* the variance :math:`\sigma^2`;
* the `batch_variance` :math:`\sigma^2`;
and by two scaling attributes: :math:`\gamma` and :math:`\beta`. and by two scaling attributes: :math:`\gamma` and :math:`\beta`.
The values for :math:`\mu` and :math:`\sigma^2` come either from computing the The values for :math:`\mu` and :math:`\sigma^2` come from computing the
mean and variance of ``input``, or from ``global_mean`` and ``global_variance``, mean and variance of ``input``.
depending on the value of ``use_global``.
.. math::
\mu_c &= \mathop{\mathbb{E}}\left(\mathtt{input}_{\bullet, c, \ldots}\right)\\
\sigma^2_c &= \mathop{\mathtt{Var}}\left(\mathtt{input}_{\bullet, c, \ldots}\right)\\
\mathtt{normlized}_{\bullet, c, \ldots} &= \frac{\mathtt{input}_{\bullet, c, \ldots}-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c
Backprop
========
.. math:: .. math::
y_c = \frac{x_c-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c [\overline{\texttt{input}}, \overline{\texttt{gamma}}, \overline{\texttt{beta}}]=\\
\mathop{\texttt{BatchNormTrainingBackprop}}(\texttt{input},\texttt{gamma},\texttt{beta},\texttt{mean},\texttt{variance},\overline{\texttt{normed_input}}).
The mean and variance can be arguments, or they may be computed for each channel
of ``input`` over the positional axes. When computed from ``input``, the mean
and variance per-channel are available as outputs.
C++ Interface C++ Interface
...@@ -98,8 +96,3 @@ C++ Interface ...@@ -98,8 +96,3 @@ C++ Interface
:members: :members:
.. doxygenclass:: ngraph::op::BatchNormInference
:project: ngraph
:members:
.. batch_norm_training_backprop.rst:
#########################
BatchNormTrainingBackprop
#########################
.. code-block:: cpp
BatchNormTrainingBackprop // Compute mean and variance backprop from the input.
Description
===========
Computes the ``input``, ``gamma`` and ``beta`` backprop increments.
Inputs
------
+----------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+======================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+----------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``mean`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``variance`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``normalized_delta`` | same as ``input`` | same as ``input`` |
+----------------------+-------------------------+------------------------------+
Attributes
----------
+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+
Outputs
-------
+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``input_delta`` | same as ``input`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+
| ``gamma_delta`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
| ``beta_delta`` | same as ``beta`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
Mathematical Definition
=======================
It is easiest to simplify by looking at a single channel and flattening the
remaining axes into a vector; so ``gamma`` and ``beta`` are scalars, and ``input`` is an
:math:`N`-element vector.
The step by step forward training computation is
.. math::
\mathtt{mean} &= \frac{\sum{\mathtt{input}_i}}{N}\\
\mathtt{centered}_i &= \mathtt{input}_i - \mathtt{mean}\\
\mathtt{square}_i &= \mathtt{centered}_i^2\\
\mathtt{variance} &= \frac{\sum \mathtt{square}_i}{N}\\
\mathtt{invsqrt} &= \frac{1}{\sqrt{\mathtt{variance}+\epsilon}}\\
\mathtt{gmul} &= \texttt{gamma}\cdot \mathtt{invsqrt}\\
\mathtt{normed}_i &= \mathtt{centered}_i\mathtt{gmul}+\texttt{beta}
Using the notation :math:`\overline{\texttt{name}}` for :math:`\texttt{name_delta}`
and :math:`\overline{x} \leftarrow y`
to mean the backprop value for :math:`\texttt{x_delta}` is a sum that includes :math:`y`.
We work backwards
.. math::
\overline{\texttt{beta}}&\leftarrow \overline{\texttt{normed}}\\
\overline{\texttt{gmul}}&\leftarrow \sum \overline{\texttt{normed}}_i\\
\overline{\texttt{centered}}_i&\leftarrow\overline{\texttt{normed}}_i\texttt{gmul}\\
\overline{\texttt{gamma}}&\leftarrow \overline{\texttt{gmul}}\cdot\texttt{invsqrt}\\
\overline{\texttt{invsqrt}}&\leftarrow\texttt{gamma}\cdot\overline{\texttt{gmul}}\\
\overline{\texttt{variance}}&\leftarrow -\frac{\overline{\texttt{invsqrt}}\cdot\texttt{invsqrt}}{2\cdot(\texttt{variance}+\epsilon)}\\
\overline{\texttt{square}}_i&\leftarrow\frac{\overline{\texttt{variance}}}{N}\\
\overline{\texttt{centered}}_i&\leftarrow 2\cdot\texttt{centered}_i\cdot\overline{\texttt{square}}_i\\
\overline{\texttt{input}}_i&\leftarrow\overline{\texttt{centered}}_i\\
\overline{\texttt{mean}}&\leftarrow\sum\overline{\texttt{centered}}_i\\
\overline{\texttt{input}}_i&\leftarrow\frac{\overline{\texttt{mean}}}{N}
C++ Interface
==============
.. doxygenclass:: ngraph::op::BatchNormTrainingBackprop
:project: ngraph
:members:
...@@ -56,7 +56,9 @@ Not currently a comprehensive list. ...@@ -56,7 +56,9 @@ Not currently a comprehensive list.
* :doc:`atan` * :doc:`atan`
* :doc:`avg_pool` * :doc:`avg_pool`
* :doc:`avg_pool_backprop` * :doc:`avg_pool_backprop`
* :doc:`batch_norm` * :doc:`batch_norm_inference`
* :doc:`batch_norm_training`
* :doc:`batch_norm_training_backprop`
* :doc:`broadcast` * :doc:`broadcast`
* :doc:`ceiling` * :doc:`ceiling`
* :doc:`concat` * :doc:`concat`
...@@ -123,7 +125,9 @@ Not currently a comprehensive list. ...@@ -123,7 +125,9 @@ Not currently a comprehensive list.
atan.rst atan.rst
avg_pool.rst avg_pool.rst
avg_pool_backprop.rst avg_pool_backprop.rst
batch_norm.rst batch_norm_inference.rst
batch_norm_training.rst
batch_norm_training_backprop.rst
broadcast.rst broadcast.rst
ceiling.rst ceiling.rst
concat.rst concat.rst
......
...@@ -27,7 +27,7 @@ and multi-device support of nGraph Compiler, please refer to [Framework integrat ...@@ -27,7 +27,7 @@ and multi-device support of nGraph Compiler, please refer to [Framework integrat
| Framework & Runtime | Supported | Validated | Framework & Runtime | Supported | Validated
|----------------------------|--------------------|------------- |----------------------------|--------------------|-------------
| TensorFlow* 1.12 | :heavy_check_mark: | :heavy_check_mark: | TensorFlow* 1.12 | :heavy_check_mark: | :heavy_check_mark:
| MXNet* 1.4 | :heavy_check_mark: | :heavy_check_mark: | MXNet* 1.3 | :heavy_check_mark: | :heavy_check_mark:
| ONNX 1.3 | :heavy_check_mark: | :heavy_check_mark: | ONNX 1.3 | :heavy_check_mark: | :heavy_check_mark:
| ONNX Runtime Functional | Functional | No | ONNX Runtime Functional | Functional | No
| PyTorch (via ONNXIFI) | Functional | No | PyTorch (via ONNXIFI) | Functional | No
...@@ -56,7 +56,7 @@ stack, and early adopters will be able test them in 2019. ...@@ -56,7 +56,7 @@ stack, and early adopters will be able test them in 2019.
| Backend | supported | Backend | Supported
|-----------------------------------------------|------------------- |-----------------------------------------------|-------------------
| Intel® Architecture CPU | :heavy_check_mark: | Intel® Architecture CPU | :heavy_check_mark:
| Intel® Architecture GPUs | Functional via clDNN and PlaidML | Intel® Architecture GPUs | Functional via clDNN and PlaidML
......
...@@ -35,21 +35,24 @@ namespace ngraph ...@@ -35,21 +35,24 @@ namespace ngraph
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
{ {
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
std::vector<T> c(shape_size(out_shape));
for (const Coordinate& output_coord : output_transform) for (const Coordinate& output_coord : output_transform)
{ {
out[output_transform.index(output_coord)] = 0; out[output_transform.index(output_coord)] = 0;
c[output_transform.index(output_coord)] = 0;
} }
CoordinateTransform input_transform(in_shape); CoordinateTransform input_transform(in_shape);
T c = 0;
for (const Coordinate& input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
Coordinate output_coord = reduce(input_coord, reduction_axes); Coordinate output_coord = reduce(input_coord, reduction_axes);
T y = arg[input_transform.index(input_coord)] - c; T y = arg[input_transform.index(input_coord)] -
c[output_transform.index(output_coord)];
T t = out[output_transform.index(output_coord)] + y; T t = out[output_transform.index(output_coord)] + y;
c = (t - out[output_transform.index(output_coord)]) - y; c[output_transform.index(output_coord)] =
(t - out[output_transform.index(output_coord)]) - y;
out[output_transform.index(output_coord)] = t; out[output_transform.index(output_coord)] = t;
} }
} }
......
This diff is collapsed.
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <climits>
#include <cmath> #include <cmath>
#include "util/all_close_f.hpp" #include "util/all_close_f.hpp"
...@@ -26,12 +27,11 @@ union FloatUnion { ...@@ -26,12 +27,11 @@ union FloatUnion {
uint32_t i; uint32_t i;
}; };
bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits) uint32_t test::float_distance(float a, float b)
{ {
// isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b)) if (!isfinite(a) || !isfinite(b))
{ {
return false; return UINT_MAX;
} }
FloatUnion a_fu{a}; FloatUnion a_fu{a};
...@@ -47,6 +47,18 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits) ...@@ -47,6 +47,18 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint); b_uint = (sign_mask & b_uint) ? (~b_uint + 1) : (sign_mask | b_uint);
uint32_t distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint); uint32_t distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
return distance;
}
bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
{
// isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b))
{
return false;
}
uint32_t distance = float_distance(a, b);
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits // e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 ) // tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
...@@ -57,6 +69,64 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits) ...@@ -57,6 +69,64 @@ bool test::close_f(float a, float b, int mantissa_bits, int tolerance_bits)
return distance <= tolerance; return distance <= tolerance;
} }
vector<uint32_t> test::float_distances(const vector<float>& a, const vector<float>& b)
{
if (a.size() != b.size())
{
throw ngraph_error("a.size() != b.size() for float_distances comparison.");
}
vector<uint32_t> distances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
distances[i] = float_distance(a[i], b[i]);
}
return distances;
}
uint32_t test::matching_mantissa_bits(uint32_t distance)
{
uint32_t tolerance_needed = distance;
if (tolerance_needed < 0x80000000)
{
// Set up the dominos - turn on all the bits below maximal bit
tolerance_needed |= tolerance_needed >> 1;
tolerance_needed |= tolerance_needed >> 2;
tolerance_needed |= tolerance_needed >> 4;
tolerance_needed |= tolerance_needed >> 8;
tolerance_needed |= tolerance_needed >> 16;
// Tumble the dominos so we end up with next highest bit
++tolerance_needed;
// all_close_f is <= test for tolerance
if ((tolerance_needed >> 1) == distance)
{
tolerance_needed = distance;
}
}
uint32_t tolerance_bit_shift = 0;
while (tolerance_needed >>= 1)
{
++tolerance_bit_shift;
}
// all_close_f calculation of tolerance_bit_shift:
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
// float_length sign exp matching_matissa_bits implicit 1 tolerance_bits
//
// Assuming 0 tolerance_bits and solving for matching_matissa_bits yields:
// tolerance_bit_shift = 32 - (1 + 8 + (matching_matissa_bits - 1 ) - 0 )
// tolerance_bit_shift = 32 - (1 + 8 + (matching_matissa_bits - 1 ) )
// matching_matissa_bits = 32 - (1 + 8 + (tolerance_bit_shift - 1 ) )
uint32_t matching_matissa_bits =
tolerance_bit_shift < 24 ? (32 - (1 + 8 + (tolerance_bit_shift - 1))) : 0;
return matching_matissa_bits;
}
bool test::all_close_f(const vector<float>& a, bool test::all_close_f(const vector<float>& a,
const vector<float>& b, const vector<float>& b,
int mantissa_bits, int mantissa_bits,
...@@ -65,27 +135,70 @@ bool test::all_close_f(const vector<float>& a, ...@@ -65,27 +135,70 @@ bool test::all_close_f(const vector<float>& a,
bool rc = true; bool rc = true;
if (a.size() != b.size()) if (a.size() != b.size())
{ {
throw ngraph_error("a.size() != b.size() for all_close comparison."); throw ngraph_error("a.size() != b.size() for all_close_f comparison.");
} }
size_t count = 0; vector<uint32_t> distances = float_distances(a, b);
// e.g. for float with 24 bit mantissa, 2 bit accuracy, and hard-coded 8 bit exponent_bits
// tolerance_bit_shift = 32 - (1 + 8 + (24 - 1 ) - 2 )
// float_length sign exp mantissa implicit 1 tolerance_bits
uint32_t tolerance_bit_shift = 32 - (1 + 8 + (mantissa_bits - 1) - tolerance_bits);
uint32_t tolerance = static_cast<uint32_t>(1U) << tolerance_bit_shift;
uint32_t max_distance = 0;
uint32_t min_distance = UINT_MAX;
size_t max_distance_index = 0;
size_t min_distance_index = 0;
size_t diff_count = 0;
for (size_t i = 0; i < a.size(); ++i) for (size_t i = 0; i < a.size(); ++i)
{ {
bool is_close_f = close_f(a[i], b[i], mantissa_bits, tolerance_bits); if (distances[i] > max_distance)
{
max_distance = distances[i];
max_distance_index = i;
}
if (distances[i] < min_distance)
{
min_distance = distances[i];
min_distance_index = i;
}
bool is_close_f = distances[i] <= tolerance;
if (!is_close_f) if (!is_close_f)
{ {
if (count < 5) if (diff_count < 5)
{ {
NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i; NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i;
} }
rc = false; rc = false;
count++; diff_count++;
} }
} }
if (!rc) if (!rc)
{ {
NGRAPH_INFO << "diff count: " << count << " out of " << a.size(); NGRAPH_INFO << "diff count: " << diff_count << " out of " << a.size();
} }
// Find median value via partial sorting
size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
uint32_t median_distance = distances[middle];
if (distances.size() % 2 == 0)
{
// Find middle-1 value
uint64_t median_sum = static_cast<uint64_t>(median_distance) +
*max_element(distances.begin(), distances.begin() + middle);
median_distance = median_sum / 2;
}
NGRAPH_INFO << "passing criteria: " << (mantissa_bits - tolerance_bits) << " mantissa bits ("
<< mantissa_bits << " mantissa bits w/ " << tolerance_bits << " tolerance bits)";
NGRAPH_INFO << "tightest match: " << matching_mantissa_bits(min_distance)
<< " mantissa bits (" << a[min_distance_index] << " vs " << b[min_distance_index]
<< " at [" << min_distance_index << "])";
NGRAPH_INFO << "loosest match: " << matching_mantissa_bits(max_distance)
<< " mantissa bits (" << a[max_distance_index] << " vs " << b[max_distance_index]
<< " at [" << max_distance_index << "])";
NGRAPH_INFO << "median match: " << matching_mantissa_bits(median_distance)
<< " mantissa bits";
return rc; return rc;
} }
......
...@@ -25,6 +25,27 @@ namespace ngraph ...@@ -25,6 +25,27 @@ namespace ngraph
{ {
namespace test namespace test
{ {
/// \brief Determine distance between two f32 numbers
/// \param a First number to compare
/// \param b Second number to compare
/// \returns Distance
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/google/googletest/blob/master/googletest/docs/AdvancedGuide.md#floating-point-comparison
///
/// s e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m
/// |------------bfloat-----------|
/// |----------------------------float----------------------------|
///
/// bfloat (s1, e8, m7) has 7 + 1 = 8 bits of mantissa or bit_precision
/// float (s1, e8, m23) has 23 + 1 = 24 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for
/// bfloat and f32.
uint32_t float_distance(float a, float b);
/// \brief Check if the two f32 numbers are close /// \brief Check if the two f32 numbers are close
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
...@@ -48,6 +69,22 @@ namespace ngraph ...@@ -48,6 +69,22 @@ namespace ngraph
/// bfloat and f32. /// bfloat and f32.
bool close_f(float a, float b, int mantissa_bits = 8, int tolerance_bits = 2); bool close_f(float a, float b, int mantissa_bits = 8, int tolerance_bits = 2);
/// \brief Determine distances between two vectors of f32 numbers
/// \param a Vector of floats to compare
/// \param b Vector of floats to compare
/// \returns Vector of distances
///
/// See float_distance for limitations and assumptions.
std::vector<uint32_t> float_distances(const std::vector<float>& a,
const std::vector<float>& b);
/// \brief Determine number of matching mantissa bits given a distance
/// \param distance Distance calculated by float_distance
/// \returns Number of matching mantissa bits
///
/// See float_distance for limitations and assumptions.
uint32_t matching_mantissa_bits(uint32_t distance);
/// \brief Check if the two floating point vectors are all close /// \brief Check if the two floating point vectors are all close
/// \param a First number to compare /// \param a First number to compare
/// \param b Second number to compare /// \param b Second number to compare
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment