Fitting Generative Models
How to Fit a Generative Model?
- The goal of a generative model \(\mathcal{M}\) is to generate new data points that resemble the training data.
- We achieve this by minimizing some measure of distance/divergence between the true data distribution \(p_{\text{data}}(\mathbf{y})\) and the model distribution (the marginal distribution of the response) \(p(\mathbf{y} \mid \mathcal{M})\).
- We only have access to samples from the true data distribution, i.e., the training data \(\mathcal{D} = \{\mathbf{y}^{(i)}\}_{i=1}^N\) but not the distribution itself.
How to align the two distributions or their samples?
Kullback-Leibler Divergence
- The Kullback-Leibler (KL) divergence is a measure of how one probability distribution diverges from a second, it is not symmetric.
- For two distributions \(P\) and \(Q\) defined on the same probability space, the KL divergence from \(P\) to \(Q\) is defined as: \[
D_{KL}(P \| Q) = \int p(x) \log\left(\frac{p(x)}{q(x)}\right) dx
\]
- One valid way to align the model and data distributions is to minimize the KL divergence from the data distribution to the model distribution. \[
D_{KL}(p_{\text{data}}(\mathbf{y}) \| p(\mathbf{y} \mid \mathcal{M}))
\]
Kullback-Leibler Divergence
Denote the best model that minimizes the KL divergence by \(\mathcal{M}^*\).
\[
\begin{aligned}
\mathcal{M}^* &= \arg\min_{\mathcal{M}} D_{KL}(p_{\text{data}}(\mathbf{y}) \| p(\mathbf{y} \mid \mathcal{M})) \\
&= \arg\min_{\mathcal{M}} \int p_{\text{data}}(\mathbf{y}) \log\left(\frac{p_{\text{data}}(\mathbf{y})}{p(\mathbf{y} \mid \mathcal{M})}\right) d\mathbf{y} \\
&= \arg\min_{\mathcal{M}} \left[ \int p_{\text{data}}(\mathbf{y}) \log(p_{\text{data}}(\mathbf{y})) d\mathbf{y} - \int p_{\text{data}}(\mathbf{y}) \log(p(\mathbf{y} \mid \mathcal{M})) d\mathbf{y} \right] \\
&= \arg\max_{\mathcal{M}} \int p_{\text{data}}(\mathbf{y}) \log(p(\mathbf{y} \mid \mathcal{M})) d\mathbf{y} \\
& \approx \arg\max_{\mathcal{M}} \frac{1}{N} \sum_{i=1}^N \log(p(\mathbf{y}^{(i)} \mid \mathcal{M}))
\end{aligned}
\]
- The first term is independent of the model \(\mathcal{M}\) and can be ignored during optimization.
- \(N\) is the number of training samples.
Kullback-Leibler Divergence
\[
\begin{aligned}
\mathcal{M}^* &= \arg\min_{\mathcal{M}} D_{KL}(p_{\text{data}}(\mathbf{y}) \| p(\mathbf{y} \mid \mathcal{M})) \\
& \approx \arg\max_{\mathcal{M}} \frac{1}{N} \sum_{i=1}^N \log(p(\mathbf{y}^{(i)} \mid \mathcal{M}))
\end{aligned}
\]
- \(\log(p(\mathbf{y}^{(i)} \mid \mathcal{M}))\) is the contribution of the \(i\)-th data point to the log marginal likelihood of the model \(\mathcal{M}\).
Maximizing the average log marginal likelihood over the training data is asymptotically equivalent to minimizing the KL divergence from the data distribution to the model distribution, as the number of data points \(N\) goes to infinity.
Other Distances and Divergences
Other common distances/divergences used to align the two distributions are:
Jensen–Shannon divergence, used in generative adversarial networks (GANs): \[
\begin{aligned}
D_{JS}(p_{\text{data}}(\mathbf{y}) \| p(\mathbf{y} \mid \mathcal{M})) &= \tfrac{1}{2} D_{KL}\!\left(p_{\text{data}}(\mathbf{y}) \| m\right) + \tfrac{1}{2} D_{KL}\!\left(p(\mathbf{y} \mid \mathcal{M}) \| m\right) \\
m &= \tfrac{1}{2}\big(p_{\text{data}}(\mathbf{y}) + p(\mathbf{y} \mid \mathcal{M})\big)
\end{aligned}
\]
Wasserstein (Earth Mover’s) distance, used in Wasserstein GANs: \[
W\!\left(p_{\text{data}}(\mathbf{y}),\, p(\mathbf{y} \mid \mathcal{M})\right)
= \inf_{\gamma \in \Pi \left(p_{\text{data}},\, p(\cdot \mid \mathcal{M})\right)}
\mathbb{E}_{(\mathbf{y}, \tilde{\mathbf{y}})\sim \gamma} \big[ \| \mathbf{y} - \tilde{\mathbf{y}} \| \big]
\] where \(\Pi\!\left(p_{\text{data}}, p(\cdot \mid \mathcal{M})\right)\) is the set of all couplings (joint distributions) with the specified marginals.
Other Distances and Divergences
Other common distances/divergences used to align the two distributions are:
- Fisher divergence, used in energy-based and (Stein) score-based generative models: \[
D_F\!\left(p_{\text{data}}(\mathbf{y}) \,\|\, p(\mathbf{y} \mid \mathcal{M})\right) = \int p_{\text{data}}(\mathbf{y})\, \left\| \nabla_{\mathbf{y}} \log p_{\text{data}}(\mathbf{y}) + \nabla_{\mathbf{y}} \log p(\mathbf{y} \mid \mathcal{M}) \right\|^2 d\mathbf{y}
\]
The gradient \(\nabla_{\mathbf{y}} \log p(\mathbf{y} \mid \mathcal{M})\) is the model’s Stein score. \(\nabla_{\mathbf{y}} \log p_{\text{data}}(\mathbf{y})\) is the Stein score of the unknown data distribution.
Which Divergence Goes With Which Model?
The choice of divergence to minimize during training is closely tied to which operations the model supports tractably:
- KL divergence (forward) \(\to\) maximum likelihood. Requires evaluating \(\log p(\mathbf{y} \mid \mathcal{M})\). Natural for normalizing flows, PPCA, classical distributions; via ELBO for VAEs and DDPMs.
- Jensen–Shannon divergence \(\to\) GANs. Avoids requiring \(\log p(\mathbf{y})\) entirely — uses an auxiliary discriminator instead.
- Wasserstein distance \(\to\) Wasserstein GANs. More stable than JS for distributions with disjoint supports.
- Fisher divergence \(\to\) score matching, energy-based models, score-based diffusion. Only needs the score of the model, not the normalized density.
Understanding the Marginal Likelihood
Marginal Likelihood
The marginal likelihood, also known as the model evidence, is a key quantity in latent variable models. It is the probability of the observed data under a given model, integrating over all possible values of the latent variables:
\[
p(\mathbf{y} \mid \mathcal{M}) = \int p(\mathbf{y} \mid \boldsymbol{z}, \mathcal{M}) \space p(\boldsymbol{z} \mid \mathcal{M}) \space d\boldsymbol{z}
\]
- \(p(\mathbf{y} \mid \mathcal{M})\): Marginal likelihood of the model \(\mathcal{M}\) (model evidence).
- \(p(\mathbf{y} \mid \boldsymbol{z}, \mathcal{M})\): Conditional likelihood of the data given specific values of the latent variable \(\boldsymbol{z}\).
- \(p(\boldsymbol{z} \mid \mathcal{M})\): Prior distribution over the latent variables.
Pharmacometrics Notation
- In pharmacometrics, the latent variable \(\boldsymbol{z}\) is denoted by \(\boldsymbol{\eta}\), representing individual-specific random effects.
- The observed data \(\mathbf{y}\) corresponds to the measurements taken from individuals.
- The model \(\mathcal{M}\) encompasses the structural model and population parameters: \(\theta\), \(\Omega\), and \(\sigma\).
Balancing Individual Fit and Population Simulation
- Say we have panel data from \(N\) individuals, \(\mathcal{D} = \{\mathbf{y}_i\}_{i=1}^N\).
- We now consider another view of the marginal likelihood that shows how it balances fitting the individual time-series data while ensuring simulation accuracy at the population level.
- Denote the latent variable (individual random effects) by \(\boldsymbol{\eta}_i\) for individual \(i\) and the population parameters by \(\boldsymbol{\theta}\).
- We start with Bayes’ rule:
\[
\begin{aligned}
p(\boldsymbol{\eta} \mid \mathbf{y}, \boldsymbol{\theta}) &= \frac{p(\mathbf{y} \mid \boldsymbol{\eta}, \boldsymbol{\theta}) \, p(\boldsymbol{\eta} \mid \boldsymbol{\theta})}{p(\mathbf{y} \mid \boldsymbol{\theta})} \\
p(\mathbf{y} \mid \boldsymbol{\theta}) &= \frac{p(\mathbf{y} \mid \boldsymbol{\eta}, \boldsymbol{\theta}) \, p(\boldsymbol{\eta} \mid \boldsymbol{\theta})}{p(\boldsymbol{\eta} \mid \mathbf{y}, \boldsymbol{\theta})}
\end{aligned}
\]
Balancing Individual Fit and Population Simulation
We then take the log of both sides:
\[
\log p(\mathbf{y} \mid \boldsymbol{\theta}) = \log p(\mathbf{y} \mid \boldsymbol{\eta}, \boldsymbol{\theta}) + \log p(\boldsymbol{\eta} \mid \boldsymbol{\theta}) - \log p(\boldsymbol{\eta} \mid \mathbf{y}, \boldsymbol{\theta})
\]
Then we take the expectation with respect to the posterior \(p(\boldsymbol{\eta} \mid \mathbf{y}, \boldsymbol{\theta})\):
\[
\log p(\mathbf{y} \mid \boldsymbol{\theta}) =
\underbrace{\mathbb{E}_{p(\boldsymbol{\eta} \mid \mathbf{y}, \boldsymbol{\theta})} \left[\log p(\mathbf{y} \mid \boldsymbol{\eta}, \boldsymbol{\theta})\right]}_{\text{fit quality under posterior}} +
\underbrace{\mathbb{E}_{p(\boldsymbol{\eta} \mid \mathbf{y}, \boldsymbol{\theta})} \left[\log p(\boldsymbol{\eta} \mid \boldsymbol{\theta}) - \log p(\boldsymbol{\eta} \mid \mathbf{y}, \boldsymbol{\theta})\right]}_{-\mathrm{KL}\left(p(\boldsymbol{\eta} \mid \mathbf{y}, \boldsymbol{\theta}) \| p(\boldsymbol{\eta} \mid \boldsymbol{\theta})\right)}
\]
The first term is a measure of how well the model fits the data, averaged over the posterior distribution of the random effects. It encourages good individual fit. We denote it by \(Q\).
The second term is the negative KL divergence between the posterior distribution of the random effects \(p(\boldsymbol{\eta} \mid \mathbf{y}, \boldsymbol{\theta})\) and the prior distribution \(p(\boldsymbol{\eta} \mid \boldsymbol{\theta})\).
Balancing Individual Fit and Population Simulation
For a population of \(N\) individuals, we can write the same quantity as: \[
\log p(\mathbf{y} \mid \boldsymbol{\theta}) = \sum_{i=1}^N \left( Q_i - \mathrm{KL}\!\left(p(\boldsymbol{\eta}_i \mid \mathbf{y}_i, \boldsymbol{\theta}) \| p(\boldsymbol{\eta}_i \mid \boldsymbol{\theta})\right) \right)
\] where
- \(Q_i\) is the fit quality for individual \(i\), and
- \(p(\boldsymbol{\eta}_i \mid \mathbf{y}_i, \boldsymbol{\theta})\) is the posterior distribution of the random effects for individual \(i\).
In NLME, the prior distribution \(p(\boldsymbol{\eta}_i \mid \boldsymbol{\theta})\) is typically the same for all individuals, e.g. \(\mathcal{N}(\mathbf{0}, \boldsymbol{\Omega})\), where \(\boldsymbol{\Omega}\) is a component of the population parameters \(\boldsymbol{\theta}\). \[
\log p(\mathbf{y} \mid \boldsymbol{\theta}) = \sum_{i=1}^N Q_i \;-\; \sum_{i=1}^N \mathrm{KL}\!\left( p(\boldsymbol{\eta}_i \mid \mathbf{y}_i, \boldsymbol{\theta}) \,\|\, \mathcal{N}(\mathbf{0}, \boldsymbol{\Omega}) \right)
\]
Balancing Individual Fit and Population Simulation
The KL divergence is convex with respect to its first argument, the posterior \(p(\boldsymbol{\eta}_i \mid \mathbf{y}_i, \boldsymbol{\theta})\). Therefore, the following inequality holds for the second term above:
\[
\frac{1}{N} \sum_{i=1}^N \mathrm{KL}\!\left(p(\boldsymbol{\eta}_i \mid \mathbf{y}_i, \boldsymbol{\theta}) \,\|\, \mathcal{N}(\mathbf{0}, \boldsymbol{\Omega})\right)
\;\ge\;
\mathrm{KL}\!\left( \frac{1}{N} \sum_{i=1}^N p(\boldsymbol{\eta}_i \mid \mathbf{y}_i, \boldsymbol{\theta}) \,\Big\|\, \mathcal{N}(\mathbf{0}, \boldsymbol{\Omega}) \right)
\]
where \(\frac{1}{N} \sum_{i=1}^N p(\boldsymbol{\eta}_i \mid \mathbf{y}_i, \boldsymbol{\theta})\) is the mixture distribution of the individual posteriors.
Maximizing the marginal likelihood therefore balances between:
- Maximizing the ability of the individual posteriors to fit the individual data by maximizing the sum of the individual fit quality (\(Q_i\)) terms, and
- Minimizing the sum of the individual KL divergences which indirectly aligns the mixture of the individual posteriors with the prior distribution, \(\mathcal{N}(\mathbf{0}, \boldsymbol{\Omega})\), ensuring a calibrated model and realistic simulations.
Balancing Individual Fit and Population Simulation
When simulating from an NLME model with population parameters \(\boldsymbol{\theta}\), we:
- Sample \(\boldsymbol{\eta}\) from the prior distribution \(p(\boldsymbol{\eta} \mid \boldsymbol{\theta})\), e.g. \(\mathcal{N}(\mathbf{0}, \boldsymbol{\Omega})\),
- Then sample \(\mathbf{y}\) from the conditional distribution \(p(\mathbf{y} \mid \boldsymbol{\eta}, \boldsymbol{\theta})\).
If the prior is close to the mixture of the individual posterior distributions, and the individual posteriors themselves fit the individual data well, this ensures that the simulations are realistic and consistent with the observed data.
This is especially useful for pharmacometrics, where it is common to simulate counter-factual scenarios for decision-making, such as predicting the effect of a new dosing regimen.
Predicting the Next Observation
Autoregressive Factorization
Let an individual have \(o\) observations at times \(\mathbf{t} = (t_1,\ldots,t_o)\) with responses \(\mathbf{y} = (y_1,\ldots,y_o)\).
The marginal likelihood factors autoregressively as
\[
\begin{aligned}
p(\mathbf{y} \mid \mathbf{t}, \theta)
&= \prod_{j=1}^o p(y_j \mid y_{1:j-1}, \mathbf{t}, \theta) \\
&= \prod_{j=1}^o p(y_j \mid y_{1:j-1}, t_{1:j}, \theta) \\
\log p(\mathbf{y} \mid \mathbf{t}, \theta)
&= \sum_{j=1}^o \log p(y_j \mid y_{1:j-1}, t_{1:j}, \theta)
\end{aligned}
\]
Here \(y_{1:j}\) denotes the first j observations; \(t_{1:j}\) their time points.
Predicting the Next Observation
First Observation (j = 1)
\[
\begin{aligned}
p(y_1 \mid t_1, \theta)
&= \int p(y_1 \mid t_1, \boldsymbol{\eta}, \theta)\, p(\boldsymbol{\eta} \mid \theta)\, d\boldsymbol{\eta} \\
&= \mathbb{E}_{p(\boldsymbol{\eta} \mid \theta)} \big[ p(y_1 \mid t_1, \boldsymbol{\eta}, \theta) \big]
\end{aligned}
\]
Average predictive probability of the first observation under the prior distribution of random effects.
Subsequent Observations (j > 1)
\[
\begin{aligned}
p(y_j \mid y_{1:j-1}, t_{1:j}, \theta)
&= \int p(y_j \mid t_j, \boldsymbol{\eta}, \theta)\, p(\boldsymbol{\eta} \mid y_{1:j-1}, t_{1:j-1}, \theta)\, d\boldsymbol{\eta} \\
&= \mathbb{E}_{p(\boldsymbol{\eta} \mid y_{1:j-1}, t_{1:j-1}, \theta)} \big[ p(y_j \mid t_j, \boldsymbol{\eta}, \theta) \big]
\end{aligned}
\]
Uses the updated (posterior) distribution of random effects given past data.
Predicting the Next Observation
Predictive / Generalization View
Maximizing \(\log p(\mathbf{y} \mid \mathbf{t}, \theta)\) maximizes, on average over j:
\[
\log p(y_j \mid y_{1:j-1}, t_{1:j}, \theta)
\]
This is the (one–step–ahead) predictive log probability.
Thus the marginal likelihood measures sequential generalization: how well the model predicts each next observation given the past while integrating over uncertainty in \(\boldsymbol{\eta}\) (prior for j=1, posterior thereafter).
Implication
- Encourages good individual fit (accurate one–step predictions).
- Discourages over-fitting by averaging over \(\boldsymbol{\eta}\) instead of conditioning on point estimates.
Summary
- Fitting a generative model means aligning the model and data distributions under some divergence; the choice of divergence depends on which operations the model supports.
- The marginal likelihood is the central quantity in latent-variable models, and maximizing it is asymptotically equivalent to minimizing the forward KL divergence.
- For panel data, the marginal likelihood balances individual-level fit and population-level calibration, and admits an autoregressive view as one-step-ahead prediction performance.
Fitting Probabilistic PCA, NLME Models, and VAEs
Fitting PPCA
In PPCA, the marginal likelihood is available in closed form, so fitting reduces to direct gradient-based or analytical MLE.
Recall the PPCA generative process with parameters \(\theta = \{\mathbf{W}, \boldsymbol{\mu}, \sigma^2\}\): \[
\begin{aligned}
\mathbf{z} &\sim \mathcal{N}(\mathbf{0}, \mathbf{I}), \quad \boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \sigma^2 \mathbf{I}) \\
\mathbf{y} &= \mathbf{Wz} + \boldsymbol{\mu} + \boldsymbol{\epsilon}
\end{aligned}
\]
- The marginal likelihood is available in closed form as a Gaussian: \[
p(\mathbf{y} \mid \theta) = \int p(\mathbf{y} \mid \mathbf{z}, \theta) \, p(\mathbf{z}) \, d\mathbf{z} = \mathcal{N}(\mathbf{y} \mid \boldsymbol{\mu}, \mathbf{C}), \quad \mathbf{C} = \mathbf{W}\mathbf{W}^T + \sigma^2 \mathbf{I}
\]
- The marginal likelihood has a closed form so we can directly maximize it with respect to \(\theta\).
- PPCA admits a closed-form solution via eigen-decomposition of the data covariance matrix, similar to classical PCA, but we can also use gradient-based optimization to find the MLE.
Fitting NLME Models
In NLME models, the marginal likelihood is not available in closed form due to the nonlinearity of the structural model and the presence of individual random effects.
\[
p(\mathbf{y} \mid \boldsymbol{\theta}) = \int p(\mathbf{y} \mid \boldsymbol{\eta}, \boldsymbol{\theta}) \, p(\boldsymbol{\eta} \mid \boldsymbol{\theta}) \, d\boldsymbol{\eta}
\]
We typically use one of the following 2 approaches to fit NLME models:
- Laplace approximation: Approximate the integral over the random effects with a Gaussian centered at the mode of the integrand.
- Expectation-Maximization (EM): Treat the random effects as latent variables and iteratively optimize the expected complete-data log-likelihood.
In pharmacometrics, it is common to further approximate the Laplace approximation by using the first-order (FO) or first-order conditional estimation (FOCE) methods.
Variational Autoencoders (VAEs)
Variational Autoencoders (VAEs)
The decoder and latent space of a VAE can be viewed as an NLME model with:
- A nonlinear structural model defined by a neural network (decoder).
- A Gaussian prior distribution over the latent variables.
\[
\begin{aligned}
\boldsymbol{z} &\sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\
\mathbf{y} &\sim p(\mathbf{y} \mid \boldsymbol{z}) = \mathcal{N}(\mathbf{\mu}(\boldsymbol{z}), \boldsymbol{\Sigma}(\boldsymbol{z}))
\end{aligned}
\]
- The encoder of the VAE can be viewed as an inference network that approximates the posterior distribution of the latent variables given the observed data.
- VAEs are fitted by maximizing the log marginal likelihood (aka evidence) indirectly by maximizing a lower bound called the evidence lower bound (ELBO).
- The ELBO is derived using variational inference and the estimation algorithm is known as variational expectation-maximization (VEM).
Variational Inference
- Variational inference is a method for approximating complex posterior distributions in probabilistic models (\(p(\boldsymbol{z} \mid \mathbf{y})\)) particularly in models with latent variables.
- The key idea is to introduce a family of simpler distributions \(q(\boldsymbol{z})\), called the variational family, to approximate the true posterior distribution.
- The variational family is often chosen to be a parametric distribution, such as a Gaussian distribution with mean and covariance that are neural network functions of the observed data \(\mathbf{y}\).
- The goal is to find the member of the variational family (by tuning its parameters) that is closest to the true posterior distribution, typically by minimizing the KL divergence between the two distributions. \[
D_{KL}(q(\boldsymbol{z}) \| p(\boldsymbol{z} \mid \mathbf{y})) = \int q(\boldsymbol{z}) \log\left(\frac{q(\boldsymbol{z})}{p(\boldsymbol{z} \mid \mathbf{y})}\right) d\boldsymbol{z}
\]
Variational Inference
In VAEs, we want to maximize the marginal likelihood of the observed data \(\mathbf{y}\): \[
p(\mathbf{y}) = \int p(\mathbf{y} \mid \boldsymbol{z}) \, p(\boldsymbol{z}) \, d\boldsymbol{z}
\]
However, this integral is often intractable.
Instead of directly maximizing the marginal likelihood, we use variational inference to:
- Approximate the posterior distribution of the latent variable \(\boldsymbol{z}\) given the observed data \(\mathbf{y}\), and
- Construct a lower bound on the marginal likelihood that can be optimized.
Variational Inference
Assume the variational family/distribution for each subject \(i\) is defined as: \[
\begin{aligned}
\boldsymbol{\xi}_i &\sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\
\boldsymbol{z}_i &= T(\boldsymbol{\xi}_i; \boldsymbol{\phi}_i)
\end{aligned}
\] where \(T(.; \boldsymbol{\phi}_i)\) is an invertible parameterized transformation/function (e.g., an invertible neural network) that:
- Has the same structure for all subjects, but
- Has different parameters \(\boldsymbol{\phi}_i\) for each subject \(i\).
We denote the variational distribution and its probability density function by \(q(\boldsymbol{z}; \boldsymbol{\phi}_i)\).
Variational Inference
The probability density of the variational distribution can be computed using the change of variables formula: \[
\begin{aligned}
q(\boldsymbol{z}; \boldsymbol{\phi}_i) &= p(\boldsymbol{\xi}) \left| \det\left( \frac{\partial T(\boldsymbol{\xi}; \boldsymbol{\phi}_i)}{\partial \boldsymbol{\xi}} \right) \right|^{-1} \\
\boldsymbol{\xi} &= T^{-1}(\boldsymbol{z}; \boldsymbol{\phi}_i)
\end{aligned}
\]
The Jacobian is analytical for simple transformations (e.g., affine transformations) and can be computed using automatic differentiation tools for more complex transformations (e.g., neural networks).
Amortized Inference
- In VAEs, we use amortized inference to share information across different subjects and reduce the number of parameters to be learned.
- Instead of learning separate variational parameters \(\boldsymbol{\phi}_i\) for each subject \(i\), we use a shared inference network (encoder) to map the observed data \(\mathbf{y}_i\) to the variational parameters \(\boldsymbol{\phi}_i\).
- The inference network is typically a neural network that takes the observed data \(\mathbf{y}_i\) as input and outputs the parameters of the variational distribution \(q(\boldsymbol{z}; \boldsymbol{\phi}_i)\).
- The inference network is trained jointly with the generative model (decoder) by maximizing the evidence lower bound (ELBO) on the marginal likelihood of the observed data.
- In NLME models, we typically learn separate variational parameters for each subject without using a shared inference network. This is more accurate but can require more parameters to be learned.
Evidence Lower Bound (ELBO)
- The evidence lower bound (ELBO) is a lower bound on the marginal likelihood of the observed data that can be obtained using variational inference.
- We start from the log marginal likelihood of the observed data \(\mathbf{y}\): \[
\log p(\mathbf{y}) = \log \int p(\mathbf{y} \mid \boldsymbol{z}) \, p(\boldsymbol{z}) \, d\boldsymbol{z}
\]
- We introduce the variational distribution \(q(\boldsymbol{z}; \boldsymbol{\phi})\) (per subject but subscript \(i\) is omitted) by multiplying and dividing the integrand by \(q(\boldsymbol{z}; \boldsymbol{\phi})\): \[
\log p(\mathbf{y}) = \log \int p(\mathbf{y} \mid \boldsymbol{z}) \, p(\boldsymbol{z}) \, \frac{q(\boldsymbol{z}; \boldsymbol{\phi})}{q(\boldsymbol{z}; \boldsymbol{\phi})} \, d\boldsymbol{z} = \log \mathbb{E}_{q(\boldsymbol{z}; \boldsymbol{\phi})} \left[ \frac{p(\mathbf{y} \mid \boldsymbol{z}) \, p(\boldsymbol{z})}{q(\boldsymbol{z}; \boldsymbol{\phi})} \right]
\]
Evidence Lower Bound (ELBO)
- We can then apply Jensen’s inequality to obtain a lower bound on the log marginal likelihood: \[
\log p(\mathbf{y}) \geq \mathbb{E}_{q(\boldsymbol{z}; \boldsymbol{\phi})} \left[ \log \left(\frac{p(\mathbf{y} \mid \boldsymbol{z}) \, p(\boldsymbol{z})}{q(\boldsymbol{z}; \boldsymbol{\phi})} \right) \right]
\]
- Recall Bayes’ rule \[
p(\mathbf{y} \mid \boldsymbol{z})\, p(\boldsymbol{z}) = p(\boldsymbol{z} \mid \mathbf{y})\, p(\mathbf{y})
\]
- One can alternatively write the ELBO as: \[
\begin{aligned}
\text{ELBO} = \int \log\!\left( \frac{ p(\boldsymbol{z} \mid \mathbf{y})\, p(\mathbf{y}) }{ q(\boldsymbol{z}; \boldsymbol{\phi}) } \right) q(\boldsymbol{z}; \boldsymbol{\phi})\, d\boldsymbol{z} &= \log p(\mathbf{y}) + \int \log\!\left( \frac{ p(\boldsymbol{z} \mid \mathbf{y}) }{ q(\boldsymbol{z}; \boldsymbol{\phi}) } \right) q(\boldsymbol{z}; \boldsymbol{\phi})\, d\boldsymbol{z} \\
&= \log p(\mathbf{y}) - \mathrm{KL}\!\left( q(\boldsymbol{z}; \boldsymbol{\phi}) \,\|\, p(\boldsymbol{z} \mid \mathbf{y}) \right)
\end{aligned}
\]
Evidence Lower Bound (ELBO)
\[
\begin{aligned}
\text{ELBO} &= \log p(\mathbf{y}) - \mathrm{KL}\!\left( q(\boldsymbol{z}; \boldsymbol{\phi}) \,\|\, p(\boldsymbol{z} \mid \mathbf{y}) \right)
\end{aligned}
\]
The gap between the ELBO and the actual log marginal likelihood is exactly the KL divergence \(\mathrm{KL}\!\left( q(\boldsymbol{z}; \boldsymbol{\phi}) \,\|\, p(\boldsymbol{z} \mid \mathbf{y}) \right)\).
Minorization Maximization (MM)
The minorization maximization (MM) principle is an iterative optimization technique that can be used to maximize the marginal likelihood.
- First, we construct a surrogate function (ELBO) that is a lower bound on (minorizes) the objective function (log marginal likelihood). The surrogate function is easier to compute and maximize than the original objective function.
- We then iteratively maximize the lower bound to find its local maximizer and then update the lower bound at the new maximizer.
Maximizing the lower bound indirectly maximizes the original objective function.
When minimizing (instead of maximizing) an objective function, the MM principle is also known as majorization minimization. An upper bound (majorizer) is constructed instead of a lower bound (minorizer).
When the surrogate function is an expectation, the MM principle is also known as expectation maximization (EM).
Minorization Step
- Let’s include the model parameters \(\boldsymbol{\theta}\) explicitly and rewrite the ELBO as: \[
\begin{aligned}
\text{ELBO}(\boldsymbol{\phi}, \boldsymbol{\theta}) &= \log p(\mathbf{y} \mid \boldsymbol{\theta}) - \mathrm{KL}\!\left( q(\boldsymbol{z}; \boldsymbol{\phi}) \,\|\, p(\boldsymbol{z} \mid \mathbf{y}, \boldsymbol{\theta}) \right)
\end{aligned}
\]
- The minorization step finds the best variational parameters \(\boldsymbol{\phi}\) for fixed model parameters \(\boldsymbol{\theta}\). \[
\begin{aligned}
\boldsymbol{\phi}^* &= \arg \max_{\boldsymbol{\phi}} \text{ELBO}(\boldsymbol{\phi}, \boldsymbol{\theta}) \\
&= \arg \max_{\boldsymbol{\phi}} \log p(\mathbf{y} \mid \boldsymbol{\theta}) - \mathrm{KL}\!\left( q(\boldsymbol{z}; \boldsymbol{\phi}) \,\|\, p(\boldsymbol{z} \mid \mathbf{y}, \boldsymbol{\theta}) \right) \\
&= \arg \min_{\boldsymbol{\phi}} \mathrm{KL}\!\left( q(\boldsymbol{z}; \boldsymbol{\phi}) \,\|\, p(\boldsymbol{z} \mid \mathbf{y}, \boldsymbol{\theta}) \right)
\end{aligned}
\]
- In practice, we don’t need the real posterior \(p(\boldsymbol{z} \mid \mathbf{y}, \boldsymbol{\theta})\) to perform the minorization step. We use another equivalent form of the ELBO which does not require the true posterior.
- However, conceptually, maximizing the ELBO with respect to \(\boldsymbol{\phi}\) is equivalent to minimizing the KL divergence to the true posterior.
Maximization Step
\[
\begin{aligned}
\text{ELBO}(\boldsymbol{\phi}, \boldsymbol{\theta}) &= \log p(\mathbf{y} \mid \boldsymbol{\theta}) - \mathrm{KL}\!\left( q(\boldsymbol{z}; \boldsymbol{\phi}) \,\|\, p(\boldsymbol{z} \mid \mathbf{y}, \boldsymbol{\theta}) \right)
\end{aligned}
\]
- The maximization step finds the best model parameters \(\boldsymbol{\theta}\) for fixed variational parameters \(\boldsymbol{\phi}\). \[
\begin{aligned}
\boldsymbol{\theta}^* &= \arg \max_{\boldsymbol{\theta}} \text{ELBO}(\boldsymbol{\phi}, \boldsymbol{\theta})
\end{aligned}
\]
Joint Optimization
- Since both the minorization and maximization attempt to maximize the ELBO with respect to different parameters, we can combine them into a single optimization problem: \[
\begin{aligned}
\boldsymbol{\phi}^*, \boldsymbol{\theta}^* &= \arg \max_{\boldsymbol{\phi}, \boldsymbol{\theta}} \text{ELBO}(\boldsymbol{\phi}, \boldsymbol{\theta})
\end{aligned}
\]
Multiple Subjects
For multiple subjects \(i = 1, \ldots, N\), the ELBO decomposes as a sum of individual ELBOs: \[
\text{ELBO}(\boldsymbol{\phi}, \boldsymbol{\theta}) = \sum_{i=1}^N \text{ELBO}_i(\boldsymbol{\phi}_i, \boldsymbol{\theta})
\] where \[\text{ELBO}_i(\boldsymbol{\phi}_i, \boldsymbol{\theta}) = \mathbb{E}_{q(\boldsymbol{z}; \boldsymbol{\phi}_i)} \left[ \log p(\mathbf{y}_i \mid \boldsymbol{z}, \boldsymbol{\theta}) \right] - \mathrm{KL}\!\left( q(\boldsymbol{z}; \boldsymbol{\phi}_i) \,\|\, p(\boldsymbol{z} \mid \boldsymbol{\theta}) \right)\]
Amortized Inference Revisited
- When using amortized inference, the variational parameters \(\boldsymbol{\phi}_i\) for each subject \(i\) are outputs of a shared inference network (encoder) with parameters \(\boldsymbol{\psi}\).
- Denote this function by \(\boldsymbol{\phi}(\mathbf{y}_i; \boldsymbol{\psi})\), which takes the observed data \(\mathbf{y}_i\) as input and outputs the variational parameters \(\boldsymbol{\phi}_i\).
- The ELBO for a dataset with \(N\) subjects becomes: \[
\text{ELBO}(\boldsymbol{\psi}, \boldsymbol{\theta}) = \sum_{i=1}^N \text{ELBO}_i(\boldsymbol{\phi}(\mathbf{y}_i; \boldsymbol{\psi}), \boldsymbol{\theta})
\]
- In this case, we optimize the ELBO with respect to the shared inference network parameters \(\boldsymbol{\psi}\) and the model parameters \(\boldsymbol{\theta}\): \[
\boldsymbol{\psi}^*, \boldsymbol{\theta}^* = \arg \max_{\boldsymbol{\psi}, \boldsymbol{\theta}} \text{ELBO}(\boldsymbol{\psi}, \boldsymbol{\theta})
\]
Denoising Diffusion Probabilistic Models (DDPMs)
Denoising Diffusion Probabilistic Models (DDPMs)
Denoising Diffusion Probabilistic Models (DDPMs)
- DDPMs are analogical to a \(T\)-step VAE, for \(T > 1\).
- Instead of defining the prior and conditional likelihood, we define the posterior distribution as a noise adding step.
- The prior naturally tends to a standard normal distribution as \(T\) increases.
- The conditional likelihood is learned to undo the noise added in the posterior.
- DDPMs are trained by maximizing the ELBO, just like VAEs.
- As \(T \to \infty\), the optimal conditional likelihood can be shown to converge to an isotropic Gaussian form with a closed form expression for the variance, which simplifies training. \[
p(\mathbf{y}_{t-1} \mid \mathbf{y}_t) = \mathcal{N}(\mathbf{y}_{t-1} \mid \boldsymbol{\mu}_{\boldsymbol{\theta}}(\mathbf{y}_t, t), \sigma_t^2 \mathbf{I})
\]
Denoising Diffusion Probabilistic Models (DDPMs)
- \(\sigma_t^2\) has a closed form optimal solution as \(T \to \infty\).
- The only parameters to be learned are those of the mean function \(\boldsymbol{\mu}_{\boldsymbol{\theta}}(\mathbf{y}_t, t)\), which is typically parameterized using a neural network.
- We can re-parameterize the mean function in terms of a noise function \(\boldsymbol{\epsilon}\): \[
\boldsymbol{\mu}_{\boldsymbol{\theta}}(\mathbf{y}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{y}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\mathbf{y}_t, t) \right)
\] where \(\alpha_t\) and \(\bar{\alpha}_t\) are known functions of \(t\) that control the noise schedule. \[
\alpha_t = 1 - \beta_t, \quad \bar{\alpha}_t = \prod_{s=1}^t \alpha_s
\]
Denoising Diffusion Probabilistic Models (DDPMs)
- Note that the noisy data \(\mathbf{y}_t\) at time step \(t\) can be expressed as a linear combination of the original data \(\mathbf{y}_0\) and the noise \(\boldsymbol{\epsilon}\): \[
\mathbf{y}_t = \sqrt{\bar{\alpha}_t} \mathbf{y}_0 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon}
\]
- With the re-parameterization of \(\boldsymbol{\mu}_{\boldsymbol{\theta}}(\mathbf{y}_t, t)\) and the above expression for \(\mathbf{y}_t\), the ELBO objective reduces to a weighted sum: \[
\text{ELBO} = \sum_{t=1}^T w_t \, \mathbb{E}_{\mathbf{y}_0, \boldsymbol{\epsilon}} \left[ \left\| \boldsymbol{\epsilon} - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\mathbf{y}_t, t) \right\|^2 \right] + \text{const}
\] where \[
w_t = \frac{(1 - \alpha_t)^2}{\sigma_t^2 (1 - \bar{\alpha}_t)}
\]
Denoising Diffusion Probabilistic Models (DDPMs)
- More generally, one can define a distribution over the time step \(t\) and re-write the ELBO up to a constant as:
\[
\mathbb{E}_{t, \mathbf{y}_0, \boldsymbol{\epsilon}} \left[ \left\| \boldsymbol{\epsilon} - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\mathbf{y}_t, t) \right\|^2 \right] = \mathbb{E}_{t, \mathbf{y}_0, \boldsymbol{\epsilon}} \left[ \left\| \boldsymbol{\epsilon} - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\sqrt{\bar{\alpha}_t} \mathbf{y}_0 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon}, t) \right\|^2 \right]
\]
- \(\epsilon_{\boldsymbol{\theta}}(\mathbf{y}_t, t)\) is a neural network that takes the noisy data \(\mathbf{y}_t\) and the time step \(t\) as input and outputs an estimate of the noise \(\boldsymbol{\epsilon}\).
- The ELBO objective can be optimized using stochastic gradient descent by sampling \(t\), \(\mathbf{y}_0\), and \(\boldsymbol{\epsilon}\) from their respective distributions.
Continuous-Time Diffusion Models
Continuous-Time Diffusion Models
- Continuous-time diffusion models are a generalization of discrete-time DDPMs where the time step \(t\) is continuous.
- Fitting continuous-time diffusion models reduces to learning the Stein score function \(\nabla_{\mathbf{y}} \log p(\mathbf{y}_{\text{noisy}} \mid \mathcal{M})\) for all noise levels, where \(\mathbf{y}_{\text{noisy}}\) is the noisy data at a given noise level.
- The Stein score function can be learned using Stein score matching, which minimizes the Fisher divergence between the data distribution and the model distribution.
- We will revisit continuous-time diffusion models and Stein score matching later.
Fitting Energy-Based Models
Energy-Based Models
EBMs don’t give us a tractable log density (the partition function \(Z\) is intractable), so MLE and ELBO are off the table. Instead we minimize the Fisher divergence, which only depends on the score \(\nabla_{\mathbf{y}} \log p(\mathbf{y})\) — and the score does not depend on \(Z\).
- Recall that an energy-based model defines \(p(\mathbf{y}) = e^{-E(\mathbf{y})} / Z\) where \(E(\mathbf{y})\) is a neural-network-parameterized energy function and \(Z = \int e^{-E(\mathbf{y})}\, d\mathbf{y}\) is the (typically intractable) partition function.
- Stein score matching is a method for training energy-based and continuous-time diffusion models by minimizing the Fisher divergence between the data distribution and the model distribution.
Stein Score Matching
The Fisher divergence is defined as: \[
D_F\!\left(p_{\text{data}}(\mathbf{y}) \,\|\, p(\mathbf{y} \mid \mathcal{M})\right) = \int p_{\text{data}}(\mathbf{y})\, \left\| \nabla_{\mathbf{y}} \log p_{\text{data}}(\mathbf{y}) + \nabla_{\mathbf{y}} \log p(\mathbf{y} \mid \mathcal{M}) \right\|^2 d\mathbf{y}
\]
The gradient \(\nabla_{\mathbf{y}} \log p(\mathbf{y} \mid \mathcal{M})\) is the model’s Stein score.
\(\nabla_{\mathbf{y}} \log p_{\text{data}}(\mathbf{y})\) is the Stein score of the unknown data distribution.
The Stein score is different from the Fisher score (common in statistics), which is the gradient of the log-likelihood with respect to the model parameters \(\nabla_{\theta} \log p(\mathbf{y} \mid \theta)\).
Stein Score Matching
- The Stein score of the data distribution is unknown.
- The Stein score of the model distribution can be computed using the energy function: \[\nabla_{\mathbf{y}} \log p(\mathbf{y} \mid \mathcal{M}) = -\nabla_{\mathbf{y}} E(\mathbf{y})\]
- Computing the Stein score of the model does not require computing the partition function \(Z\).
- If the energy function is parameterized using neural networks, the Stein score can be computed using automatic differentiation tools.
- If the model has latent variables, the model’s Stein score can be computed using Fisher’s identity. \[ \nabla_{\mathbf{y}} \log p(\mathbf{y} \mid \mathcal{M}) = \mathbb{E}_{p(\boldsymbol{z} \mid \mathbf{y}, \mathcal{M})} \left[ \nabla_{\mathbf{y}} \log p(\mathbf{y}, \boldsymbol{z} \mid \mathcal{M}) \right] \]
- The expectation can be approximated by sampling from the posterior \(p(\boldsymbol{z} \mid \mathbf{y}, \mathcal{M})\) using MCMC methods.
Stein Score Matching
- Assuming the probability density \(p_{\text{data}}(\mathbf{y})\) vanishes at the boundaries of the data space, we can rewrite the minimizer of the Fisher divergence as: \[
\begin{aligned}
\mathcal{M}^* &= \arg\min_{\mathcal{M}} D_F\!\left(p_{\text{data}}(\mathbf{y}) \,\|\, p(\mathbf{y} \mid \mathcal{M})\right) \\
&= \arg\min_{\mathcal{M}} \int p_{\text{data}}(\mathbf{y}) \left\| \nabla_{\mathbf{y}} \log p_{\text{data}}(\mathbf{y}) - \nabla_{\mathbf{y}} \log p(\mathbf{y} \mid \mathcal{M}) \right\|^2 d\mathbf{y} \\
&= \arg\min_{\mathcal{M}} \int p_{\text{data}}(\mathbf{y}) \left\| \nabla_{\mathbf{y}} \log p_{\text{data}}(\mathbf{y}) + \nabla_{\mathbf{y}} E(\mathbf{y}) \right\|^2 d\mathbf{y} \\
& \quad \text{(non-trivial derivation skipped, requires the above assumption)} \\
&= \arg\min_{\mathcal{M}} \int p_{\text{data}}(\mathbf{y}) \left( \|\nabla_{\mathbf{y}} E(\mathbf{y})\|^2 + 2 \Delta_{\mathbf{y}} E(\mathbf{y}) \right) d\mathbf{y} \\
&\approx \arg\min_{\mathcal{M}} \frac{1}{N} \sum_{i=1}^N \left( \|\nabla_{\mathbf{y}} E(\mathbf{y}^{(i)})\|^2 + 2 \Delta_{\mathbf{y}} E(\mathbf{y}^{(i)}) \right)
\end{aligned}
\]
Stein Score Matching
Summary
\[
\begin{aligned}
\mathcal{M}^* &= \arg\min_{\mathcal{M}} D_F\!\left(p_{\text{data}}(\mathbf{y}) \,\|\, p(\mathbf{y} \mid \mathcal{M})\right) \\
&\approx \arg\min_{\mathcal{M}} \frac{1}{N} \sum_{i=1}^N \left( \|\nabla_{\mathbf{y}} E(\mathbf{y}^{(i)})\|^2 + 2 \Delta_{\mathbf{y}} E(\mathbf{y}^{(i)}) \right)
\end{aligned}
\]
- \(\|\nabla_{\mathbf{y}} E(\mathbf{y})\|^2\) is the squared norm of the gradient of the energy function with respect to \(\mathbf{y}\),
- \(\Delta_{\mathbf{y}} E(\mathbf{y}) = \left( \nabla_{\mathbf{y}}^T \nabla_{\mathbf{y}} \right) E(\mathbf{y}) = \text{trace}(\nabla_{\mathbf{y}} \nabla_{\mathbf{y}}^T E(\mathbf{y}))\) is the Laplacian of the energy function,
- \(\nabla_{\mathbf{y}} \nabla_{\mathbf{y}}^T E(\mathbf{y})\) is the Hessian matrix of the energy function with respect to \(\mathbf{y}\), and
- \(N\) is the number of training samples.
Stein Score Matching
\[
\begin{aligned}
\mathcal{M}^* & \approx \arg\min_{\mathcal{M}} \frac{1}{N} \sum_{i=1}^N \left( \|\nabla_{\mathbf{y}} E(\mathbf{y}^{(i)})\|^2 + 2 \Delta_{\mathbf{y}} E(\mathbf{y}^{(i)}) \right)
\end{aligned}
\]
- To minimize the above objective, gradient-based optimization methods can be used.
- To calculate the gradient of the objective with respect to the model parameters, we don’t need to compute the full Hessian matrix.
- Reverse-mode automatic differentiation tools can compute the Laplacian \(\Delta_{\mathbf{y}} E(\mathbf{y})\) and its gradient with respect to the model parameters efficiently.
Sliced Stein Score Matching
- We can write the following expression as a single trace: \[
\begin{aligned}
\|\nabla_{\mathbf{y}} E(\mathbf{y})\|^2 + 2 \Delta_{\mathbf{y}} E(\mathbf{y}) &= \text{trace}\left( \nabla_{\mathbf{y}} E(\mathbf{y}) \cdot \nabla_{\mathbf{y}}^T E(\mathbf{y}) + 2 \nabla_{\mathbf{y}} \nabla_{\mathbf{y}}^T E(\mathbf{y}) \right)
\end{aligned}
\]
- In sliced Stein score matching, we approximate the above trace using random projections. This is known as Hutchinson’s trace estimator. \[
\text{trace}(\mathbf{A}) = \mathbb{E}_{\mathbf{r} \sim \mathcal{R}^D} \left[ \mathbf{r}^T \mathbf{A} \mathbf{r} \right] \approx \frac{1}{M} \sum_{m=1}^M \mathbf{r}_m^T \mathbf{A} \mathbf{r}_m
\] where \(\mathcal{R}^D\) is a distribution with zero mean and identity covariance, e.g. \(\mathcal{N}(\mathbf{0}, \mathbf{I})\) or the Rademacher distribution, and \(M\) is the number of random projections.
- This avoids computing the full Hessian matrix or Stein score function, making the training more efficient.
Fitting Continuous-Time Diffusion Models
Continuous-Time Diffusion Models
Denoising Score Matching
- Recall that to train continuous-time diffusion models, we need to approximate the Stein score of the noisy data distribution at any noise level, from \(t = 0\) (no noise) to \(t = 1\) (high noise).
- Just like neural network based EBMs, we can approximate the Stein score of the noisy data distribution using a neural network.
- This is equivalent to the following EBM: \[
p(\tilde{\mathbf{y}} \mid t) = \frac{e^{-E(\tilde{\mathbf{y}}, t)}}{Z(t)}
\] where \(\tilde{\mathbf{y}}\) is the noisy data at noise level \(t\).
- We can use a single neural network to parameterize the energy function \(E(\tilde{\mathbf{y}}, t)\) for all noise levels \(t \in [0, 1]\).
Denoising Score Matching
- Recall that in regular score matching, the objective is to minimize the Fisher divergence between the data distribution and the model distribution: \[
\begin{aligned}
\mathcal{M}^* &= \arg\min_{\mathcal{M}} D_F\!\left(p_{\text{data}}(\mathbf{y}) \,\|\, p(\mathbf{y} \mid \mathcal{M})\right) \\
&= \arg\min_{\mathcal{M}} \int p_{\text{data}}(\mathbf{y}) \left\| \nabla_{\mathbf{y}} \log p_{\text{data}}(\mathbf{y}) - \nabla_{\mathbf{y}} \log p(\mathbf{y} \mid \mathcal{M}) \right\|^2 d\mathbf{y} \\
&= \arg\min_{\mathcal{M}} \int p_{\text{data}}(\mathbf{y}) \left\| \nabla_{\mathbf{y}} \log p_{\text{data}}(\mathbf{y}) + \nabla_{\mathbf{y}} E(\mathbf{y}) \right\|^2 d\mathbf{y} \\
&= \arg\min_{\mathcal{M}} \int p_{\text{data}}(\mathbf{y}) \left( \|\nabla_{\mathbf{y}} E(\mathbf{y})\|^2 + 2 \Delta_{\mathbf{y}} E(\mathbf{y}) \right) d\mathbf{y}
\end{aligned}
\]
- The final integral/expectation is then approximated using the available samples from the data distribution \(p_{\text{data}}(\mathbf{y})\).
Denoising Score Matching
Similarly, minimizing the Fisher divergence between the noisy data distribution and the model distribution for a specific noise level \(t\) is equivalent to minimizing the following objective: \[
\begin{aligned}
\mathcal{M}^* &= \arg\min_{\mathcal{M}} \int p_{\text{noisy}}(\tilde{\mathbf{y}} \mid t) \left( \|\nabla_{\tilde{\mathbf{y}}} E(\tilde{\mathbf{y}}, t \mid \mathcal{M})\|^2 + 2 \Delta_{\tilde{\mathbf{y}}} E(\tilde{\mathbf{y}}, t \mid \mathcal{M}) \right) d\tilde{\mathbf{y}} \\
&= \arg\min_{\mathcal{M}} \mathbb{E}_{\tilde{\mathbf{y}} \sim p_{\text{noisy}}(\tilde{\mathbf{y}} \mid t)} \left[ \|\nabla_{\tilde{\mathbf{y}}} E(\tilde{\mathbf{y}}, t \mid \mathcal{M})\|^2 + 2 \Delta_{\tilde{\mathbf{y}}} E(\tilde{\mathbf{y}}, t \mid \mathcal{M}) \right]
\end{aligned}
\]
where \(p_{\text{noisy}}(\tilde{\mathbf{y}} \mid t)\) is the noisy data distribution at noise level \(t\).
For all noise levels combined, the objective becomes: \[
\begin{aligned}
\mathcal{M}^* &= \arg\min_{\mathcal{M}} \mathbb{E}_{t \sim \lambda(t), \space \tilde{\mathbf{y}} \sim p_{\text{noisy}}(\tilde{\mathbf{y}} \mid t)} \left[ \|\nabla_{\tilde{\mathbf{y}}} E(\tilde{\mathbf{y}}, t \mid \mathcal{M})\|^2 + 2 \Delta_{\tilde{\mathbf{y}}} E(\tilde{\mathbf{y}}, t \mid \mathcal{M}) \right]
\end{aligned}
\]
where \(\lambda(t)\) is the probability density function of the noise level \(t\), e.g. uniform over \([0, 1]\).
Denoising Score Matching
\[
\begin{aligned}
\mathcal{M}^* &= \arg\min_{\mathcal{M}} \mathbb{E}_{t \sim \lambda(t), \space \tilde{\mathbf{y}} \sim p_{\text{noisy}}(\tilde{\mathbf{y}} \mid t)} \left[ \|\nabla_{\tilde{\mathbf{y}}} E(\tilde{\mathbf{y}}, t \mid \mathcal{M})\|^2 + 2 \Delta_{\tilde{\mathbf{y}}} E(\tilde{\mathbf{y}}, t \mid \mathcal{M}) \right]
\end{aligned}
\]
Given a dataset of \(N\) i.i.d. samples from the data distribution, \(\{\mathbf{y}^{(i)}\}_{i=1}^N \sim p_{\text{data}}(\mathbf{y})\), we can minimize the above objective using Monte Carlo sampling and stochastic gradient descent:
- Sample a noise level \(t\) from its distribution \(\lambda(t)\), e.g. uniform over \([0, 1]\).
- Sample a noisy data point \(\tilde{\mathbf{y}}\) from \(p(\tilde{\mathbf{y}} \mid t)\).
- Sample a data point \(\mathbf{y}^{(i)}\) from the training data.
- Sample a noisy data point \(\tilde{\mathbf{y}}\) from \(p(\tilde{\mathbf{y}} \mid \mathbf{y}^{(i)}, t)\) using the forward SDE which usually has a closed form Gaussian solution.
- Using the sampled \(\tilde{\mathbf{y}}\) and \(t\), compute the above expectand and its gradient with respect to the model parameters using automatic differentiation tools.
- Update the model parameters using a stochastic gradient method, e.g. Adam.
Fitting Generative Adversarial Networks (GANs)
Generative Adversarial Networks (GANs)
GANs take the opposite approach to everything we have seen so far. Instead of working with a tractable density (exact, bounded, or unnormalized), they give up on the density entirely and train via a classifier that learns to tell real samples apart from generated ones.
- Recall that a GAN consists of a generator \(G\) that maps noise \(\mathbf{z} \sim p(\mathbf{z})\) to a sample \(\mathbf{y} = G(\mathbf{z})\), and a discriminator \(D\) that outputs the probability that a given \(\mathbf{y}\) is real rather than generated.
- The generator defines a deterministic mapping (not a conditional likelihood), so the induced distribution \(p_G(\mathbf{y})\) does not have a tractable log density. We cannot fit GANs by MLE.
- The discriminator outputs the parameter of a Bernoulli distribution of whether the input data point \(\mathbf{y}\) is real or generated: \[
\mathbf{y} \text{ is real} \sim \text{Bernoulli}(D(\mathbf{y}))
\]
Generative Adversarial Networks (GANs)
- We know the labels during training! Real samples are labeled as 1 and generated samples as 0.
- The discriminator can be any probabilistic binary classifier with a differentiable log likelihood; in practice it is typically a neural network classifier.
Adversarial Learning
- Given a dataset of real and generated samples \(\{(\mathbf{y}^{(i)}, \text{label}^{(i)})\}_{i=1}^N\), the likelihood of the discriminator is: \[
\mathcal{L}(D) = \prod_{i=1}^N p(\text{label}^{(i)} \mid \mathbf{y}^{(i)}) = \prod_{i=1}^N D(\mathbf{y}^{(i)})^{\text{label}^{(i)}} (1 - D(\mathbf{y}^{(i)}))^{1 - \text{label}^{(i)}}
\]
- The log likelihood is: \[
\log \mathcal{L}(D) = \sum_{i=1}^N \left( \text{label}^{(i)} \log D(\mathbf{y}^{(i)}) + (1 - \text{label}^{(i)}) \log(1 - D(\mathbf{y}^{(i)})) \right)
\]
- The negative log likelihood of the Bernoulli distribution is also known as the binary cross-entropy loss.
Adversarial Learning
- Since any generated \(\mathbf{y}\) can be written as \(\mathbf{y} = G(\mathbf{z})\) where \(\mathbf{z} \sim p(\mathbf{z})\), we can rewrite the log likelihood as: \[
\begin{aligned}
\log \mathcal{L}(D) &= \sum_{i=1}^{N_{\text{real}}} \log D(\mathbf{y}^{(i)}) + \sum_{j=1}^{N_{\text{gen}}} \log(1 - D(G(\mathbf{z}^{(j)}))
\end{aligned}
\] where \(\{\mathbf{y}^{(i)}\}_{i=1}^{N_{\text{real}}}\) are real samples from the data distribution and \(\{\mathbf{z}^{(j)}\}_{j=1}^{N_{\text{gen}}}\) are random noise samples from the prior distribution.
- Taking the average over the real samples separately from the generated samples, we can write the log likelihood as: \[
\begin{aligned}
\log \mathcal{L}(D) &= \frac{1}{N_{\text{real}}} \sum_{i=1}^{N_{\text{real}}} \log D(\mathbf{y}^{(i)}) + \frac{1}{N_{\text{gen}}} \sum_{j=1}^{N_{\text{gen}}} \log(1 - D(G(\mathbf{z}^{(j)}))
\end{aligned}
\]
Adversarial Learning
- In the limit as \(N_{\text{gen}}\) goes to infinity, we can re-write the above expression in terms of expectations over the data distribution and the prior distribution. \[
\begin{aligned}\log \mathcal{L}(D) &= \mathbb{E}_{\mathbf{y} \sim p_{\text{data}}(\mathbf{y})} [\log D(\mathbf{y})] + \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})} [\log(1 - D(G(\mathbf{z})))]\end{aligned}
\] where \(p_{\text{data}}(\mathbf{y})\) is the empirical distribution of the real data and \(p(\mathbf{z})\) is the prior distribution over the latent space.
- The optimal discriminator is one that maximizes the average log likelihood: \[
D^* = \arg \max_D \mathbb{E}_{\mathbf{y} \sim p_{\text{data}}(\mathbf{y})} [\log D(\mathbf{y})] + \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})} [\log(1 - D(G(\mathbf{z})))]
\]
Adversarial Learning
- The optimal generator is one that maximally confuses the discriminator!
- Confusing the discriminator amounts to minimizing the average log likelihood of the discriminator: \[
G^* = \arg \min_G \mathbb{E}_{\mathbf{y} \sim p_{\text{data}}(\mathbf{y})} [\log D(\mathbf{y})] + \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})} [\log(1 - D(G(\mathbf{z})))]
\]
- The training process can be formulated as a minimax game: \[
G^*, D^* = \arg \min_G \max_D \mathbb{E}_{\mathbf{y} \sim p_{\text{data}}(\mathbf{y})} [\log D(\mathbf{y})] + \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})} [\log(1 - D(G(\mathbf{z})))]
\]
Adversarial Learning
\[
G^*, D^* = \arg \min_G \max_D \mathbb{E}_{\mathbf{y} \sim p_{\text{data}}(\mathbf{y})} [\log D(\mathbf{y})] + \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})} [\log(1 - D(G(\mathbf{z})))]
\]
- Both \(G\) and \(D\) are typically parameterized using neural networks.
- We can optimize the above objective using stochastic gradient methods, alternating between updating \(D\) and \(G\).
- Typically, we perform multiple updates of \(D\) for each update of \(G\).
- GANs are notoriously difficult to train due to instability and mode collapse.
- Mode collapse occurs when the generator produces a limited variety of samples, failing to capture the diversity of the real data distribution.
Jensen-Shannon (JS) Divergence
- The Jensen-Shannon (JS) divergence is a symmetric measure of similarity between two probability distributions \(p\) and \(q\).
- It is based on the Kullback-Leibler (KL) divergence but has some desirable properties, such as being symmetric and always having a finite value.
- The JS divergence is defined as: \[
\begin{aligned}
D_{JS}(p \| q) &= \frac{1}{2} D_{KL}(p \| m) + \frac{1}{2} D_{KL}(q \| m) \\
m &= \frac{1}{2}(p + q)
\end{aligned}
\]
Adversarial Learning as Minimizing JS Divergence
- Recall the minimax objective of GANs: \[
\begin{aligned}
\mathcal{L}(G, D) &= \mathbb{E}_{\mathbf{y} \sim p_{\text{data}}(\mathbf{y})} [\log D(\mathbf{y})] + \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})} [\log(1 - D(G(\mathbf{z})))]
\end{aligned}
\]
- We can re-write this to use the implicit distribution of the generator \(p_G(\mathbf{y})\) instead of the prior distribution \(p(\mathbf{z})\): \[
\begin{aligned}
\mathcal{L}(G, D) &= \mathbb{E}_{\mathbf{y} \sim p_{\text{data}}(\mathbf{y})} [\log D(\mathbf{y})] + \mathbb{E}_{\mathbf{y} \sim p_G(\mathbf{y})} [\log(1 - D(\mathbf{y}))] \\
&= \int p_{\text{data}}(\mathbf{y}) \log D(\mathbf{y}) d\mathbf{y} + \int p_G(\mathbf{y}) \log(1 - D(\mathbf{y})) d\mathbf{y} \\
&= \int \left( p_{\text{data}}(\mathbf{y}) \log D(\mathbf{y}) + p_G(\mathbf{y}) \log(1 - D(\mathbf{y})) \right) d\mathbf{y}
\end{aligned}
\]
- The optimal discriminator maximizes the above objective for a fixed generator.
Adversarial Learning as Minimizing JS Divergence
- We aim to analytically find the optimal discriminator \(D^*\) for a fixed generator \(G\). \[
\begin{aligned}
D^* &= \arg \max_D \int \left( p_{\text{data}}(\mathbf{y}) \log D(\mathbf{y}) + p_G(\mathbf{y}) \log(1 - D(\mathbf{y})) \right) d\mathbf{y}
\end{aligned}
\]
- To find the optimal discriminator, we can set the functional derivative with respect to \(D\) to zero.
- A point with zero derivative is known as a stationary point.
- Stationary points can be either maxima, minima, or saddle points.
- To find a stationary point, we set the functional derivative with respect to \(D\) to zero.
- The functional derivative generalizes the concept of a derivative to allow us to take derivatievs with respect an entire function, not just a single point.
- Intuitively, if \(D\) is a function with parameters \(\boldsymbol{\phi}\), the functional derivative is the gradient with respect to \(\boldsymbol{\phi}\).
Adversarial Learning as Minimizing JS Divergence
\[
\begin{aligned}
D^* &= \arg \max_D \int \left( p_{\text{data}}(\mathbf{y}) \log D(\mathbf{y}) + p_G(\mathbf{y}) \log(1 - D(\mathbf{y})) \right) d\mathbf{y}
\end{aligned}
\]
- Since \(D\) is an arbitrary function of \(\mathbf{y}\) and we want to maximize an integral with respect to \(\mathbf{y}\), setting the functional derivative to zero is equivalent to setting the derivative at each point \(\mathbf{y}\) to zero.
- The functional derivative of the integrand for a given point \(\mathbf{y}\) is: \[
\frac{\partial}{\partial D(\mathbf{y})} \left( p_{\text{data}}(\mathbf{y}) \log D(\mathbf{y}) + p_G(\mathbf{y}) \log(1 - D(\mathbf{y})) \right) = \frac{p_{\text{data}}(\mathbf{y})}{D(\mathbf{y})} - \frac{p_G(\mathbf{y})}{1 - D(\mathbf{y})}
\]
- Setting this to zero and solving for \(D(\mathbf{y})\) gives: \[
D^*(\mathbf{y}) = \frac{p_{\text{data}}(\mathbf{y})}{p_{\text{data}}(\mathbf{y}) + p_G(\mathbf{y})}
\]
Adversarial Learning as Minimizing JS Divergence
\[
D^*(\mathbf{y}) = \frac{p_{\text{data}}(\mathbf{y})}{p_{\text{data}}(\mathbf{y}) + p_G(\mathbf{y})}
\]
- We can show that this stationary point is indeed a local maximum by checking that the second derivative is negative.
- Substituting the optimal discriminator back into the minimax objective gives: \[
\begin{aligned}
\mathcal{L}(G, D^*) &= \int p_{\text{data}}(\mathbf{y}) \log \left( \frac{p_{\text{data}}(\mathbf{y})}{p_{\text{data}}(\mathbf{y}) + p_G(\mathbf{y})} \right) d\mathbf{y} + \int p_G(\mathbf{y}) \log \left( \frac{p_G(\mathbf{y})}{p_{\text{data}}(\mathbf{y}) + p_G(\mathbf{y})} \right) d\mathbf{y}
\end{aligned}
\]
- Let the mixture distribution of the true data distribution and the generator distribution be: \[
m(\mathbf{y}) = \frac{1}{2}(p_{\text{data}}(\mathbf{y}) + p_G(\mathbf{y}))
\]
Adversarial Learning as Minimizing JS Divergence
- The first term can be re-written as: \[
\begin{aligned}
\int p_{\text{data}}(\mathbf{y}) \log \left( \frac{p_{\text{data}}(\mathbf{y})}{p_{\text{data}}(\mathbf{y}) + p_G(\mathbf{y})} \right) d\mathbf{y} &= \int p_{\text{data}}(\mathbf{y}) \log \left( \frac{p_{\text{data}}(\mathbf{y})}{2 m(\mathbf{y})} \right) d\mathbf{y} \\
&= \int p_{\text{data}}(\mathbf{y}) \log \left( \frac{p_{\text{data}}(\mathbf{y})}{m(\mathbf{y})} \right) d\mathbf{y} - \int p_{\text{data}}(\mathbf{y}) \log (2) d\mathbf{y} \\
&= D_{KL}(p_{\text{data}}(\mathbf{y}) \| m(\mathbf{y})) - \log (2)
\end{aligned}
\]
- The second term can be re-written similarly: \[
\begin{aligned}
\int p_G(\mathbf{y}) \log \left( \frac{p_G(\mathbf{y})}{p_{\text{data}}(\mathbf{y}) + p_G(\mathbf{y})} \right) d\mathbf{y} &= \int p_G(\mathbf{y}) \log \left( \frac{p_G(\mathbf{y})}{2 m(\mathbf{y})} \right) d\mathbf{y} \\
&= \int p_G(\mathbf{y}) \log \left( \frac{p_G(\mathbf{y})}{m(\mathbf{y})} \right) d\mathbf{y} - \log (2) \\
&= D_{KL}(p_G(\mathbf{y}) \| m(\mathbf{y})) - \log (2)
\end{aligned}
\]
Adversarial Learning as Minimizing JS Divergence
- Recall that the Jensen-Shannon divergence is defined as: \[
D_{JS}(p \| q) = \frac{1}{2} D_{KL}(p \| m) + \frac{1}{2} D_{KL}(q \| m) \quad \text{where} \quad m = \frac{1}{2}(p + q)
\]
- Plugging in the above results, the optimal generator minimizes the following objective: \[
\begin{aligned}
G^* &= \arg \min_G \int p_{\text{data}}(\mathbf{y}) \log \left( \frac{p_{\text{data}}(\mathbf{y})}{p_{\text{data}}(\mathbf{y}) + p_G(\mathbf{y})} \right) d\mathbf{y} + \int p_G(\mathbf{y}) \log \left( \frac{p_G(\mathbf{y})}{p_{\text{data}}(\mathbf{y}) + p_G(\mathbf{y})} \right) d\mathbf{y} \\
&= \arg \min_G D_{KL}(p_{\text{data}}(\mathbf{y}) \| m(\mathbf{y})) + D_{KL}(p_G(\mathbf{y}) \| m(\mathbf{y})) - 2\log(2) \\
&= \arg \min_G 2 D_{JS}(p_{\text{data}}(\mathbf{y}) \| p_G(\mathbf{y})) \\
&= \arg \min_G D_{JS}(p_{\text{data}}(\mathbf{y}) \| p_G(\mathbf{y}))
\end{aligned}
\]
- The JS divergence equivalence assumes an optimal discriminator that is a stationary point of the objective. In practice, getting the discriminator to converge may be difficult.
Fitting Wasserstein GANs (WGANs)
Optimal transport and Wasserstein Distance
- The JS divergence behaves badly when the data and generator distributions have nearly disjoint supports.
- The Wasserstein distance avoids this problem by measuring distances in \(\mathbf{y}\)-space rather than density ratios.
- Optimal transport is a mathematical framework for comparing probability distributions by measuring the cost of transforming one distribution into another.
- The Wasserstein distance (also known as the Earth Mover’s Distance) is a metric that quantifies the distance between two probability distributions based on the optimal transport cost.
Optimal transport and Wasserstein Distance
- The Wasserstein distance of order 1 between two probability distributions \(p\) and \(q\) is defined as: \[
W_1(p, q) = \inf_{\gamma \in \Pi(p, q)} \mathbb{E}_{(x, y) \sim \gamma} [\|x - y\|]
\] where \(\Pi(p, q)\) is the set of all joint distributions \(\gamma(x, y)\) whose marginals are \(p\) and \(q\) respectively.
- More generally, the Wasserstein distance of order \(r\) is defined as: \[
W_r(p, q) = \left( \inf_{\gamma \in \Pi(p, q)} \mathbb{E}_{(x, y) \sim \gamma} [\|x - y\|^r] \right)^{1/r}
\]
Wasserstein GANs (WGANs)
- WGANs are a variant of GANs that use the Wasserstein distance as the objective function instead of the Jensen-Shannon divergence.
- The WGAN objective is defined as: \[
\begin{aligned}
G^*, C^* = \arg \min_{G} \max_{\|C\|_{\text{Lip}} \leq 1} \mathbb{E}_{\mathbf{y} \sim p_{\text{data}}(\mathbf{y})} [C(\mathbf{y})] - \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})} [C(G(\mathbf{z}))]
\end{aligned}
\]
- \(C\) replaces the classifier/discriminator in regular GANs and is called the critic function. \(C\) outputs a real-valued score instead of a probability.
- The critic is trained to assign higher scores to real samples and lower scores to generated samples.
- The critic \(C\) is typically parameterized using a neural network.
- The WGAN objective has better theoretical properties and is more stable to train compared to the original GAN objective.
Wasserstein GANs (WGANs)
- \(\|C\|_{\text{Lip}} \leq 1\) means that the function \(C\) is constrained to have a Lipschitz constant of at most 1.
- The Lipschitz constant of a function \(f\) is defined as: \[\|f\|_{\text{Lip}} = \sup_{x_1 \neq x_2} \frac{|f(x_1) - f(x_2)|}{\|x_1 - x_2\|}\]
- Intuitively, the Lipschitz constant measures how much the function can change in response to changes in its input.
- For differentiable functions with finite gradients, the Lipschitz constant is equal to the maximum norm of the gradient: \[\|f\|_{\text{Lip}} = \sup_{x} \|\nabla f(x)\|\]
- The 1-Lipschitz constraint can be enforced using weight clipping or gradient penalty (recommended) during training.
Wasserstein GANs (WGANs)
- The WGAN objective is the strong dual problem of another primal optimization problem.
- Duality is a fundamental concept in optimization theory that relates a given optimization problem (the primal problem) to another derived optimization problem (the dual problem).
- The strong duality means that solving either problem is equivalent to solving the other and the objective values of the primal and dual optimization problems are equal at their respective optima.
- The primal problem in the WGAN case is the problem of minimizing the Wasserstein distance between the real data distribution and the generator distribution. \[
\begin{aligned}
G^* &= \arg \min_G W_1(p_{\text{data}}(\mathbf{y}), p_G(\mathbf{y}))
\end{aligned}
\]
- The WGAN objective can be derived from the Kantorovich-Rubinstein duality theorem (out of scope).
Fitting Flow-Based Models
Normalizing Flows (NFs)
- Normalizing flows give us back what GANs sacrificed: an exact, tractable log density.
- The trick is to restrict the generator to a sequence of invertible, differentiable transformations so that the change-of-variables formula applies.
- Recall: a normalizing flow \(\mathbf{y} = f(\mathbf{z}) = f_K \circ \ldots \circ f_1(\mathbf{z})\) maps a base sample \(\mathbf{z} \sim p_Z(\mathbf{z})\) through a sequence of invertible, differentiable transformations.
- Similar to GANs, NFs define an implicit distribution \(p_Y(\mathbf{y})\) through a deterministic mapping \(\mathbf{y} = f(\mathbf{z})\).
- Unlike GANs, the likelihood of the generated samples can be computed exactly using the change of variables formula because the transformations are invertible and differentiable: \[
\begin{aligned}
p_Y(\mathbf{y}) &= p_Z(\mathbf{z}) \left| \det \left( \frac{\partial f^{-1}(\mathbf{y})}{\partial \mathbf{y}} \right) \right| \\
\mathbf{z} &= f^{-1}(\mathbf{y}) = f_1^{-1} \circ f_2^{-1} \circ \ldots \circ f_K^{-1}(\mathbf{y})
\end{aligned}
\]
Normalizing Flows (NFs)
- This makes NFs almost as useful as any standard continuous probability distribution (e.g. Gaussian, log normal or Beta) and enables both density estimation and sample generation.
- Because the log likelihood can be computed exactly, NFs can be trained using standard maximum likelihood estimation (MLE).
- Given a dataset of \(N\) i.i.d. samples from the data distribution, \(\{\mathbf{y}^{(i)}\}_{i=1}^N \sim p_{\text{data}}(\mathbf{y})\), the MLE objective is: \[\mathcal{M}^* = \arg \max_{\mathcal{M}} \sum_{i=1}^N \log p(\mathbf{y}^{(i)} \mid \mathcal{M})\]
- The log likelihood of each sample can be computed using the change of variables formula.
- The model parameters can be optimized using stochastic gradient methods, e.g. Adam.
Continuous Normalizing Flows (CNFs)
- CNFs can be written as: \[
\begin{aligned}
\frac{d\mathbf{x}(t)}{dt} &= f(\mathbf{x}(t), t) \\
\mathbf{x}(0) &\sim p_Z(\mathbf{z}) \\
\mathbf{y} &= \mathbf{x}(T)
\end{aligned}
\] where \(f(\mathbf{x}, t)\) is a time-dependent vector field parameterized by a neural network and \(T\) is the total integration time.
- Similar to NFs, continuous normalizing flows (CNFs) also give us an exact, tractable log density.
- Training a CNF by MLE is possible because we can compute the log likelihood using an ODE solver.
Flow Matching
However when the CNF uses a complex neural network to parameterize the vector field \(f(\mathbf{x}, t)\), MLE training can be slow because:
- Each evaluation of the log likelihood requires solving an ODE, and
- Calculating the gradient of the log likelihood requires solving another ODE backward in time.
Flow matching trains CNFs without solving an ODE at training time.
The idea is to match the vector field \(f(\mathbf{x}, t)\) to a reference vector field that is only known at training time.
The reference vector field is defined such that it produces correct samples from the data distribution when integrated over time.
Flow Matching
- First, let’s re-define the CNF to map from the data distribution to the base distribution instead of the other way around.
- The sampling procedure then requires solving a final value problem instead of an initial value problem: \[
\begin{aligned}
\frac{d\mathbf{x}(t)}{dt} &= f(\mathbf{x}(t), t) \\
\mathbf{x}(T) &\sim p_Z(\mathbf{z}) \\
\mathbf{y} &= \mathbf{x}(0)
\end{aligned}
\]
- This is an equally valid CNF that can be trained by MLE, but it is more convenient to think about the flow matching algorithm in this direction.
Flow Matching
The simplest reference vector field and associated reference CNF are: \[
\begin{aligned}
\frac{d\mathbf{x}(t)}{dt} &= \mathbf{z} - \mathbf{y} \\
\mathbf{x}(0) &= \mathbf{y} \\
\mathbf{y} &\sim p_{\text{data}}(\mathbf{y}) \\
\mathbf{z} &\sim p_Z(\mathbf{z})
\end{aligned}
\]
By construction, sampling \(\mathbf{y}\) from the data distribution and \(\mathbf{z}\) from the base distribution and integrating the above ODE from \(t=0\) to \(t=1\) guarantees that:
- The solution at time \(t=0\) is a sample from the data distribution.
- The solution at time \(t=1\) is a sample from the base distribution.
Note that this is cheating! We are using the data distribution to define the reference CNF.
But we have an empirical approximation of the data distribution during training.
Flow Matching
\[
\begin{aligned}
\frac{d\mathbf{x}(t)}{dt} &= \mathbf{z} - \mathbf{y} \\
\mathbf{x}(0) &= \mathbf{y} \\
\mathbf{y} &\sim p_{\text{data}}(\mathbf{y}) \\
\mathbf{z} &\sim p_Z(\mathbf{z})
\end{aligned}
\]
- This trivial reference CNF behaves like our desired CNF that we would like to learn, i.e. it generates correct samples when integrated over time.
- However, obviously, the final trained CNF should not depend on the data samples because we want to use it at test time when we do not have access to samples from the data distribution.
Flow Matching
- A simple training algorithm is to try to match the vector field of the CNF model (that we are trying to learn) to the reference vector field.
- This can be done by minimizing the following objective: \[
\mathcal{M}^* = \arg \min_{\mathcal{M}} \mathbb{E}_{t \sim \lambda(t), \space \mathbf{y} \sim p_{\text{data}}(\mathbf{y}), \space \mathbf{z} \sim p_Z(\mathbf{z})} \left[ \|f(\mathbf{x}(t), t \mid \mathcal{M}) - (\mathbf{z} - \mathbf{y})\|^2 \right]
\] where \(\lambda(t)\) is a distribution over time, e.g. uniform distribution over \([0, 1]\).
- To calculate \(\mathbf{x}(t)\), we can use the closed-form solution of the reference ODE: \[\mathbf{x}(t) = \mathbf{y} + t(\mathbf{z} - \mathbf{y})\]
- Note that we sample data points \(\mathbf{y}\) and latent points \(\mathbf{z}\) independently at each training step.
- So all pairs of \((\mathbf{y}, \mathbf{z})\) are possible, not just those that correspond to each other through the CNF.
Flow Matching
\[
\begin{aligned}
\mathcal{M}^* &= \arg \min_{\mathcal{M}} \mathbb{E}_{t \sim \lambda(t), \space \mathbf{y} \sim p_{\text{data}}(\mathbf{y}), \space \mathbf{z} \sim p_Z(\mathbf{z})} \left[ \|f(\mathbf{x}(t), t \mid \mathcal{M}) - (\mathbf{z} - \mathbf{y})\|^2 \right] \\
\mathbf{x}(t) &= \mathbf{y} + t(\mathbf{z} - \mathbf{y})
\end{aligned}
\]
- It is not immediately obvious why minimizing the above objective would lead to a good generative model that matches the data distribution.
- Proving that the above objective converges to a good generative model is the main contribution of the flow matching paper (Lipman et al., 2023).
- Other couplings of \((\mathbf{y}, \mathbf{z})\) are possible, e.g. using optimal transport to find a coupling that minimizes the expected distance between \(\mathbf{y}\) and \(\mathbf{z}\).
- Any joint distribution (coupling) of \((\mathbf{y}, \mathbf{z})\) that has the correct marginals can be used in flow matching.
- Using a better coupling can lead to simpler vector fields (closer to linear) that are faster to integrate numerically when generating samples.