After searching innumerable times, I finally have it handy. This was originally
discussed in Issue #394 on Github.

Vectorized Pairwise Distances

For $\mathbf{X} \in \mathbb{R}^{... \times m \times d}, \mathbf{Y} \in \mathbb{R}^{... \times n \times d}$, the pairwise distance matrix between each pair of these batched matrices is $\mathbf{D} \in \mathbb{R}^{... \times m \times n}$, where ... represent arbitrary batch dimension (think batches of pairs of $m$ and $n$ samples of dimension $d$).

## in PyTorchdefpairwise_dist(x,y): xx = (x * x).sum(dim=-1).unsqueeze(-1) yy = (y * y).sum(dim=-1).unsqueeze(-2) xy = torch.einsum('...ji,...ki->...jk', x, y) d = xx + yy -2.* xyreturn d