.. derive-for-training.rst ######################### Derive a trainable model ######################### Documentation in this section describes one of the possible ways to turn a :abbr:`DL (Deep Learning)` model for inference into one that can be used for training. Additionally, and to provide a more complete walk-through that *also* trains the model, our example includes the use of a simple data loader for uncompressed MNIST data. * :ref:`model_overview` * :ref:`code_structure` - :ref:`inference` - :ref:`loss` - :ref:`backprop` - :ref:`update` .. _automating_graph_construction: Automating graph construction ============================== In a :abbr:`Machine Learning (ML)` ecosystem, it makes sense to use automation and abstraction whereever possible. nGraph was designed to automatically use the "ops" of tensors provided by a framework when constructing graphs. However, nGraph's graph-construction API operates at a fundamentally lower level than a typical framework's API, and writing a model directly in nGraph would be somewhat akin to programming in assembly language: not impossible, but not the easiest thing for humans to do. To make the task easier for developers who need to customize the "automatic", construction of graphs, we've provided some demonstration code for how this could be done. We know, for example, that a trainable model can be derived from any graph that has been constructed with weight-based updates. The following example named ``mnist_mlp.cpp`` represents a hand-designed inference model being converted to a model that can be trained with nGraph. .. _model_overview: Model overview =============== Due to the lower-level nature of the graph-construction API, the example we've selected to document here is a relatively simple model: a fully-connected topology with one hidden layer followed by ``Softmax``. Remember that in nGraph, the graph is stateless; values for the weights must be provided as parameters along with the normal inputs. Starting with the graph for inference, we will use it to create a graph for training. The training function will return tensors for the updated weights. .. note:: This example illustrates how to convert an inference model into one that can be trained. Depending on the framework, bridge code may do something similar, or the framework might do this operation itself. Here we do the conversion with nGraph because the computation for training a model is significantly larger than for inference, and doing the conversion manually is tedious and error-prone. .. _code_structure: Code structure ============== .. _inference: Inference --------- We begin by building the graph, starting with the input parameter ``X``. We also define a fully-connected layer, including parameters for weights and bias: .. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp :language: cpp :lines: 127-139 Repeat the process for the next layer, .. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp :language: cpp :lines: 141-149 and normalize everything with a ``softmax``. .. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp :language: cpp :lines: 151-153 .. _loss: Loss ---- We use cross-entropy to compute the loss. nGraph does not currenty have a core op for cross-entropy, so we implement it directly, adding clipping to prevent underflow. .. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp :language: cpp :lines: 154-169 .. _backprop: Backprop -------- We want to reduce the loss by adjusting the weights. We compute the adjustments using the reverse-mode autodiff algorithm, commonly referred to as "backprop" because of the way it is implemented in interpreted frameworks. In nGraph, we augment the loss computation with computations for the weight adjustments. This allows the calculations for the adjustments to be further optimized. .. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp :language: cpp :lines: 171-175 For any node ``N``, if the update for ``loss`` is ``delta``, the update computation for ``N`` will be given by the node .. code-block:: cpp auto update = loss->backprop_node(N, delta); The different update nodes will share intermediate computations. So to get the updated values for the weights as computed with the specified :doc:`backend <../programmable/index>`, .. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp :language: cpp :lines: 177-217 .. _update: Update ------ Since nGraph is stateless, we train by making a function that has the original weights among its inputs and the updated weights among the results. For training, we'll also need the labeled training data as inputs, and we'll return the loss as an additional result. We'll also want to track how well we are doing; this is a function that returns the loss and has the labeled testing data as input. Although we can use the same nodes in different functions, nGraph currently does not allow the same nodes to be compiled in different functions, so we compile clones of the nodes. .. literalinclude:: ../../../examples/mnist_mlp/mnist_mlp.cpp :language: cpp :lines: 221-226