derive-for-training.rst 5.12 KB
Newer Older
1 2 3 4 5 6
.. derive-for-training.rst

#########################
Derive a trainable model 
#########################

7 8 9
Documentation in this section describes one of the possible ways to turn a 
:abbr:`DL (Deep Learning)` model for inference into one that can be used  
for training.
10 11

Additionally, and to provide a more complete walk-through that *also* trains the 
12
model, our example includes the use of a simple data loader for uncompressed
13 14 15 16 17 18 19 20 21 22 23
MNIST data.

* :ref:`model_overview`
* :ref:`code_structure`

  - :ref:`inference`
  - :ref:`loss`
  - :ref:`backprop`
  - :ref:`update`


24 25 26 27
.. _automating_graph_construction:

Automating graph construction
==============================
28

29 30 31 32 33 34 35
In a :abbr:`Machine Learning (ML)` ecosystem, it makes sense to use automation 
and abstraction whereever possible. nGraph was designed to automatically use 
the "ops" of tensors provided by a framework when constructing graphs. However, 
nGraph's graph-construction API operates at a fundamentally lower level than a 
typical framework's API, and writing a model directly in nGraph would be somewhat 
akin to programming in assembly language: not impossible, but not the easiest 
thing for humans to do. 
36

37 38 39 40 41 42 43
To make the task easier for developers who need to customize the "automatic", 
construction of graphs, we've provided some demonstration code for how this 
could be done. We know, for example, that a trainable model can be derived from 
any graph that has been constructed with weight-based updates. 

The following example named ``mnist_mlp.cpp`` represents a hand-designed 
inference model being converted to a model that can be trained with nGraph. 
44 45


46 47
.. _model_overview:

48 49 50 51 52 53
Model overview 
===============

Due to the lower-level nature of the graph-construction API, the example we've 
selected to document here is a relatively simple model: a fully-connected 
topology with one hidden layer followed by ``Softmax``.
54

55 56 57 58
Remember that in nGraph, the graph is stateless; values for the weights must
be provided as parameters along with the normal inputs. Starting with the graph
for inference, we will use it to create a graph for training. The training
function will return tensors for the updated weights. 
59

60 61 62 63 64 65
.. note:: This example illustrates how to convert an inference model into one 
   that can be trained. Depending on the framework, bridge code may do something 
   similar, or the framework might do this operation itself. Here we do the 
   conversion with nGraph because the computation for training a model is 
   significantly larger than for inference, and doing the conversion manually 
   is tedious and error-prone.
66 67 68 69


.. _code_structure:

70
Code structure
71 72 73 74 75 76 77 78 79
==============


.. _inference:

Inference
---------

We begin by building the graph, starting with the input parameter 
80 81
``X``. We also define a fully-connected layer, including parameters for
weights and bias:
82

83
.. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp
84
   :language: cpp
85 86
   :lines: 127-139

87

88 89 90 91 92
Repeat the process for the next layer,

.. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp
   :language: cpp
   :lines: 141-149
93

94
and normalize everything with a ``softmax``.
95

96
.. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp
97
   :language: cpp
98
   :lines: 151-153
99 100 101 102 103 104 105


.. _loss:

Loss
----

106 107 108
We use cross-entropy to compute the loss. nGraph does not currenty have a core
op for cross-entropy, so we implement it directly, adding clipping to prevent 
underflow.
109

110
.. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp
111
   :language: cpp
112
   :lines: 154-169
113 114 115 116 117 118 119


.. _backprop:

Backprop
--------

120 121 122 123 124
We want to reduce the loss by adjusting the weights. We compute the adjustments 
using the reverse-mode autodiff algorithm, commonly referred to as "backprop" 
because of the way it is implemented in interpreted frameworks. In nGraph, we 
augment the loss computation with computations for the weight adjustments. This 
allows the calculations for the adjustments to be further optimized.
125

126
.. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp
127
   :language: cpp
128
   :lines: 171-175
129 130 131 132 133 134 135 136 137 138


For any node ``N``, if the update for ``loss`` is ``delta``, the
update computation for ``N`` will be given by the node

.. code-block:: cpp

   auto update = loss->backprop_node(N, delta);

The different update nodes will share intermediate computations. So to
139 140
get the updated values for the weights as computed with the specified 
:doc:`backend <../programmable/index>`, 
141

142
.. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp
143
   :language: cpp
144 145
   :lines: 177-217

146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161

.. _update:

Update
------

Since nGraph is stateless, we train by making a function that has the
original weights among its inputs and the updated weights among the
results. For training, we'll also need the labeled training data as
inputs, and we'll return the loss as an additional result.  We'll also
want to track how well we are doing; this is a function that returns
the loss and has the labeled testing data as input. Although we can
use the same nodes in different functions, nGraph currently does not
allow the same nodes to be compiled in different functions, so we
compile clones of the nodes.

162
.. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp
163
   :language: cpp
164
   :lines: 221-226
165