Markov Chain Monte Carlo Convergence

Authors

Jose Storopoli

Mohamed Tarek

using Pumas

MCMC has an interesting property that it will asymptotically converge to the target distribution.

Note

MCMC asymptotically convergence stems from the Central Limit Theorem (CLT) for Makov Chains. The assumption is that the parameters have finite variance (inherited from the CLT) and that the chain is reversible (also known as detailed balance)

If you are curious about MCMC convergence theory please check Brooks et al. (2011).

That means, if we have all the time in the world, it is guaranteed, irrelevant of the target distribution posterior geometry, MCMC will give you the right answer.

However, we don’t have all the time in the world. Different MCMC algorithms, like HMC and NUTS, can reduce the sampling (and warmup) time necessary for convergence to the target distribution.

Note

Markov chain Monte Carlo (MCMC) samplers have the ability of approximating any target posterior density using approximations with discrete sampling in an iterative manner. The target posterior density is explored by proposing new parameters values using an acception/rejection rule.

Modern MCMC samplers, such as Hamiltoninan Monte Carlo (HMC), uses the log posterior gradient to guide the proposals to areas of high probability, making the posterior exploration and approximation more efficient than random-walk proposals. However, this guidance scheme depends on two hyperparameters:

  • \(L\): the number of steps to use in guiding the proposal towards areas of high probability
  • \(\epsilon\): the step-size of each step taken into guiding the proposal towards areas of high probability

No-U-Turn-Sampler (NUTS) (Hoffman & Gelman, 2011) is a MCMC Sampler that uses HMC and can automatically tune \(\epsilon\). This tuning happens during the adaptation phase of the sampler (also called “warm-up”), and, once is done, \(\epsilon\) is fixed for the sampling phase while \(L\) is dynamically chosen in each iteration. The sampling phase is the phase that generates the samples used for the MCMC target posterior density approximation.

For most models, NUTS performs as expected and converges fast to the target posterior density. Nevertheless, there is a common posterior density topology that NUTS (and almost all of MCMC samplers) struggle to explore, sample, and approximate.

1 Can We Prove Convergence?

In the ideal scenario, the NUTS sampler converges to the true posterior and doesn’t miss on any mode.

Unfortunately, this is not easy to prove in general. All the convergence diagnostics are only tests for symptoms of lack of convergence. In other words if all the diagnostics look normal, then we can’t prove that the sampler didn’t converge.

2 Signs of Lack of Convergence

Some signs of lack of convergence are:

  • Any of the moments (e.g. the mean or standard deviation) is changing with time. This is diagnosed using stationarity tests by comparing different parts of a single chain to each other.
  • Any of the moments is sensitive to the initial parameter values. This is diagnosed using multiple chains by comparing their summary statistics to each other.

While high auto-correlation is not strictly a sign of lack of convergence, samplers with high auto-correlation will require many more samples to get to the same efficiency as another sampler with low auto-correlation. So a low auto-correlation is usually more desirable.

To showcase and give applied explanations of MCMC convergence, let’s recall the model from the Introduction to Bayesian Models in Pumas tutorial:

pk_1cmp = @model begin
    @param begin
        tvcl ~ LogNormal(log(3.2), 1)
        tvv ~ LogNormal(log(16.4), 1)
        tvka ~ LogNormal(log(3.8), 1)
        ω²cl ~ LogNormal(log(0.04), 0.25)
        ω²v ~ LogNormal(log(0.04), 0.25)
        ω²ka ~ LogNormal(log(0.04), 0.25)
        σ_p  LogNormal(log(0.2), 0.25)
    end
    @random begin
        ηcl ~ Normal(0, sqrt(ω²cl))
        ηv ~ Normal(0, sqrt(ω²v))
        ηka ~ Normal(0, sqrt(ω²ka))
    end
    @covariates begin
        Dose
    end
    @pre begin
        CL = tvcl * exp(ηcl)
        Vc = tvv * exp(ηv)
        Ka = tvka * exp(ηka)
    end
    @dynamics Depots1Central1
    @derived begin
        cp := @. Central / Vc
        Conc ~ @. Normal(cp, abs(cp) * σ_p)
    end
end
┌ Warning: Covariate Dose is not used in the model.
└ @ Pumas ~/_work/PumasTutorials.jl/PumasTutorials.jl/custom_julia_depot/packages/Pumas/aZRyj/src/dsl/model_macro.jl:2856
PumasModel
  Parameters: tvcl, tvv, tvka, ω²cl, ω²v, ω²ka, σ_p
  Random effects: ηcl, ηv, ηka
  Covariates: Dose
  Dynamical system variables: Depot, Central
  Dynamical system type: Closed form
  Derived: Conc
  Observed: Conc
using PharmaDatasets
pkpain_df = dataset("pk_painrelief");
using DataFramesMeta
@chain pkpain_df begin
    @rsubset! :Dose != "Placebo"
    @rtransform! begin
        :amt = :Time == 0 ? parse(Int, chop(:Dose; tail = 3)) : missing
        :evid = :Time == 0 ? 1 : 0
        :cmt = :Time == 0 ? 1 : 2
    end
    @rtransform! :Conc = :evid == 1 ? missing : :Conc
end;
first(pkpain_df, 5)
5×10 DataFrame
Row Subject Time Conc PainRelief PainScore RemedStatus Dose amt evid cmt
Int64 Float64 Float64? Int64 Int64 Int64 String7 Int64? Int64 Int64
1 1 0.0 missing 0 3 1 20 mg 20 1 1
2 1 0.5 1.15578 1 1 0 20 mg missing 0 2
3 1 1.0 1.37211 1 0 0 20 mg missing 0 2
4 1 1.5 1.30058 1 0 0 20 mg missing 0 2
5 1 2.0 1.19195 1 1 0 20 mg missing 0 2
pop =
    pkpain_noplb = read_pumas(
        pkpain_df,
        id = :Subject,
        time = :Time,
        amt = :amt,
        observations = [:Conc],
        covariates = [:Dose],
        evid = :evid,
        cmt = :cmt,
    )
Population
  Subjects: 120
  Covariates: Dose
  Observations: Conc
pk_1cmp_fit = fit(
    pk_1cmp,
    pop,
    init_params(pk_1cmp),
    BayesMCMC(; nsamples = 1_000, nadapts = 500, constantcoef = (; tvka = 2)),
);
pk_1cmp_tfit = Pumas.truncate(pk_1cmp_fit; burnin = 500)
[ Info: Checking the initial parameter values.
[ Info: The initial log probability and its gradient are finite. Check passed.
[ Info: Checking the initial parameter values.
[ Info: Checking the initial parameter values.
[ Info: The initial log probability and its gradient are finite. Check passed.
[ Info: The initial log probability and its gradient are finite. Check passed.
[ Info: Checking the initial parameter values.
[ Info: The initial log probability and its gradient are finite. Check passed.
Chains MCMC chain (500×6×4 Array{Float64, 3}):

Iterations        = 1:1:500
Number of chains  = 4
Samples per chain = 500
Wall duration     = 481.94 seconds
Compute duration  = 1794.83 seconds
parameters        = tvcl, tvv, ω²cl, ω²v, ω²ka, σ_p

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

        tvcl    3.1922    0.0804    0.0078    107.2489    248.2158    1.0169   ⋯
         tvv   13.2465    0.2635    0.0171    234.8128    562.5041    1.0059   ⋯
        ω²cl    0.0727    0.0082    0.0005    334.1100    490.7868    1.0080   ⋯
         ω²v    0.0465    0.0058    0.0003    532.5855    820.4850    1.0013   ⋯
        ω²ka    1.1000    0.1429    0.0033   1898.9600   1599.7130    0.9997   ⋯
         σ_p    0.1040    0.0024    0.0001   2151.0237   1358.5256    1.0072   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

        tvcl    3.0269    3.1385    3.1962    3.2492    3.3368
         tvv   12.7466   13.0595   13.2487   13.4240   13.7655
        ω²cl    0.0584    0.0671    0.0719    0.0775    0.0908
         ω²v    0.0366    0.0424    0.0462    0.0502    0.0586
        ω²ka    0.8516    0.9980    1.0888    1.1875    1.4094
         σ_p    0.0995    0.1024    0.1039    0.1056    0.1087

3 Diagnostics Plots

There are several diagnostics plots that can be used to assess convergence in a Markov chain:

  • Trace Plot
  • Cumulative Mean Plot
  • Auto-correlation Plot

We’ll dive into each one those plots.

using PumasUtilities

The trace plot of a parameter shows the value of the parameter in each iteration of the MCMC algorithm.

A good trace plot is one that:

  • is noisy, not an increase or decreasing line for example
  • has a fixed mean
  • has a fixed variance
  • Shows all chains overlapping with each other, i.e. chain mixing

Let’s see an example with tvcl and tvv:

trace_plot(pk_1cmp_tfit; parameters = [:tvcl, :tvv], linkyaxes = :none)

You can see that all of 4 chains have mixed well and the trace plot does not show any bias along with the iterations. This means that the chains were stationary along the sampling phase of the MCMC sampler.

The cumulative mean plot of a parameter shows the mean of the parameter value in each MCMC chain up to a certain iteration.

An MCMC chain converging to a stationary posterior distribution should have the cumulative mean of each parameter converge to a fixed value.

Let’s see an example with tvcl and tvv:

cummean_plot(pk_1cmp_tfit; parameters = [:tvcl, :tvv], linkyaxes = :none)

MCMC chains are prone to auto-correlation between the samples because each sample in the chain is a function of the previous sample.

The auto-correlation plot shows the correlation between every sample with index \(s\) and the corresponding sample with index \(s + \text{lag}\) for all \(s \in [1, N − \text{lag}]\) where \(N\) is the total number of samples.

For each value of lag, we can compute a correlation measure between the samples and their lag-steps-ahead counterparts. The correlation is usually a value between \(0\) and \(1\) but can sometimes be between \(-1\) and \(0\) as well.

The auto-correlation plot shows the lag on the \(x\)-axis and the correlation value on the \(y\)-axis.

For well behaving MCMC chains, when lag increases, the corresponding correlation gets closer to \(0\). This means that there is less and less correlation between any 2 samples further away from each other.

Let’s see an example with tvcl and tvv:

autocor_plot(pk_1cmp_tfit; parameters = [:tvcl, :tvv])

4 Convergence Metrics

There are a few metrics and diagnostics usually used to assess and diagnose the Markov chains:

  • Effective Sample Size (ESS): an approximation of the “number of independent samples” generated by a Markov chain.
  • \(\widehat{R}\) (Rhat): potential scale reduction factor, a metric to measure if the Markov chain have mixed, and, potentially, converged.
Tip

To find more about convergence metrics, please check Gelman et al. (2013) and McElreath (2020).

The formula for the effective sample size is:

\[\widehat{\eta}_{\text{eff}} = \frac{mn}{1 + \sum^T_{t=1}\widehat{\rho}_t}\]

where:

  • \(m\): number of Markov chains
  • \(n\): total samples per Markov chain (discarding warmup or burnin)
  • \(\widehat{\rho}_t\): an autocorrelation estimate

This formula is an approximation of the “number of independent samples” generated by a Markov chain, since we don’t have a way to recover the true autocorrelation \(\rho\). Instead we rely on an estimate \(\widehat{\rho}\).

To get any ESS from a Pumas Bayesian result from a fit/Pumas.truncate. You can use summarystats and check the :ess column:

summarystats(pk_1cmp_tfit)
Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

        tvcl    3.1922    0.0804    0.0078    107.2489    248.2158    1.0169   ⋯
         tvv   13.2465    0.2635    0.0171    234.8128    562.5041    1.0059   ⋯
        ω²cl    0.0727    0.0082    0.0005    334.1100    490.7868    1.0080   ⋯
         ω²v    0.0465    0.0058    0.0003    532.5855    820.4850    1.0013   ⋯
        ω²ka    1.1000    0.1429    0.0033   1898.9600   1599.7130    0.9997   ⋯
         σ_p    0.1040    0.0024    0.0001   2151.0237   1358.5256    1.0072   ⋯
                                                                1 column omitted

Or you can also use the ess function to return a vector of the parameters’ ESS according to the ordering of appearance in summarystats:

ess(pk_1cmp_tfit)
6-element Vector{Float64}:
  106.09258881683644
  236.98756500780098
  323.8463745293615
  523.378597607564
 1846.3342910080996
 2187.8781031059707

which makes convenient to apply summaryzing functions:

mean(ess(pk_1cmp_tfit))
870.7529200126055

The formula for Rhat is:

\[\widehat{R} = \sqrt{\frac{\widehat{\operatorname{var}}^+ \left( \psi \mid y \right)}{W}}\]

where \(\widehat{\operatorname{var}}^+ \left( \psi \mid y \right)\) is the Markov chains’ sample variance for a certain parameter \(\psi\). We calculate it by using a weighted sum of the within-chain \(W\) and between-chain \(B\) variances:

\[\widehat{\operatorname{var}}^+ \left( \psi \mid y \right) = \frac{n - 1}{n} W + \frac{1}{n} B\]

Intuitively, the value is \(\widehat{R} = 1.0\) if all chains are totally convergent. As a heuristic, if \(\widehat{R} > 1.1\), you need to worry because probably the chains have not converged adequately.

To get any Rhat from a Pumas Bayesian result from a fit/Pumas.truncate. You can use summarystats and check the :rhat column:

summarystats(pk_1cmp_tfit)
Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

        tvcl    3.1922    0.0804    0.0078    107.2489    248.2158    1.0169   ⋯
         tvv   13.2465    0.2635    0.0171    234.8128    562.5041    1.0059   ⋯
        ω²cl    0.0727    0.0082    0.0005    334.1100    490.7868    1.0080   ⋯
         ω²v    0.0465    0.0058    0.0003    532.5855    820.4850    1.0013   ⋯
        ω²ka    1.1000    0.1429    0.0033   1898.9600   1599.7130    0.9997   ⋯
         σ_p    0.1040    0.0024    0.0001   2151.0237   1358.5256    1.0072   ⋯
                                                                1 column omitted

Or you can also use the rhat function to return a vector of the parameters’ Rhat according to the ordering of appearance in summarystats:

rhat(pk_1cmp_tfit)
6-element Vector{Float64}:
 1.016190648689789
 1.0061381210347342
 1.007300375675496
 1.0003389357549985
 0.9985363226142564
 0.9993023867585773

which makes convenient to apply summaryzing functions:

mean(rhat(pk_1cmp_tfit))
1.0046344650879753

5 What To Do If the Markov Chains Do Not Converge?

Note

Before diving into MCMC tweaks and tuning, let’s learn about Gelman’s (2008) Folk Theorem:

“When you have computational problems, often there’s a problem with your model.”

This means that most of the MCMC samplers and algorithms have a great deal of scientific and academic scrutiny, and also engineering stress tests.

Thus, in bad MCMC convergence, since we are Bayesians, we give a higher prior mass on your model having issues than in the MCMC sampler having issues.

There are some things that you can do to make your Markov chains converge:

  • lower the target acceptance ratio:

    fit(
      ...,
      BayesMCMC(;
        target_accept=0.6 # default is 0.8
      )
    )
  • re-parameterize your model to have less parameter dependence (check Random Effects in Bayesian Models tutorial)

  • fix some parameter values to known good values, e.g. values obtained by maximum-a-posteriori (MAP) optimization

  • initialize the sampling from good parameter values, i.e. good initial values in the fit function

  • use a stronger prior around suspected good parameter values

  • simplify your model, e.g. using simpler dynamics

  • try the marginal MCMC algorithm MarginalMCMC instead of the full joint MCMC algorithm BayesMCMC (check Marginal MCMC Algorithm)

Note

Please refer to the Pumas Bayesian workflow documentation for an in-depth explanation behind these suggestions.

6 References

Gelman, A. (2008). The folk theorem of statistical computing. https://statmodeling.stat.columbia.edu/2008/05/13/the_folk_theore/

Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A., & Rubin, D. B. (2013). Bayesian Data Analysis. Chapman and Hall/CRC.

McElreath, R. (2020). Statistical rethinking: A Bayesian course with examples in R and Stan. CRC press.

Brooks, S., Gelman, A., Jones, G., & Meng, X.-L. (2011). Handbook of Markov Chain Monte Carlo. CRC Press.