The Stein Gradient

Visualizing the simple yet powerful Stein gradient for sampling (with notebook)

🧮 math

Table of Contents

Machine Learning is all about dealing with uncertainty of outcomes and Bayesian inference provides us a principled way to reason about the same. We combine the observed data with priors to build (potentially complex) posteriors over the variables of interest and use those for answering subsequent questions. The ability to model a probability distribution called the posterior allows us to quantify the uncertainty claims for any downstream tasks.

p(ΘX)posterior=p(XΘ)likelihoodp(Θ)priorP(X)evidence\overbrace{p(\Theta|\mathbf{X})}^{\text{posterior}} = \frac{\overbrace{p(\mathbf{X}|\Theta)}^{\text{likelihood}}\overbrace{p(\Theta)}^{\text{prior}}}{\underbrace{P(\mathbf{X})}_{\text{evidence}}}

However, the posterior is a tricky term to compute because of the computationally intensive integral term in the denominator (the evidence). Without having access to the true closed form of the posterior, the next best thing to have is a representative set (more formally the typical set 1 from the posterior distributions’ high density regions. Over the years, researchers have developed a handful of umbrella techniques, all of which converge to the true distribution in the limit of the number of samples - conjugate priors for mathematical convenience 2, Markov Chain Monte Carlo (MCMC) family of algorithms especially the Hamiltonian Monte Carlo (HMC) 3 4, Variational Inference 5 6 via approximate posteriors, Normalizing Flows 7 for deterministic transforms. All these methods have been immensely successful in modern Bayesian analysis.

However, like all feasible methods, these make trade offs - using conjugate priors limits the richness of the posterior densities and restricts us to only mathematically convenient likelihood-prior combinations; the MCMC family requires us to carefully design the transition kernel of the Markov chain and the SDE integrator for numerical stability; Variational Inference can lead to posteriors that under-estimate the support of the true distribution because of the nature of forward KL divergence optimization; Normalizing Flows demand an exact parametric form of the bijectors where the computation of the determinant of the Jacobian needs to be tractable. Today we will look at an approach which does away with all the previous pathologies (not to say it doesn’t introduce new ones) called the kernelized Stein Gradient.

We will first summarize the key background topics needed to understand, state the formal results underlying the Stein gradient and then visualize via code how this technique can be used to build powerful samplers.

Background

Kernels

Kernels are a very neat idea where the objective is to map our raw input space X\mathcal{X} to an abstract vector space H\mathcal{H} so that we can define similarity in terms of dot products and get all the geometric niceties of angles, lengths and distances 8. Formally, the most important equations can be summarized as

Φ:XHxxΦ(x)k(x,x)=Φ(x),Φ(x)\begin{aligned} \mathbf{\Phi}:& \mathcal{X} \to \mathcal{H} \\ &\mathcal{x} \mapsto \mathbf{x} \triangleq \mathbf{\Phi}(\mathcal{x}) \\ k(x, x^\prime) &= \langle \mathbf{\Phi}(\mathcal{x}), \mathbf{\Phi}(\mathcal{x}^\prime) \rangle \end{aligned}

The key idea is that given a kernel function, we can construct a feature space (more formally known as the Hilbert space in functional analysis) such that the kernel computes a dot product in that feature space. The terminology kernel comes from a function kk which gives rise to the following integral operator TkT_k

(Tkf)(x)=Xk(x,x)f(x)dx(T_kf)(x) = \int_{\mathcal{X}} k(x, x^\prime) f(x) dx^\prime

Intuitively, one can imagine this operator to be a rich representation in terms of the linear combination over the full space where the combination coefficients are defined by the kernel function. By extending this idea and a particular construction, we arrive at the idea of representer of evaluation which leads to the classic notion of reproducing kernel Hilbert space (RKHS).

Stein’s Identity

Let p(x)p(x) be a smooth density supported on XRd\mathcal{X} \subseteq \mathbb{R}^d and ϕ(x)\mathbf{\phi}(x) is a smooth vector function, the Stein’s identity states that under a Stein operator Ap\mathcal{A}_p,

Exp[Apϕ(x)]=0where,Apϕ(x)=ϕ(x)xlogp(x)T+xϕ(x)\begin{aligned} \mathbb{E}_{x \sim p}\left[ \mathcal{A}_p\phi(x) \right] &= 0 \\ \text{where}, \mathcal{A}_p\phi(x) &= \phi(x) \nabla_x \log{p(x)}^T + \nabla_x \phi(x) \end{aligned}

We need mild boundary conditions for the above to be true - either p(x)ϕ(x)=0 xXp(x)\phi(x) = 0\text{ }\forall x \in \partial\mathcal{X} when X\mathcal{X} is compact or limxp(x)ϕ(x)=0\lim_{||x|| \to \infty}p(x)\phi(x) =0 when X=Rd\mathcal{X} = \mathbb{R}^d. A function ϕ\phi belongs to the Stein class of pp if the Stein identity holds.

Stein Discrepancy

Let us consider another smooth density q(x)q(x) and ϕ\phi belongs to the Stein class of qq (but not pp). After some simple manipulations, we can see that

Eq[Apϕ(x)]=Eq[Apϕ(x)Aqϕ(x)]=Eq[(xlogp(x)xlogq(x))ϕ(x)T]\begin{aligned} \mathbb{E}_q\left[\mathcal{A}_p\phi(x)\right] &= \mathbb{E}_q\left[\mathcal{A}_p\phi(x) - \mathcal{A}_q\phi(x)\right] \\ &= \mathbb{E}_q\left[ (\nabla_x \log{p(x)} - \nabla_x \log{q(x)}) \phi(x)^T \right] \end{aligned}

Intuitively, this can be seen as a function weighted by the difference of the score functions of both the distributions. It has been shown 9 that

Eq[(xlogp(x)xlogq(x))ϕ(x)T]=Eq[trace(Apϕ(x))]\mathbb{E}_q\left[ (\nabla_x \log{p(x)} - \nabla_x \log{q(x)}) \phi(x)^T \right] = \mathbb{E}_q\left[trace(\mathcal{A}_p\phi(x))\right]

Note that the use of trace for the matrix norm is only meant to make it a scalar and other matrix norms may be considered. If we consider the maximum value over some family of test functions, we get the Stein Discrepancy measure.

S(q,p)=maxϕF{(Eq[trace(Apϕ(x))])2}\mathbb{S}(q, p) = \underset{\phi \in \mathcal{F}}{\max} \left\{ (\mathbb{E}_q\left[trace(\mathcal{A}_p\phi(x))\right])^2 \right\}

This can be thought of as a maximum possible violation under the family of test functions away from pp. However, it turns out that the computational tractability of this measure is critically dependent on the choice of family of test functions F\mathcal{F} (and mostly not possible).

Kernelized Stein Discrepancy

If we choose the family of test functions to be the unit norm RKHS, it turns out we can find out a closed form for the Stein Discrepancy measure 10.

S(q,p)=maxϕHd{(Eq[trace(Apϕ(x))])2,ϕHd1}=ϕq,pHd2where, ϕq,p()=Eq[Apk(x,)]\begin{aligned} \mathbb{S}(q, p) &= \underset{\phi \in \mathcal{H}^d}{\max} \left\{ (\mathbb{E}_q\left[trace(\mathcal{A}_p\phi(x))\right])^2, ||\phi||_{\mathcal{H}^d} \leq 1 \right\} \\ &= ||\mathbf{\phi}^\star_{q,p}||_{\mathcal{H}^d}^2 \\ \text{where, } \mathbf{\phi}^\star_{q,p}(\cdot) &= \mathbb{E}_q\left[ \mathcal{A}_p k(x, \cdot) \right] \end{aligned}

This measure is zero if and only if q=pq = p. The power of this result is realized by the Stein Variational Gradient Descent.

Stein Variational Gradient Descent

As it turns out, under a deterministic transform (a one step normalizing flow) z=T(x)=x+ϵϕ(x),where xq(x)z = \mathbf{T}(x) = x + \epsilon \mathbf{\phi}(x), \text{where } x \sim q(x), we have 11.

ϵKL(q[T]p)ϵ=0=Eq[trace(Apϕ(x))]\nabla_\epsilon KL(q_{[\mathbf{T}]} || p) \bigg|_{\epsilon = 0} = - \mathbb{E}_q\left[trace(\mathcal{A}_p\phi(x))\right]

The distribution of q[T]q_{[\mathbf{T}]} can be given by the change of variable formula for probability distributions

q[T](z)=q(T1(z))detzT1(z)q_{[\mathbf{T}]}(z) = q(\mathbf{T}^{-1}(z))\cdot \det{\left|\nabla_z\mathbf{T}^{-1}(z)\right|}

Combining this result with the previous discussion, we can conclude that ϕq,p\phi^\star_{q,p} is the optimal direction of perturbation in the unit norm RKHS. This is the direction of the steepest descent that minimizes the KL divergence of the transformed distribution q[T]q_{[\mathbf{T}]} in the zero centered ball of Hd\mathcal{H}^d and the magnitude of change is ϵKL(q[T]p)ϵ=0=S(q,p)\nabla_\epsilon KL(q_{[\mathbf{T}]} || p) \bigg|_{\epsilon = 0} = - \mathbb{S}(q, p)

In practice, making this identity perturbation transform with every timestep ϵ\epsilon brings the KL divergence down by a factor of ϵS(q,p)\epsilon \mathbb{S}(q,p). If we keep running this long enough, we should eventually converge to the true distribution pp. Therefore, the ODE we are trying to simulate to convergence is

x˙=ϕq,p(x)\dot{x} = \phi^\star_{q,p}(x)

ϕq,p\phi^\star_{q,p} is an expectation and can be empirically estimated using a mean of nn particles. Before we go and see how this works in practice, it is important to see what the term ϕq,p\phi^\star_{q,p} is really achieving

ϕ^q,p(x)=1nj=1n[k(xj,x)xjlogp(xj)+xjk(xj,x)]\hat{\phi}^\star_{q,p}(x) = \frac{1}{n} \sum_{j = 1}^n \left[ k(x_j, x) \nabla_{x_j} \log{p(x_j)} + \nabla_{x_j} k(x_j, x) \right]

If we consider just one particle n=1n = 1 and all kernels where xk(x,x)=0\nabla_xk(x,x) =0 (which is true for the rbf kernel), then what we achieve is plain old gradient descent and the particle would simply reach the mode of pp. In the presence of more particles, the kernel acts like a repulsive force which encourage diversity of particles. This is actually a pretty neat result where we have established sort of a communication protocol between the particles via the gradient of the kernel function.

Experiments

Open In Colab

The experiments can be run via the Jupyter notebook. Click the badge above. Here are some results to show you the power of Stein Gradient.

Stein particles on Gaussian Distribution
Stein particles on Gaussian Distribution
Stein particles on Mixture of Two Gaussians
Stein particles on Mixture of Two Gaussians
Stein particles on Mixture of Six Gaussians
Stein particles on Mixture of Six Gaussians
MAP behavior with one particle on a Mixture of Six Gaussians (the particle may fall into any of the modes depending on initial position)
MAP behavior with one particle on a Mixture of Six Gaussians (the particle may fall into any of the modes depending on initial position)

All these results use the rbf kernel with the median bandwidth heuristic for a total of 1000 gradient steps using Adam (the original work 11 uses Adagrad but it should be noted we can use any adaptive gradient descent scheme). See the notebook for more details.

Conclusion

It is time to go back to the fundamental equation in Bayesian learning - the Bayes theorem

p(ΘX)posterior=p(XΘ)likelihoodp(Θ)priorP(X)evidence\overbrace{p(\Theta|\mathbf{X})}^{\text{posterior}} = \frac{\overbrace{p(\mathbf{X}|\Theta)}^{\text{likelihood}}\overbrace{p(\Theta)}^{\text{prior}}}{\underbrace{P(\mathbf{X})}_{\text{evidence}}}

All the methods mentioned in the introduction make some or the other assumption about the parametric nature to get a tractable posterior. Using the Stein gradient, we are in a position to be non-parametric about the posterior. Additionally, we can work with an unnormalized density because the score function does not depend on the normalized constant during the simulation of the ODE described above - xlogp(x)=xlogp~(x)xlogZ=logp~(x)\nabla_x \log{p(x)} = \nabla_x \log{\tilde{p}(x)} - \nabla_x \log{Z} = \log{\tilde{p}(x)}. An example of this was seen in the experiments when we used the Mixture of Gaussians which were unnormalized. It should however be noted that we may need to introduce a stochastic gradient if the likelihood term is costly to evaluate just like in the Stochastic Hamiltonian Monte Carlo gradient. Overall, this is great news and an exciting approach to dig into!

Footnotes

  1. Betancourt, M., Byrne, S., Livingstone, S., & Girolami, M. (2014). The Geometric Foundations of Hamiltonian Monte Carlo. arXiv: Methodology.

  2. Gelman, A., Carlin, J., Stern, H., Dunson, D., Vehtari, A., & Rubin, D. (1995). Bayesian Data Analysis.

  3. Neal, R. (2011). MCMC Using Hamiltonian Dynamics. arXiv: Computation, 139-188.

  4. Brooks, S., Gelman, A., Jones, G., & Meng, X. (2011). Handbook of Markov Chain Monte Carlo.

  5. Jordan, M.I., Ghahramani, Z., Jaakkola, T., & Saul, L. (2004). An Introduction to Variational Methods for Graphical Models. Machine Learning, 37, 183-233.

  6. Ranganath, R., Gerrish, S., & Blei, D. (2014). Black Box Variational Inference. AISTATS.

  7. Rezende, D.J., & Mohamed, S. (2015). Variational Inference with Normalizing Flows. ICML.

  8. Schölkopf, B., & Smola, A. (2001). Learning with Kernels: Support Vector Machines, Regularization, Optimization, and Beyond. Journal of the American Statistical Association, 98, 489-489.

  9. Gorham, J., & Mackey, L. (2015). Measuring Sample Quality with Stein’s Method. ArXiv, abs/1506.03039.

  10. Liu, Q., Lee, J., & Jordan, M.I. (2016). A Kernelized Stein Discrepancy for Goodness-of-fit Tests. ICML.

  11. Liu, Q., & Wang, D. (2016). Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm. NIPS. 2