Higherorder gradients in PyTorch, Parallelized
Handling metalearning in distributed PyTorch
with Ramakrishna Vedantam.
Machine learning algorithms often require differentiating through a sequence of firstorder gradient updates, for instance in metalearning. While it is easy to build learning algorithms with firstorder gradient updates using PyTorch Modules, these do not natively support differentiation through firstorder gradient updates.
We will see how to build a PyTorch pipeline that resembles the familiar simplicity of firstorder gradient updates, but also supports differentiating through the updates using a library called higher
.
Further, most modern machine learning workloads rely on distributed training, which higher
does not support^{1} as of this writing.
However, we will see a solution to support a distributed training pipeline compatible with PyTorch, despite not being supported in higher
. And, we will be able to convert any PyTorch module code to support parallelized higherorder gradients.
The Standard PyTorch Pipeline
The standard recipe to build a gradientbased pipeline in PyTorch is: (i) setup up a stateful Module
object (e.g. a neural network), (ii) run a forward
pass to build the computational graph, and (iii) call the resultant tensor’s (e.g. training loss) backward
to populate the gradients of module parameters. In addition, this pipeline can be easily parallelized using Distributed Data Parallel (DDP
). Here’s an example code skeleton:
import torch.nn as nn
import torch.optim as optim
from torch.nn.distributed import DistributedDataParallel
class MyModule(nn.Module):
def __init__(self):
super().__init__()
## Setup modules.
def forward(self, inputs):
## Run inputs through modules.
return
model = MyModule()
model = DistributedDataParallel(model, device_ids=[device_id])
optimizer = optim.SGD(model.parameters(), lr=1e2)
loss = loss_fn(model(inputs), targets)
loss.backward() ## Automatically sync gradients across distributed machines, if needed.
optimizer.step()
This approach, however, only works for firstorder gradients. What do we do when we need to differentiate through firstorder gradient updates?
As an illustration of the need to differentiate through firstorder gradient updates, let us tackle a toy metalearning problem.
A Metalearning Problem
Learning to learn is a hallmark of intelligence. Once a child learns the concept of a wheel, it is much easier to identify a different wheel whether it be on a toy or a car. Intelligent beings can adapt to a similar task much faster than learning from scratch. We call this ability metalearning.
To mimic a metalearning setting, consider a toy learning to learn problem:
How do we learn a dataset which can be perfectly classified by a linear binary classifier with only a few gradient updates?
Unlike standard settings where we learn a classifier using a fixed dataset, we instead want to learn a dataset such that any subsequent classifier is easy to learn. (We want to learn a dataset, to learn a classifier quickly.)
Intuitively, one solution is a dataset which has two farseparated clusters such that any randomly initialized classifier (in the 2D case a line) can be adapted in very few steps to classify the points perfectly, as in the figure below. We do indeed find later that the solutions look similar to two separated clusters.
One algorithm for learning to learn is known as MAML, which we describe next.
MAML
ModelAgnostic Meta Learning or MAML^{2} formulates the metalearning problem as learning the parameters of a task via firstorder gradients, such that adapting the parameters for a novel task takes only a few gradient steps. Visually,
We want to learn a dataset of 100 2D points $\theta \in \mathbb{R}^{100 \times 2}$, such that they can be perfectly classified by a linear classifier with independent parameters $\beta$, in only a few gradient steps.
MAML prescribes $T$ “inner loop” updates for every “outer loop” update. For a given state of parameters $\theta$ and loss function $\ell$, the inner loop gradient updates using SGD step size $\eta$ look like,
$\begin{aligned} \theta_t^\prime = \theta_{t1}^\prime  \eta \nabla_\theta \ell(\theta) \big_{\theta=\theta_{t1}^\prime},\text{ for } t \in \{1,\dots,T \},\text{ s.t. } \theta_{0}^\prime = \theta. \end{aligned}$The resulting $\theta_T^\prime$ is then used to construct the SGD step size $\alpha$ update for the corresponding outer loop as,
$\theta \leftarrow \theta  \alpha \nabla_\theta \ell(\theta_T^\prime).$The key operation of note here is that $\theta_T^\prime$ is itself a function of $\theta$, say $\theta_T^\prime = g(\theta)$. Since $g(\theta)$ already involves a sequence of firstorder gradient updates through time, MAML therefore requires secondorder gradients in the outer loop that differentiate through the inner loop updates.
More generally, outer parameters $\theta$ and inner parameters $\theta^\prime$ can be shared (i.e., $\theta = \theta^\prime$) or completely separate sets of parameters. For our toy problem, we take the inner loop parameters to be $\beta$ and outer loop parameters to be the dataset $\theta$. Such an algorithm with independent inner and outer loop parameters was proposed as CAVIA.^{3}
MetaLearning a Dataset
For our toy problem, the parameters we learn are in fact the dataset $X = \theta \in \mathbb{R}^{100 \times 2}$ itself. In code, we randomly initialize a MetaDatasetModule
where the parameters are self.X
as,
import torch.nn as nn
class MetaDatasetModule(nn.Module):
def __init__(self, n=100, d=2):
super().__init__()
self.X = nn.Parameter(torch.randn(n, d))
self.register_buffer('Y', torch.cat([
torch.zeros(n // 2), torch.ones(n // 2)]))
self.Y
is constructed to contain equal samples of each class, labeled as zeros and ones.
For our toy problem, we want to learn a linear classifier which we represent with weights $w \in \mathbb{R}^2$ and bias $b \in \mathbb{R}$ in the inner loop, i.e. $\beta$ is the combination of $w$ and $b$. More importantly, the dataset $X$ should be such that the classifier is learnable in a few gradient updates (we choose three). We abstract away this inner loop by implementing it in the forward pass as:
class MetaDatasetModule(nn.Module):
# ...
def forward(self, device, n_inner_opt=3):
## Hotpatch metaparameters.
self.register_parameter('w',
nn.Parameter(torch.randn(self.X.size(1))))
self.register_parameter('b',
nn.Parameter(torch.randn(1)))
inner_loss_fn = nn.BCEWithLogitsLoss()
inner_optimizer = optim.SGD([self.w, self.b],
lr=1e1)
with higher.innerloop_ctx(self, inner_optimizer,
device=device, copy_initial_weights=False,
track_higher_grads=self.training) as (fmodel, diffopt):
for _ in range(n_inner_opt):
logits = fmodel.X @ fmodel.w + fmodel.b
inner_loss = inner_loss_fn(logits, self.Y)
diffopt.step(inner_loss)
return fmodel.X @ fmodel.w + fmodel.b
Let us breakdown the key ingredients of this forward pass:
 Parameters
self.w
andself.b
are hotpatched into the model during the forward pass using PyTorch’sregister_parameter
function. These parameters are required only in the inner loop, and are therefore initialized locally in the forward pass only.  An inner optimizer, SGD with learning rate $\eta = 10^{1}$ points to the hotpatched inner parameters of the linear classifier.
 Using a
higher
inner loop contexthigher.innerloop_ctx
, we monkeypatch the PyTorch module containing the outer parameter variableself.X
. Most importantly, we setcopy_initial_weights=False
so that we keep using the original parameters in subsequent computational graph.  For memory efficiency during evaluation, we set
track_higher_grads
toFalse
when the module is not training, so that the computational graph is not constructed. This flag is modified using.eval()
call to the module.  The loop represents multiple steps of SGD where the logits are constructed as the matrix operation $X w + b$, and the loss is the standard binary crossentropy loss
BCEWithLogitsLoss
. Notably, we use the monkeypatched version of the original model, represented byfmodel
, and a differentiable version of the optimizerdiffopt
.  Within the context, we now execute the final forward pass such that the output of the forward pass now contains a full computational graph of inner gradient updates of inner parameters $w$ (
fmodel.w
) and $b$ (fmodel.b
).
Therefore, a forward pass of the MetaDatasetModule
returns the logits, which can now be sent to the binary crossentropy loss transparently. The computational graph automatically unrolls through the inner loop gradient updates as preserved due to the reliance on higher
’s fmodel
. The rest of the training can operate the same as our skeleton PyTorch pipeline in the introduction.
Parallelizing MetaLearning
Preparing a PyTorch model for distributed training only requires the Distributed Data Parallel (DDP) wrapper:
model = DistributedDataParallel(model, device_ids=[device_id])
DDP works under the assumption that any parameters registered in the model are not modified after being wrapped. A key reason is that every parameter gets a communication hook attached which are used to sync the gradients across different processes during distributed training. Any modifications to the parameters on the fly would not inherit such hooks, forcing our hand to manually handle distributed communication. This is error prone and best avoided.
In the context of our metalearning setup, we modify the existing parameters on the fly inside the higher
inner loop context. By creating a local copy, we violate the assumption above that the parameters registered before wrapping with DDP are not modified. And therefore, higher
does not support distributed training^{1} outofthebox.
To remain compatible with DDP, we must construct computational graph on top of the originally registered parameters. This is why setting copy_initial_weights=False
is important. Any additional parameters introduced in the inner loop do not interfere.
A computational graph constructed due to gradient updates in the inner loop will be preserved in the returned variable. This design enables the transparent usage of the forward pass, such that parallelizing the MetaDatasetModule
is exactly the same as operating with a standard PyTorch model  wrapping in DDP.
Visualizing Results
To verify whether our parallelized metalearning setup works, we do a complete run with different number of GPUs.
For a fixed number of outer loop steps $T = 500$, we expect that increasing the number of GPUs available should lead to more effective learning. This is because, even though the outer loop updates are fixed in number, each outer loop update sees more examples by a factor of the number of GPUs  every GPU corresponds to an independent random initialization of the inner loop parameters ($w$ and $b$).
For instance, training with 4 GPUs sees 4 random initializations of the inner loop parameters in each outer loop gradient step as compared to just 1 when training with a single GPU.
In the figures above, we see a gradual increase in the effectiveness of a dataset given a fixed outer loop budget of 500 steps  when training with 8 GPUs, we effectively get a dataset which can be perfectly classified by a linear classifier. Our toy metalearning task is solved!
See activatedgeek/higherdistributed for the complete code.
A General Recipe
Finally, we can summarize a recipe to convert your PyTorch module to support differentiation through gradient updates as:
 Create a
MetaModule
that wraps the original module:
import torch.nn as nn
class MetaModule(nn.Module):
def __init__(self, module):
super().__init__()
## Automatically registers module's parameters.
self.module = module
 In the forward pass, create the
higher
context withcopy_initial_weights=False
, and make sure to send a final forward pass applied on the monkeypatched modelfmodel
.
class FunctaMetaModule(nn.Module):
# ...
def forward(self, inputs):
## Patch meta parameters.
self.register_parameter('inner_params', nn.Parameter(...))
inner_optimizer = self.inner_optimizer_fn([self.module])
with higher.innerloop_ctx(self, inner_optimizer,
device=Y.device, copy_initial_weights=False,
track_higher_grads=self.training) as (fmodel, diffopt):
## Operate as usual on fmodel.
## ...
# Return a final forward pass on the monkeypatched fmodel.
return fmodel(inputs)
 Apply optimizers, distributed training, etc. as usual.
 Profit!
Acknowledgments
Thanks to Ed Grefenstette, Karan Desai, Tanmay Rajpurohit, Ashwin Kalyan, David Schwab and Ari Morcos for discussions around this approach.
Footnotes

Finn, Chelsea et al. “ModelAgnostic MetaLearning for Fast Adaptation of Deep Networks.” ArXiv abs/1703.03400 (2017). https://arxiv.org/abs/1703.03400 ↩

Zintgraf, Luisa M. et al. “Fast Context Adaptation via MetaLearning.” International Conference on Machine Learning (2018). https://proceedings.mlr.press/v97/zintgraf19a.html ↩