The explicit reparameterization trick (ERT) is often used to train various latent variable models due to the ease of calculating gradients of continuous random variables. By making it possible to backpropagate error in computation graphs with certain types of continuous random variables (e.g., Normal and Logistic distributions), ERT serves as a powerful tool for learning. However, due to its peculiarities, ERT is not applicable to several important continuous standard distributions, such as mixture, Gamma, Beta and Dirichlet.
An alternative method for calculating reparameterization gradients relies on implicit differentiation of cumulative distribution functions (CDFs). The implicit reparameterization trick (IRT), being a modification of ERT, is much more expressive and applicable to a wider class of distributions.
This article provides an overview of various reparameterization tricks and announces a new Python library, irt.distributions
, for sampling from various distributions using the IRT.
Explicit reparameterization gradients
The vanilla ELBO estimation in generative models like VAE uses reparametrization trick: a method of sampling random variables with low variance of the gradient of parameter distribution. The problem with this method is that it's available only for the limited number of distributions. Let's proceed to a more detailed discussion of this problem.
If we would like to optimize the expectation of some continuously differentiable function w.r.t. the parameters of the distribution, we are faced with the difficulty of doing this directly. The idea behind the reparameterization trick is to replace a probability distribution with an equivalent parameterization of it using a deterministic and differentiable transformation of some fixed distribution.
We will assume that there exists a continuously differentiable (w.r.t. its argument and parameters) and invertible standardization function that, when applied to a sample from the distribution , eliminates its dependence on the distribution's parameters. This standardization function should be continuously differentiable with respect to both its argument and parameters, and it must be invertible:
Example: For a Gaussian distribution we can use , as a standardization function.
Under the assumptions above, we can then represent the objective as an expectation w.r.t. , shifting the dependence on into :
This enables us to calculate the gradient of the expectation as the expectation of the gradients:
A standardization function satisfying the requirements exists for a wide range of continuous distributions. However, inverting the CDF is often complex and computationally intensive, and calculating its derivative poses even greater challenges.
Implicit reparametrization gradients
The IRT avoids the need to invert the standardization function. To accomplish this, we perform a change of variables: :
By applying the total gradient to the equality and expressing the result in terms of partial gradients using the chain rule, we derive:
Now, let's solve the latter equation for :
It is important to note that this expression for the gradient , calculated by implicit differentiation, only requires differentiation of the standardization function rather than its inversion.
The following table compares two types of reparameterization: ERT and IRT. Samples of in the case of IRT can be obtained, for instance, by rejection sampling, and the gradients of the standardization function can be calculated either analytically or using automatic differentiation.
Explicit reparameterization | Implicit reparameterization | |
Forward pass | Sample Set | Sample |
Backward pass | Set Set | Set Set |
New Python Library: torch.distributions.implicit
We implement the following distributions in our library irt.distributions
:
Gaussian normal distribution
Dirichlet distribution (Beta distribution)
Gamma distrbutioin
Sampling from a mixture of distributions
Sampling from the Student's t-distribution
Our focus is on providing reliable, efficient, and user-friendly tools that will benefit both researchers and practitioners. We hope that the implementation of the library will contribute to the further development of this field of science.
Stay tuned for updates!
Git: https://github.com/intsystems/implicit-reparameterization-trick.git