batch_norm_training_backprop.rst 4.76 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
.. batch_norm_training_backprop.rst:

#########################
BatchNormTrainingBackprop
#########################

.. code-block:: cpp

   BatchNormTrainingBackprop  // Compute mean and variance backprop from the input.


Description
===========

Computes the ``input``, ``gamma`` and ``beta`` backprop increments.


Inputs
------

+----------------------+-------------------------+------------------------------+
| Name                 | Element Type            | Shape                        |
+======================+=========================+==============================+
| ``input``            | real                    | :math:`(\bullet, C, \ldots)` |
+----------------------+-------------------------+------------------------------+
| ``gamma``            | same as ``input``       | :math:`(C)`                  |
+----------------------+-------------------------+------------------------------+
| ``beta``             | same as ``input``       | :math:`(C)`                  |
+----------------------+-------------------------+------------------------------+
| ``mean``             | same as ``input``       | :math:`(C)`                  |
+----------------------+-------------------------+------------------------------+
| ``variance``         | same as ``input``       | :math:`(C)`                  |
+----------------------+-------------------------+------------------------------+
| ``normalized_delta`` | same as ``input``       | same as ``input``            |
+----------------------+-------------------------+------------------------------+


Attributes
----------

+------------------+--------------------+--------------------------------------------------------+
| Name             | Type               | Notes                                                  |
+==================+====================+========================================================+
| ``epsilon``      | ``double``         | Small bias added to variance to avoid division by 0.   |
+------------------+--------------------+--------------------------------------------------------+

Outputs
-------

+---------------------+-------------------------+-----------------------------+
| Name                | Element Type            | Shape                       |
+=====================+=========================+=============================+
| ``input_delta``     | same as ``input``       | Same as ``input``           |
+---------------------+-------------------------+-----------------------------+
| ``gamma_delta``     | same as ``gamma``       | :math:`(C)`                 |
+---------------------+-------------------------+-----------------------------+
| ``beta_delta``      | same as ``beta``        | :math:`(C)`                 |
+---------------------+-------------------------+-----------------------------+


Mathematical Definition
=======================

It is easiest to simplify by looking at a single channel and flattening the
remaining axes into a vector; so ``gamma`` and ``beta`` are scalars, and ``input`` is an
:math:`N`-element vector.

The step by step forward training computation is

.. math::
   
   \mathtt{mean} &= \frac{\sum{\mathtt{input}_i}}{N}\\
   \mathtt{centered}_i &= \mathtt{input}_i - \mathtt{mean}\\
   \mathtt{square}_i &= \mathtt{centered}_i^2\\
   \mathtt{variance} &= \frac{\sum \mathtt{square}_i}{N}\\
   \mathtt{invsqrt} &= \frac{1}{\sqrt{\mathtt{variance}+\epsilon}}\\
   \mathtt{gmul} &= \texttt{gamma}\cdot \mathtt{invsqrt}\\
   \mathtt{normed}_i &= \mathtt{centered}_i\mathtt{gmul}+\texttt{beta}

Using the notation :math:`\overline{\texttt{name}}` for :math:`\texttt{name_delta}`
and :math:`\overline{x} \leftarrow y`
to mean the backprop value for :math:`\texttt{x_delta}` is a sum that includes :math:`y`.

We work backwards

.. math::

   \overline{\texttt{beta}}&\leftarrow \overline{\texttt{normed}}\\
   \overline{\texttt{gmul}}&\leftarrow \sum \overline{\texttt{normed}}_i\\
   \overline{\texttt{centered}}_i&\leftarrow\overline{\texttt{normed}}_i\texttt{gmul}\\
   \overline{\texttt{gamma}}&\leftarrow \overline{\texttt{gmul}}\cdot\texttt{invsqrt}\\
   \overline{\texttt{invsqrt}}&\leftarrow\texttt{gamma}\cdot\overline{\texttt{gmul}}\\
   \overline{\texttt{variance}}&\leftarrow -\frac{\overline{\texttt{invsqrt}}\cdot\texttt{invsqrt}}{2\cdot(\texttt{variance}+\epsilon)}\\
   \overline{\texttt{square}}_i&\leftarrow\frac{\overline{\texttt{variance}}}{N}\\
   \overline{\texttt{centered}}_i&\leftarrow 2\cdot\texttt{centered}_i\cdot\overline{\texttt{square}}_i\\
   \overline{\texttt{input}}_i&\leftarrow\overline{\texttt{centered}}_i\\
   \overline{\texttt{mean}}&\leftarrow\sum\overline{\texttt{centered}}_i\\
   \overline{\texttt{input}}_i&\leftarrow\frac{\overline{\texttt{mean}}}{N}


C++ Interface
==============

.. doxygenclass:: ngraph::op::BatchNormTrainingBackprop
   :project: ngraph
   :members: