Any machine learning problem is generally formulated as roughly the following steps
Model the outputs as some function of the input and parameters
Then come up with a loss function that quantifies how well the trained model fits the data
We solve the following minimization problem at the end of all
Anybody with some elementary calculus experience knows that the solution to the above problem involves taking gradients . While we have the importance of differentiation established we will first understand in summary the limitations of a few approaches and then discuss in detail how automatic differentiation overcomes that even when calculating the exact form while being computationally efficient. It might excite you to know that most modern general purpose machine learning frameworks like Tensorflow and Torch, use AD as a first-class citizen.
Let us take an example which can be calculated by hand but non-trivial enough for the purpose of this post. Off the top of my head
That is an absolutely nutty equation and any resemblance to something meaningful is purely coincidental. The aim here is to include all the standard operators and core set of functions that are commonly used. Now let us say we aim to find how does this function change for a change to the value of . We write that as the following partial derivative
Ok this is already going crazy, even though calculating the derivative is a trivial application of the chain rule, doesn't change the fact that it is laborious. Now imagine hand-coding this for every possible function type, inter-leaved with control flow statements in a programming language. That is getting too complicated already but nonetheless this gives us the exact form of the differential.
A slighly less laborious method exists where one only approximates the value based on an infinitesimal perturbation from the point of interest.
This can be easily derived from the Taylor Series expansion of with a first order approximation. is an infinitesimally small step size.
A numerically more stable variant is given by
Despite this, both these approaches commit the cardinal sins of numerical analysis - "thou shalt not add small numbers to big numbers" and "thou shalt not subtract numbers which are approximately equal”, hence plagued by truncation and round-off errors.
Symbolic differentiation is used by mathematical analysis systems like Mathematica, where a pre-defined set of operations like chain rule for products are available and each function's form is mechanistically produced before the evaluation. This requires complete knowledge of all control flow in the program and needs to literally build the complete computational graph (discussed later). It still presents the challenge of how much should a designer hand code the pre-set rules.
While useful in their own right, such kind of algebraic systems are plagued with a problem called expression swell where the final exact form of the differentiation can get exponentially large and out of control. Note that in the above example, I have manually cancelled out a lot of terms by virtue of hand-coding. But an algebraic system cannot identify such patterns by itself. Despite all the effort of encoding a basic set of differentiation rules into some data structures, this still doesn't save us from approximation errors.
We will now introduce what is known as Automatic Differentiation. The idea behind this approach is not limited in any form by the nature of the function and can work across complicated control flow statements. All it cares about is the sequence of operations performed. Before we dive into the details, let us understand computational graphs.
Computational Graphs are a very simple idea based on Dynamic Programming. Each computation can be considered as a directed graph of basic operations (i.e add, multiply, divide, trigonometric functions and so on). The values propagated along each path in the graph compound on top of each other to reach the final value. (Does that remind you of Neural Networks? Hold on to that thought). Consider the graph for our function .
Intuitively, think of each node as a special gate function that we have
implemented. For instance, we have a gate which just passes the incoming
value (we'll call it the
= gate). Elsewhere we have a "
log of incoming value. The function that each gate applies to
an incoming value has been specified on top of each node in consideration.
It is also important to realize the full power of these graphs here in the sense that this doesn't care about the complicated logic in our code like loops and conditionals. Given a set of input values, whatever computation happens next gets recorded in the graph irrespective of what "might" have been in the other conditional branch.
The composition of all these gates starting from the input nodes and to the final output node gives us the value of the function . You should be easily able to convince yourself off that. This is generally familiar to everybody as the natural flow of composition.
To make it explicit, we will build below something called the Forward Primal Trace and the corresponding Forward Tangent Trace. I would recommend that you don't look at the full symbolic representations of and trust the computations below. But, for the sake of completing the intuition, you might work that alongside yourself. Remember we are calculating the derivative .
The order of evaluation in the forward mode is top to bottom. If one observes carefully, each lower term can be calculated from operands that have already been calculated in some higher term. And each intermediate derivative is calculated using some elementary differentiation rules like the sum-rule or the product rule and so on. Substituting, all the values backwards should result in the exact same value as above. Interestingly, observe that we never really needed the full symbolic representation of the differentiation formula but just local "gate" level formulas from our repository of basic rules.
In the larger picture, it is important to understand what the forward tangent trace actually calculates. It calculates the rate of change of the function with respect to one of the inputs . While optimizations in practice, we are usually interested in how the output behaves with respect to all the input parameters. For instance, in the case of neural networks, these parameters comprise of weight matrices between each layer. It isn't hard to see that to calculate the gradient with respect to a total of parameters, one would need forward mode differentiations. In deep neural networks, those numbers are crazily huge rendering this technique terribly inefficient.
This is a digression and not really important for the purpose of this post. It can be safely skipped. But for the adventurous, this presents how beautifully mathematicians come up with certain definitions to present efficient patterns.
Mathematically, the forward mode can be seen as a set of operation on dual numbers. By definition, they are just the truncated Taylor Series expansion until first order derivative and any value is augmented as
where . Now consider some operations like adding and multiplying two dual numbers
The coefficients of are exactly equal to the symbolic derivatives. Now we must add one more definition to this concept before it can be universally utilized.
This implies that for a composition
which is in fact the symbolic derivative formula for the chain rule.
Analogously, imagine something we do between real and complex numbers but just that the operator . In practice, this technique helps encode together the function and derivatives by either transforming the non-dual code or through operator overloading.
The reverse mode is the general form of a more familiar technique known as the backpropagation, which is at the heart of neural networks. We define the adjoint of a variable as
which intuitively describes the sensitivity of the output with respect to the intermediate variable . For instance, in neural networks, the output usually tends to be the loss function applied to the output of forward pass and we are interested in observing how does the loss vary with respect to the input parameters of the network. This time along with the Forward Primal Trace we will build the Reverse Adjoint Trace. I would again recommend not looking at the full symbolic representations of the derivatives but instead try and trust the computations happening below.
Now this result is truly magical because in a single-pass we have computed the sensitivity of my output with respect to each input variable - and . The trace on the left is the trivial one about which you should be convinced by now. The trace on the right is in fact evaluated from bottom to top. Contrary to the Forward Tangent Trace, in the Reverse Adjoint Trace shown on the right above, each lower term is calculated before a higher term and as a result the operands of each higher term are available for usage. Note that each adjoint is calculated in reverse order of occurrence in the Forward Primal Mode (traverse the computational flow graph in reverse level-order).
You must have also observed that is calculated twice in the Reverse Adjoint Mode and in fact the resultant value is the accumulation of both adjoints. Intuitively, observe from the computational graph that the input variable can influence the output via two paths (indirectly via and directly via ) and this accumulation of the value twice is reflective of that. Consequently, Reverse Mode is also known as the Reverse Accumulation Mode.
From a computational standpoint, we must again observe that we only care about the local derivative flow at the gate and the rest falls into place. This property comes from the chain rule of derivatives and makes computations extremely efficient and accurate (of course only up to the floating point error).
This technique has been independently invented multiple times in history and perhaps one of the most influential research outputs from 20th Century Mathematics. AD techniques in Machine Learning have become ubiquitous with Neural Networks bursting into the scene and the huge amount of computations that needed to be done. Deep Learning frameworks like Tensorflow and PyTorch extensively use AD.
Let the total number of operations in a function be .
To calculate the sensitivity of output with respect to each input parameter, for the forward mode, we have total cost as because one forward pass will calculate the values for all outputs with respect to one parameter. Conversely, the backward pass will run in as it will calculate the derivative of one output with respect to all parameters. It is easy to see that when , the reverse mode AD performs roughly better overall.
Here is a toy neural network in PyTorch which uses Automatic Differentiation
backward() phase of the network.
import torch from torch.autograd import Variable def main(): ## # N - Number of sample instances # D_in - Dimension of input sample # H - Number of neurons in the hidden layer # D_out - Number of neurons in the output layer # N, D_in, H, D_out = 64, 1000, 100, 10 # Randomly initialize some inputs and desired outputs x = Variable(torch.randn(N, D_in)) y = Variable(torch.randn(N, D_out)) # ReLU activation function for hidden layer (rest linear layers) model = torch.nn.Sequential( torch.nn.Linear(D_in, H), torch.nn.ReLU(), torch.nn.Linear(H, D_out), ) # Mean Squared Error Loss for output layer loss_fn = torch.nn.MSELoss(size_average=False) # Stochastic Gradient Descent with some learning rate lr = 1e-4 optimizer = torch.optim.SGD(model.parameters(), lr=lr) for t in range(500): # Forward Mode y_pred = model(x) # Compute Loss at the output layer w.r.t desired outputs loss = loss_fn(y_pred, y) print(t, loss.data) ## # Flush the gradient accumulator to prevent gradients # from last reverse pass # optimizer.zero_grad() ## # Reverse Accumulation Mode using Automatic Differentiation # loss.backward() ## # Update the model parameters (the weight matrices) using # the gradient loss from Reverse Mode and Learning Rate # optimizer.step() if __name__=='__main__': main()
It is interesting to note here that Torch (and many other) build
dynamic computational graphs on the fly as and when the operations
are executed. For non-trivial operations, one needs to implement
backward() function which dictates how the gradients are
supposed to be calculated.