Unverified Commit 220288e3 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Doc for Batchnorm (#2143)

parent 0c6590e7
......@@ -1807,7 +1807,7 @@ SEARCH_INCLUDES = YES
# preprocessor.
# This tag requires that the tag SEARCH_INCLUDES is set to YES.
INCLUDE_PATH =
INCLUDE_PATH = ../../src
# You can use the INCLUDE_FILE_PATTERNS tag to specify one or more wildcard
# patterns (like *.h and *.hpp) to filter out the header-files in the
......
.. batch_norm_inference.rst:
##################
BatchNormInference
##################
.. code-block:: cpp
BatchNormInference // Adjust input for mean and variance
Description
===========
Inputs
------
+---------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+=====================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+---------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``mean`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``variances`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
Attributes
----------
+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+
Outputs
-------
+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``normalized`` | same as ``gamma`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+
Mathematical Definition
=======================
The axes of the input fall into two categories: positional and channel, with
channel being axis 1. For each position, there are :math:`C` channel values,
each normalized independently.
Normalization of a channel sample is controlled by two values:
* the `mean` :math:`\mu`, and
* the `variance` :math:`\sigma^2`;
and by two scaling attributes: :math:`\gamma` and :math:`\beta`.
.. math::
\mathtt{normalized}_{\bullet, c, \ldots} = \frac{\mathtt{input}_{\bullet, c, \ldots}-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c
C++ Interface
==============
.. doxygenclass:: ngraph::op::BatchNormInference
:project: ngraph
:members:
.. batch_norm.rst:
.. batch_norm_training.rst:
#########
BatchNorm
#########
#################
BatchNormTraining
#################
.. code-block:: cpp
BatchNorm // Produces a normalized output
BatchNormTraining // Compute mean and variance from the input.
Description
===========
Produces a normalized output.
Inputs
------
+---------------------+-------------------------+-----------------------------+
+---------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``input`` | same as ``gamma`` | \(..., C, ...\) |
+---------------------+-------------------------+-----------------------------+
| ``gamma`` | any | \(C\) |
+---------------------+-------------------------+-----------------------------+
| ``beta`` | same as ``gamma`` | \(C\) |
+---------------------+-------------------------+-----------------------------+
| ``global_mean`` | same as ``gamma`` | \(C\) |
+---------------------+-------------------------+-----------------------------+
| ``global_variance`` | same as ``gamma`` | \(C\) |
+---------------------+-------------------------+-----------------------------+
| ``use_global`` | ``bool`` | \(\) |
+---------------------+-------------------------+-----------------------------+
+=====================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+---------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+---------------------+-------------------------+------------------------------+
Attributes
----------
+------------------+--------------------+---------------------+
+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+=====================+
| ``epsilon`` | same as ``input`` | Bias for variance |
+------------------+--------------------+---------------------+
| ``channel_axis`` | size_t | Channel axis |
+------------------+--------------------+---------------------+
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+
Outputs
-------
......@@ -51,16 +43,15 @@ Outputs
+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``normalized`` | same as ``gamma`` | same as ``input`` |
| ``normalized`` | same as ``gamma`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+
| ``batch_mean`` | same as ``gamma`` | \(C\) |
| ``batch_mean`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
| ``batch_variance`` | same as ``gamma`` | \(C\) |
| ``batch_variance`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
The ``batch_mean`` and ``batch_variance`` outputs are computed per-channel from
``input``. The values only need to be computed if ``use_global`` is ``false``,
or if they are used.
``input``.
Mathematical Definition
......@@ -72,22 +63,29 @@ each normalized independently.
Normalization of a channel sample is controlled by two values:
* the mean :math:`\mu`, and
* the variance :math:`\sigma^2`;
* the `batch_mean` :math:`\mu`, and
* the `batch_variance` :math:`\sigma^2`;
and by two scaling attributes: :math:`\gamma` and :math:`\beta`.
The values for :math:`\mu` and :math:`\sigma^2` come either from computing the
mean and variance of ``input``, or from ``global_mean`` and ``global_variance``,
depending on the value of ``use_global``.
The values for :math:`\mu` and :math:`\sigma^2` come from computing the
mean and variance of ``input``.
.. math::
\mu_c &= \mathop{\mathbb{E}}\left(\mathtt{input}_{\bullet, c, \ldots}\right)\\
\sigma^2_c &= \mathop{\mathtt{Var}}\left(\mathtt{input}_{\bullet, c, \ldots}\right)\\
\mathtt{normlized}_{\bullet, c, \ldots} &= \frac{\mathtt{input}_{\bullet, c, \ldots}-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c
Backprop
========
.. math::
y_c = \frac{x_c-\mu_c}{\sqrt{\sigma^2_c+\epsilon}}\gamma_c+\beta_c
[\overline{\texttt{input}}, \overline{\texttt{gamma}}, \overline{\texttt{beta}}]=\\
\mathop{\texttt{BatchNormTrainingBackprop}}(\texttt{input},\texttt{gamma},\texttt{beta},\texttt{mean},\texttt{variance},\overline{\texttt{normed_input}}).
The mean and variance can be arguments, or they may be computed for each channel
of ``input`` over the positional axes. When computed from ``input``, the mean
and variance per-channel are available as outputs.
C++ Interface
......@@ -98,8 +96,3 @@ C++ Interface
:members:
.. doxygenclass:: ngraph::op::BatchNormInference
:project: ngraph
:members:
.. batch_norm_training_backprop.rst:
#########################
BatchNormTrainingBackprop
#########################
.. code-block:: cpp
BatchNormTrainingBackprop // Compute mean and variance backprop from the input.
Description
===========
Computes the ``input``, ``gamma`` and ``beta`` backprop increments.
Inputs
------
+----------------------+-------------------------+------------------------------+
| Name | Element Type | Shape |
+======================+=========================+==============================+
| ``input`` | real | :math:`(\bullet, C, \ldots)` |
+----------------------+-------------------------+------------------------------+
| ``gamma`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``beta`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``mean`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``variance`` | same as ``input`` | :math:`(C)` |
+----------------------+-------------------------+------------------------------+
| ``normalized_delta`` | same as ``input`` | same as ``input`` |
+----------------------+-------------------------+------------------------------+
Attributes
----------
+------------------+--------------------+--------------------------------------------------------+
| Name | Type | Notes |
+==================+====================+========================================================+
| ``epsilon`` | ``double`` | Small bias added to variance to avoid division by 0. |
+------------------+--------------------+--------------------------------------------------------+
Outputs
-------
+---------------------+-------------------------+-----------------------------+
| Name | Element Type | Shape |
+=====================+=========================+=============================+
| ``input_delta`` | same as ``input`` | Same as ``input`` |
+---------------------+-------------------------+-----------------------------+
| ``gamma_delta`` | same as ``gamma`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
| ``beta_delta`` | same as ``beta`` | :math:`(C)` |
+---------------------+-------------------------+-----------------------------+
Mathematical Definition
=======================
It is easiest to simplify by looking at a single channel and flattening the
remaining axes into a vector; so ``gamma`` and ``beta`` are scalars, and ``input`` is an
:math:`N`-element vector.
The step by step forward training computation is
.. math::
\mathtt{mean} &= \frac{\sum{\mathtt{input}_i}}{N}\\
\mathtt{centered}_i &= \mathtt{input}_i - \mathtt{mean}\\
\mathtt{square}_i &= \mathtt{centered}_i^2\\
\mathtt{variance} &= \frac{\sum \mathtt{square}_i}{N}\\
\mathtt{invsqrt} &= \frac{1}{\sqrt{\mathtt{variance}+\epsilon}}\\
\mathtt{gmul} &= \texttt{gamma}\cdot \mathtt{invsqrt}\\
\mathtt{normed}_i &= \mathtt{centered}_i\mathtt{gmul}+\texttt{beta}
Using the notation :math:`\overline{\texttt{name}}` for :math:`\texttt{name_delta}`
and :math:`\overline{x} \leftarrow y`
to mean the backprop value for :math:`\texttt{x_delta}` is a sum that includes :math:`y`.
We work backwards
.. math::
\overline{\texttt{beta}}&\leftarrow \overline{\texttt{normed}}\\
\overline{\texttt{gmul}}&\leftarrow \sum \overline{\texttt{normed}}_i\\
\overline{\texttt{centered}}_i&\leftarrow\overline{\texttt{normed}}_i\texttt{gmul}\\
\overline{\texttt{gamma}}&\leftarrow \overline{\texttt{gmul}}\cdot\texttt{invsqrt}\\
\overline{\texttt{invsqrt}}&\leftarrow\texttt{gamma}\cdot\overline{\texttt{gmul}}\\
\overline{\texttt{variance}}&\leftarrow -\frac{\overline{\texttt{invsqrt}}\cdot\texttt{invsqrt}}{2\cdot(\texttt{variance}+\epsilon)}\\
\overline{\texttt{square}}_i&\leftarrow\frac{\overline{\texttt{variance}}}{N}\\
\overline{\texttt{centered}}_i&\leftarrow 2\cdot\texttt{centered}_i\cdot\overline{\texttt{square}}_i\\
\overline{\texttt{input}}_i&\leftarrow\overline{\texttt{centered}}_i\\
\overline{\texttt{mean}}&\leftarrow\sum\overline{\texttt{centered}}_i\\
\overline{\texttt{input}}_i&\leftarrow\frac{\overline{\texttt{mean}}}{N}
C++ Interface
==============
.. doxygenclass:: ngraph::op::BatchNormTrainingBackprop
:project: ngraph
:members:
......@@ -56,7 +56,9 @@ Not currently a comprehensive list.
* :doc:`atan`
* :doc:`avg_pool`
* :doc:`avg_pool_backprop`
* :doc:`batch_norm`
* :doc:`batch_norm_inference`
* :doc:`batch_norm_training`
* :doc:`batch_norm_training_backprop`
* :doc:`broadcast`
* :doc:`ceiling`
* :doc:`concat`
......@@ -123,7 +125,9 @@ Not currently a comprehensive list.
atan.rst
avg_pool.rst
avg_pool_backprop.rst
batch_norm.rst
batch_norm_inference.rst
batch_norm_training.rst
batch_norm_training_backprop.rst
broadcast.rst
ceiling.rst
concat.rst
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment