Unverified Commit 220288e3 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Doc for Batchnorm (#2143)

parent 0c6590e7
...@@ -1807,7 +1807,7 @@ SEARCH_INCLUDES = YES ...@@ -1807,7 +1807,7 @@ SEARCH_INCLUDES = YES
# preprocessor. # preprocessor.
# This tag requires that the tag SEARCH_INCLUDES is set to YES. # This tag requires that the tag SEARCH_INCLUDES is set to YES.
INCLUDE_PATH = INCLUDE_PATH = ../../src
# You can use the INCLUDE_FILE_PATTERNS tag to specify one or more wildcard # You can use the INCLUDE_FILE_PATTERNS tag to specify one or more wildcard
# patterns (like *.h and *.hpp) to filter out the header-files in the # patterns (like *.h and *.hpp) to filter out the header-files in the
......
.. batch_norm_inference.rst:
##################
BatchNormInference
##################
.. code-block:: cpp
BatchNormInference // Adjust input for mean and variance
Description
===========
Inputs
------
+---------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+=====================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+---------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``mean`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``variances`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
Attributes
----------
+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+
Outputs
-------
+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``normalized`` | same as ``gamma`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+
Mathematical Definition
=======================
The axes of the input fall into two categories: positional and channel, with
channel being axis 1. For each position, there are :math:`C` channel values,
each normalized independently.
Normalization of a channel sample is controlled by two values:
* the `mean` :math:`\mu`, and
* the `variance` :math:`\sigma^2`;
and by two scaling attributes: :math:`\gamma` and :math:`\beta`.
.. math::
\mathtt{normalized}_{\bullet, c, \ldots} = \frac{\mathtt{input}_{\bullet, c, \ldots}-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c
C++ Interface
==============
.. doxygenclass:: ngraph::op::BatchNormInference
:project: ngraph
:members:
.. batch_norm.rst: .. batch_norm_training.rst:
######### #################
BatchNorm BatchNormTraining
######### #################
.. code-block:: cpp .. code-block:: cpp
BatchNorm // Produces a normalized output BatchNormTraining // Compute mean and variance from the input.
Description Description
=========== ===========
Produces a normalized output.
Inputs Inputs
------ ------
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+------------------------------+
| Name | Element Type | Shape | | Name | Element Type | Shape |
+=====================+=========================+=============================+ +=====================+=========================+==============================+
| ``input`` | same as ``gamma`` | \(..., C, ...\) | | ``input`` | real | :math:`(\bullet, C, \ldots)` |
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+------------------------------+
| ``gamma`` | any | \(C\) | | ``gamma`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+------------------------------+
| ``beta`` | same as ``gamma`` | \(C\) | | ``beta`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+------------------------------+
| ``global_mean`` | same as ``gamma`` | \(C\) |
+---------------------+-------------------------+-----------------------------+
| ``global_variance`` | same as ``gamma`` | \(C\) |
+---------------------+-------------------------+-----------------------------+
| ``use_global`` | ``bool`` | \(\) |
+---------------------+-------------------------+-----------------------------+
Attributes Attributes
---------- ----------
+------------------+--------------------+---------------------+ +------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes | | Name | Type | Notes |
+==================+====================+=====================+ +==================+====================+========================================================+
| ``epsilon`` | same as ``input`` | Bias for variance | | ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+---------------------+ +------------------+--------------------+--------------------------------------------------------+
| ``channel_axis`` | size_t | Channel axis |
+------------------+--------------------+---------------------+
Outputs Outputs
------- -------
...@@ -51,16 +43,15 @@ Outputs ...@@ -51,16 +43,15 @@ Outputs
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape | | Name | Element Type | Shape |
+=====================+=========================+=============================+ +=====================+=========================+=============================+
| ``normalized`` | same as ``gamma`` | same as ``input`` | | ``normalized`` | same as ``gamma`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+-----------------------------+
| ``batch_mean`` | same as ``gamma`` | \(C\) | | ``batch_mean`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+-----------------------------+
| ``batch_variance`` | same as ``gamma`` | \(C\) | | ``batch_variance`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+ +---------------------+-------------------------+-----------------------------+
The ``batch_mean`` and ``batch_variance`` outputs are computed per-channel from The ``batch_mean`` and ``batch_variance`` outputs are computed per-channel from
``input``. The values only need to be computed if ``use_global`` is ``false``, ``input``.
or if they are used.
Mathematical Definition Mathematical Definition
...@@ -72,22 +63,29 @@ each normalized independently. ...@@ -72,22 +63,29 @@ each normalized independently.
Normalization of a channel sample is controlled by two values: Normalization of a channel sample is controlled by two values:
* the mean :math:`\mu`, and * the `batch_mean` :math:`\mu`, and
* the variance :math:`\sigma^2`;
* the `batch_variance` :math:`\sigma^2`;
and by two scaling attributes: :math:`\gamma` and :math:`\beta`. and by two scaling attributes: :math:`\gamma` and :math:`\beta`.
The values for :math:`\mu` and :math:`\sigma^2` come either from computing the The values for :math:`\mu` and :math:`\sigma^2` come from computing the
mean and variance of ``input``, or from ``global_mean`` and ``global_variance``, mean and variance of ``input``.
depending on the value of ``use_global``.
.. math::
\mu_c &= \mathop{\mathbb{E}}\left(\mathtt{input}_{\bullet, c, \ldots}\right)\\
\sigma^2_c &= \mathop{\mathtt{Var}}\left(\mathtt{input}_{\bullet, c, \ldots}\right)\\
\mathtt{normlized}_{\bullet, c, \ldots} &= \frac{\mathtt{input}_{\bullet, c, \ldots}-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c
Backprop
========
.. math:: .. math::
y_c = \frac{x_c-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c [\overline{\texttt{input}}, \overline{\texttt{gamma}}, \overline{\texttt{beta}}]=\\
\mathop{\texttt{BatchNormTrainingBackprop}}(\texttt{input},\texttt{gamma},\texttt{beta},\texttt{mean},\texttt{variance},\overline{\texttt{normed_input}}).
The mean and variance can be arguments, or they may be computed for each channel
of ``input`` over the positional axes. When computed from ``input``, the mean
and variance per-channel are available as outputs.
C++ Interface C++ Interface
...@@ -98,8 +96,3 @@ C++ Interface ...@@ -98,8 +96,3 @@ C++ Interface
:members: :members:
.. doxygenclass:: ngraph::op::BatchNormInference
:project: ngraph
:members:
.. batch_norm_training_backprop.rst:
#########################
BatchNormTrainingBackprop
#########################
.. code-block:: cpp
BatchNormTrainingBackprop // Compute mean and variance backprop from the input.
Description
===========
Computes the ``input``, ``gamma`` and ``beta`` backprop increments.
Inputs
------
+----------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+======================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+----------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``mean`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``variance`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``normalized_delta`` | same as ``input`` | same as ``input`` |
+----------------------+-------------------------+------------------------------+
Attributes
----------
+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+
Outputs
-------
+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``input_delta`` | same as ``input`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+
| ``gamma_delta`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
| ``beta_delta`` | same as ``beta`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
Mathematical Definition
=======================
It is easiest to simplify by looking at a single channel and flattening the
remaining axes into a vector; so ``gamma`` and ``beta`` are scalars, and ``input`` is an
:math:`N`-element vector.
The step by step forward training computation is
.. math::
\mathtt{mean} &= \frac{\sum{\mathtt{input}_i}}{N}\\
\mathtt{centered}_i &= \mathtt{input}_i - \mathtt{mean}\\
\mathtt{square}_i &= \mathtt{centered}_i^2\\
\mathtt{variance} &= \frac{\sum \mathtt{square}_i}{N}\\
\mathtt{invsqrt} &= \frac{1}{\sqrt{\mathtt{variance}+\epsilon}}\\
\mathtt{gmul} &= \texttt{gamma}\cdot \mathtt{invsqrt}\\
\mathtt{normed}_i &= \mathtt{centered}_i\mathtt{gmul}+\texttt{beta}
Using the notation :math:`\overline{\texttt{name}}` for :math:`\texttt{name_delta}`
and :math:`\overline{x} \leftarrow y`
to mean the backprop value for :math:`\texttt{x_delta}` is a sum that includes :math:`y`.
We work backwards
.. math::
\overline{\texttt{beta}}&\leftarrow \overline{\texttt{normed}}\\
\overline{\texttt{gmul}}&\leftarrow \sum \overline{\texttt{normed}}_i\\
\overline{\texttt{centered}}_i&\leftarrow\overline{\texttt{normed}}_i\texttt{gmul}\\
\overline{\texttt{gamma}}&\leftarrow \overline{\texttt{gmul}}\cdot\texttt{invsqrt}\\
\overline{\texttt{invsqrt}}&\leftarrow\texttt{gamma}\cdot\overline{\texttt{gmul}}\\
\overline{\texttt{variance}}&\leftarrow -\frac{\overline{\texttt{invsqrt}}\cdot\texttt{invsqrt}}{2\cdot(\texttt{variance}+\epsilon)}\\
\overline{\texttt{square}}_i&\leftarrow\frac{\overline{\texttt{variance}}}{N}\\
\overline{\texttt{centered}}_i&\leftarrow 2\cdot\texttt{centered}_i\cdot\overline{\texttt{square}}_i\\
\overline{\texttt{input}}_i&\leftarrow\overline{\texttt{centered}}_i\\
\overline{\texttt{mean}}&\leftarrow\sum\overline{\texttt{centered}}_i\\
\overline{\texttt{input}}_i&\leftarrow\frac{\overline{\texttt{mean}}}{N}
C++ Interface
==============
.. doxygenclass:: ngraph::op::BatchNormTrainingBackprop
:project: ngraph
:members:
...@@ -56,7 +56,9 @@ Not currently a comprehensive list. ...@@ -56,7 +56,9 @@ Not currently a comprehensive list.
* :doc:`atan` * :doc:`atan`
* :doc:`avg_pool` * :doc:`avg_pool`
* :doc:`avg_pool_backprop` * :doc:`avg_pool_backprop`
* :doc:`batch_norm` * :doc:`batch_norm_inference`
* :doc:`batch_norm_training`
* :doc:`batch_norm_training_backprop`
* :doc:`broadcast` * :doc:`broadcast`
* :doc:`ceiling` * :doc:`ceiling`
* :doc:`concat` * :doc:`concat`
...@@ -123,7 +125,9 @@ Not currently a comprehensive list. ...@@ -123,7 +125,9 @@ Not currently a comprehensive list.
atan.rst atan.rst
avg_pool.rst avg_pool.rst
avg_pool_backprop.rst avg_pool_backprop.rst
batch_norm.rst batch_norm_inference.rst
batch_norm_training.rst
batch_norm_training_backprop.rst
broadcast.rst broadcast.rst
ceiling.rst ceiling.rst
concat.rst concat.rst
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment