The Stein Gradient
Visualizing the simple yet powerful Stein gradient for sampling (with notebook)
Machine Learning is all about dealing with uncertainty of outcomes and Bayesian inference provides us a principled way to reason about the same. We combine the observed data with priors to build (potentially complex) posteriors over the variables of interest and use those for answering subsequent questions. The ability to model a probability distribution called the posterior allows us to quantify the uncertainty claims for any downstream tasks.
However, the posterior is a tricky term to compute because of the computationally intensive integral term in the denominator (the evidence). Without having access to the true closed form of the posterior, the next best thing to have is a representative set (more formally the typical set 1 from the posterior distributions’ high density regions. Over the years, researchers have developed a handful of umbrella techniques, all of which converge to the true distribution in the limit of the number of samples - conjugate priors for mathematical convenience 2, Markov Chain Monte Carlo (MCMC) family of algorithms especially the Hamiltonian Monte Carlo (HMC) 3 4, Variational Inference 5 6 via approximate posteriors, Normalizing Flows 7 for deterministic transforms. All these methods have been immensely successful in modern Bayesian analysis.
However, like all feasible methods, these make trade offs - using conjugate priors limits the richness of the posterior densities and restricts us to only mathematically convenient likelihood-prior combinations; the MCMC family requires us to carefully design the transition kernel of the Markov chain and the SDE integrator for numerical stability; Variational Inference can lead to posteriors that under-estimate the support of the true distribution because of the nature of forward KL divergence optimization; Normalizing Flows demand an exact parametric form of the bijectors where the computation of the determinant of the Jacobian needs to be tractable. Today we will look at an approach which does away with all the previous pathologies (not to say it doesn’t introduce new ones) called the kernelized Stein Gradient.
We will first summarize the key background topics needed to understand, state the formal results underlying the Stein gradient and then visualize via code how this technique can be used to build powerful samplers.
Background
Kernels
Kernels are a very neat idea where the objective is to map our raw input space to an abstract vector space so that we can define similarity in terms of dot products and get all the geometric niceties of angles, lengths and distances 8. Formally, the most important equations can be summarized as
The key idea is that given a kernel function, we can construct a feature space (more formally known as the Hilbert space in functional analysis) such that the kernel computes a dot product in that feature space. The terminology kernel comes from a function which gives rise to the following integral operator
Intuitively, one can imagine this operator to be a rich representation in terms of the linear combination over the full space where the combination coefficients are defined by the kernel function. By extending this idea and a particular construction, we arrive at the idea of representer of evaluation which leads to the classic notion of reproducing kernel Hilbert space (RKHS).
Stein’s Identity
Let be a smooth density supported on and is a smooth vector function, the Stein’s identity states that under a Stein operator ,
We need mild boundary conditions for the above to be true - either when is compact or when . A function belongs to the Stein class of if the Stein identity holds.
Stein Discrepancy
Let us consider another smooth density and belongs to the Stein class of (but not ). After some simple manipulations, we can see that
Intuitively, this can be seen as a function weighted by the difference of the score functions of both the distributions. It has been shown 9 that
Note that the use of trace for the matrix norm is only meant to make it a scalar and other matrix norms may be considered. If we consider the maximum value over some family of test functions, we get the Stein Discrepancy measure.
This can be thought of as a maximum possible violation under the family of test functions away from . However, it turns out that the computational tractability of this measure is critically dependent on the choice of family of test functions (and mostly not possible).
Kernelized Stein Discrepancy
If we choose the family of test functions to be the unit norm RKHS, it turns out we can find out a closed form for the Stein Discrepancy measure 10.
This measure is zero if and only if . The power of this result is realized by the Stein Variational Gradient Descent.
Stein Variational Gradient Descent
As it turns out, under a deterministic transform (a one step normalizing flow) , we have 11.
The distribution of can be given by the change of variable formula for probability distributions
Combining this result with the previous discussion, we can conclude that is the optimal direction of perturbation in the unit norm RKHS. This is the direction of the steepest descent that minimizes the KL divergence of the transformed distribution in the zero centered ball of and the magnitude of change is
In practice, making this identity perturbation transform with every timestep brings the KL divergence down by a factor of . If we keep running this long enough, we should eventually converge to the true distribution . Therefore, the ODE we are trying to simulate to convergence is
is an expectation and can be empirically estimated using a mean of particles. Before we go and see how this works in practice, it is important to see what the term is really achieving
If we consider just one particle and all kernels where (which is true for the rbf kernel), then what we achieve is plain old gradient descent and the particle would simply reach the mode of . In the presence of more particles, the kernel acts like a repulsive force which encourage diversity of particles. This is actually a pretty neat result where we have established sort of a communication protocol between the particles via the gradient of the kernel function.
Experiments
The experiments can be run via the Jupyter notebook. Click the badge above. Here are some results to show you the power of Stein Gradient.
All these results use the rbf kernel with the median bandwidth heuristic for a total of 1000 gradient steps using Adam (the original work 11 uses Adagrad but it should be noted we can use any adaptive gradient descent scheme). See the notebook for more details.
Conclusion
It is time to go back to the fundamental equation in Bayesian learning - the Bayes theorem
All the methods mentioned in the introduction make some or the other assumption about the parametric nature to get a tractable posterior. Using the Stein gradient, we are in a position to be non-parametric about the posterior. Additionally, we can work with an unnormalized density because the score function does not depend on the normalized constant during the simulation of the ODE described above - . An example of this was seen in the experiments when we used the Mixture of Gaussians which were unnormalized. It should however be noted that we may need to introduce a stochastic gradient if the likelihood term is costly to evaluate just like in the Stochastic Hamiltonian Monte Carlo gradient. Overall, this is great news and an exciting approach to dig into!
Footnotes
-
Betancourt, M., Byrne, S., Livingstone, S., & Girolami, M. (2014). The Geometric Foundations of Hamiltonian Monte Carlo. arXiv: Methodology. ↩
-
Gelman, A., Carlin, J., Stern, H., Dunson, D., Vehtari, A., & Rubin, D. (1995). Bayesian Data Analysis. ↩
-
Neal, R. (2011). MCMC Using Hamiltonian Dynamics. arXiv: Computation, 139-188. ↩
-
Brooks, S., Gelman, A., Jones, G., & Meng, X. (2011). Handbook of Markov Chain Monte Carlo. ↩
-
Jordan, M.I., Ghahramani, Z., Jaakkola, T., & Saul, L. (2004). An Introduction to Variational Methods for Graphical Models. Machine Learning, 37, 183-233. ↩
-
Ranganath, R., Gerrish, S., & Blei, D. (2014). Black Box Variational Inference. AISTATS. ↩
-
Rezende, D.J., & Mohamed, S. (2015). Variational Inference with Normalizing Flows. ICML. ↩
-
Schölkopf, B., & Smola, A. (2001). Learning with Kernels: Support Vector Machines, Regularization, Optimization, and Beyond. Journal of the American Statistical Association, 98, 489-489. ↩
-
Gorham, J., & Mackey, L. (2015). Measuring Sample Quality with Stein’s Method. ArXiv, abs/1506.03039. ↩
-
Liu, Q., Lee, J., & Jordan, M.I. (2016). A Kernelized Stein Discrepancy for Goodness-of-fit Tests. ICML. ↩
-
Liu, Q., & Wang, D. (2016). Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm. NIPS. ↩ ↩2