Skip to main content

Variational Inference

Variational Inference: A Mathematical Explanation

Variational Inference (VI)

Variational Inference (VI) is a mathematical framework for approximating complex posterior distributions in probabilistic models. Instead of sampling (as in Monte Carlo methods), VI transforms inference into an optimization problem, making it scalable and efficient for deep learning.

1. Problem Setup

We consider a probabilistic model with observed variables \( x \) and latent variables \( z \). The joint distribution is defined as:

$$ p_\theta(x, z) = p_\theta(x \mid z)\, p_\theta(z) $$
Notation:
\( x \): observed data
\( z \): latent (hidden) variable
\( p_\theta(x \mid z) \): likelihood (decoder)
\( p_\theta(z) \): prior distribution
\( \theta \): model parameters

The goal is to find parameters that maximize the marginal likelihood of the observed data:

$$ p_\theta(x) = \int p_\theta(x, z)\, dz $$

However, this integral is often intractable due to high-dimensional or nonlinear dependencies.

2. The Posterior Distribution

The posterior distribution represents our belief about \( z \) after observing \( x \):

$$ p_\theta(z \mid x) = \frac{p_\theta(x, z)}{p_\theta(x)} $$

Directly computing this posterior is difficult because it requires evaluating \( p_\theta(x) \), which involves the intractable integral above.

3. Variational Approximation

We introduce an approximate posterior distribution \( q_\phi(z \mid x) \), parameterized by \( \phi \), to approximate the true posterior:

$$ q_\phi(z \mid x) \approx p_\theta(z \mid x) $$
Notation:
\( q_\phi(z \mid x) \): approximate posterior (encoder)
\( \phi \): parameters of the variational distribution
Goal: make \( q_\phi(z \mid x) \) close to \( p_\theta(z \mid x) \)

4. Kullback–Leibler Divergence

We measure the difference between the two distributions using the KL divergence:

$$ \mathrm{KL}\!\left(q_\phi(z \mid x) \, \| \, p_\theta(z \mid x)\right) = \mathbb{E}_{q_\phi(z \mid x)}\!\left[\log \frac{q_\phi(z \mid x)}{p_\theta(z \mid x)}\right] $$

Since the KL divergence is non-negative, minimizing it brings \( q_\phi \) closer to the true posterior.

5. Deriving the Evidence Lower Bound (ELBO)

Using the identity \( p_\theta(z \mid x) = \frac{p_\theta(x, z)}{p_\theta(x)} \), we can rewrite:

$$ \mathrm{KL}(q_\phi(z \mid x) \Vert p_\theta(z \mid x)) = \mathbb{E}_{q_\phi(z \mid x)}[\log q_\phi(z \mid x) - \log p_\theta(x, z)] + \log p_\theta(x) $$

Rearranging terms gives:

$$ \log p_\theta(x) = \mathbb{E}_{q_\phi(z \mid x)}[\log p_\theta(x, z) - \log q_\phi(z \mid x)] + \mathrm{KL}(q_\phi(z \mid x) \Vert p_\theta(z \mid x)) $$

Because the KL term is always non-negative, we obtain a lower bound on the log evidence:

$$ \log p_\theta(x) \ge \mathbb{E}_{q_\phi(z \mid x)}[\log p_\theta(x, z) - \log q_\phi(z \mid x)] $$

This lower bound is called the Evidence Lower Bound (ELBO).

6. ELBO Simplification

We can expand \( p_\theta(x, z) = p_\theta(x \mid z) p_\theta(z) \) to express the ELBO as:

$$ \text{ELBO}(\theta, \phi) = \mathbb{E}_{q_\phi(z \mid x)}[\log p_\theta(x \mid z)] - \mathrm{KL}(q_\phi(z \mid x) \Vert p_\theta(z)) $$
Notation:
\( \mathbb{E}_{q_\phi(z \mid x)}[\log p_\theta(x \mid z)] \): reconstruction term
\( \mathrm{KL}(q_\phi(z \mid x) \Vert p_\theta(z)) \): regularization term
ELBO balances reconstruction accuracy and posterior smoothness

7. Optimization Objective

We maximize the ELBO (or equivalently minimize its negative) with respect to both \(\theta\) and \(\phi\):

$$ \min_{\theta, \phi} \Big[ - \mathbb{E}_{q_\phi(z \mid x)}[\log p_\theta(x \mid z)] + \mathrm{KL}(q_\phi(z \mid x) \Vert p_\theta(z)) \Big] $$

8. Gaussian Example

Assume all distributions are Gaussian:

$$ p_\theta(z) = \mathcal{N}(0, I), \quad q_\phi(z \mid x) = \mathcal{N}(\mu_\phi(x), \Sigma_\phi(x)), \quad p_\theta(x \mid z) = \mathcal{N}(f_\theta(z), \sigma^2 I) $$

Then, the ELBO becomes:

$$ \text{ELBO} = -\frac{1}{2\sigma^2} \|x - f_\theta(z)\|^2 - \frac{1}{2}\sum_i \big(1 + \log \sigma_{\phi,i}^2 - \mu_{\phi,i}^2 - \sigma_{\phi,i}^2 \big) $$
Notation:
\( f_\theta(z) \): decoder network mapping latent \(z\) to reconstructed \(x\)
\( \mu_\phi(x), \Sigma_\phi(x) \): encoder outputs (mean and variance)
\( \sigma^2 \): decoder noise variance

9. Summary

Key ideas:
1. \(p_\theta(x, z) = p_\theta(x \mid z)p_\theta(z)\): defines the generative model.
2. \(q_\phi(z \mid x)\): approximates the true posterior.
3. The ELBO provides a computable lower bound on the data likelihood.
4. Training maximizes the ELBO to learn both encoder (\(\phi\)) and decoder (\(\theta\)) parameters.

10. References

Kingma, D. P., & Welling, M. (2014). Auto-Encoding Variational Bayes. arXiv preprint, arXiv:1312.6114.

Nakajima, S., & Watanabe, K. (2019). Variational Bayesian Learning Theory. Springer Nature.

Comments

Popular posts from this blog

DINOv3

DINOv3: Unified Global & Local Self-Supervision DINOv3: Unified Global & Local Self-Supervision DINOv3 extends the DINOv2 framework by combining global self-distillation with masked patch prediction . This allows the model to learn both image-level and dense, spatial representations within a single self-supervised pipeline. This image shows the cosine similarity maps from DINOv3 output features, illustrating the relationships between the patch marked with a red cross and all other patches (as reported in the DINOv3 GitHub repository ). If you find DINOv3 useful, consider giving it a star ⭐. Citation for this work is provided in the References section. 1. Student–Teacher Architecture As in DINOv2, DINOv3 uses a student–teacher setup: a student network with parameters \( \theta \) a teacher network with parameters \( \xi \) Both networks receive different augmented views of the inpu...

Vision Transformers

Vision Transformer (ViT): A Mathematical Explanation Vision Transformer (ViT) The Vision Transformer (ViT) is a deep learning model that applies the Transformer architecture—originally designed for language processing—to visual data. Unlike CNNs, which operate on local pixel neighborhoods, ViT divides an image into patches and models global relationships among them via self-attention. 1. Image to Patch Embeddings The input image: $$ \mathbf{x} \in \mathbb{R}^{H \times W \times C} $$ is divided into non-overlapping patches of size \( P \times P \), giving a total of $$ N = \frac{H \times W}{P^2} $$ patches. Each patch \( \mathbf{x}^{(i)} \) is flattened and linearly projected into a \( D \)-dimensional embedding: $$ \mathbf{e}^{(i)} = \mathbf{W}_{\text{embed}} \, \text{vec}(\mathbf{x}^{(i)}) \in \mathbb{R}^D, \quad i = 1, \dots, N $$ After stacking all patch embeddings, we form: $$ \mathbf{E} = [\mathbf{e}^{(1)}, \dots, \mathb...

DINOv2

DINOv2: A Mathematical Explanation of Self-Supervised Vision Learning DINOv2: Self-Distillation for Vision Without Labels DINOv2 is a powerful self-supervised vision model that learns visual representations without using labels. It builds on the original DINO framework, using a student–teacher architecture and advanced augmentations to produce strong, semantically rich embeddings. 1. Student–Teacher Architecture DINOv2 uses two networks: a student network with parameters \( \theta \) a teacher network with parameters \( \xi \) Both networks receive different augmented views of the same image. $$ x_s = \text{Aug}_{\text{student}}(x), \qquad x_t = \text{Aug}_{\text{teacher}}(x) $$ The student learns by matching the teacher’s output distribution. The teacher is updated using an exponential moving average (EMA) of the student. 2. Image Embeddings The student and teacher networks (often Vision Transformers) pr...