"""MCMC convergence diagnostics.
Provides Gelman--Rubin R-hat and effective sample size (ESS) for
assessing chain convergence.
"""
from __future__ import annotations
import logging
from typing import Any
import numpy as np
logger = logging.getLogger(__name__)
[docs]
def gelman_rubin(chains: np.ndarray) -> np.ndarray:
r"""Compute the Gelman--Rubin :math:`\hat{R}` statistic per parameter.
Parameters
----------
chains : np.ndarray
MCMC chains of shape ``(n_walkers, n_steps, n_dim)``.
Returns
-------
np.ndarray
:math:`\hat{R}` per parameter (length ``n_dim``). Values near
1.0 indicate convergence; values above ~1.05 suggest the chains
have not mixed.
"""
n_walkers, n_steps, n_dim = chains.shape
if n_walkers < 2:
return np.full(n_dim, np.nan)
# Per-chain means and variances.
chain_means = chains.mean(axis=1) # (n_walkers, n_dim)
chain_vars = chains.var(axis=1, ddof=1) # (n_walkers, n_dim)
# Between-chain variance.
overall_mean = chain_means.mean(axis=0) # (n_dim,)
B = n_steps * np.var(chain_means, axis=0, ddof=1) # (n_dim,)
# Within-chain variance.
W = chain_vars.mean(axis=0) # (n_dim,)
# Pooled estimate.
var_hat = (1.0 - 1.0 / n_steps) * W + (1.0 / n_steps) * B
r_hat = np.sqrt(var_hat / np.where(W > 0, W, np.nan))
return r_hat
[docs]
def effective_sample_size(chains: np.ndarray) -> np.ndarray:
"""Estimate effective sample size (ESS) per parameter via FFT autocorrelation.
Parameters
----------
chains : np.ndarray
MCMC chains of shape ``(n_walkers, n_steps, n_dim)``.
Returns
-------
np.ndarray
ESS per parameter (length ``n_dim``).
"""
n_walkers, n_steps, n_dim = chains.shape
ess = np.zeros(n_dim)
for d in range(n_dim):
tau_sum = 0.0
for w in range(n_walkers):
x = chains[w, :, d]
x = x - x.mean()
# FFT-based autocorrelation.
n = len(x)
f = np.fft.fft(x, n=2 * n)
acf = np.fft.ifft(f * np.conj(f))[:n].real
if acf[0] > 0:
acf /= acf[0]
else:
continue
# Integrated autocorrelation time: sum until first negative.
tau = 1.0
for k in range(1, n):
if acf[k] < 0:
break
tau += 2.0 * acf[k]
tau_sum += tau
avg_tau = tau_sum / max(n_walkers, 1)
ess[d] = n_walkers * n_steps / max(avg_tau, 1.0)
return ess
[docs]
def summarise_convergence(chains: np.ndarray) -> dict[str, Any]:
"""Compute a convergence summary for MCMC chains.
Parameters
----------
chains : np.ndarray
MCMC chains of shape ``(n_walkers, n_steps, n_dim)``.
Returns
-------
dict
Keys: ``r_hat`` (array), ``ess`` (array), ``r_hat_max`` (float),
``ess_min`` (float), ``converged`` (bool, True if R-hat < 1.05
and ESS > 100 for all parameters).
"""
r_hat = gelman_rubin(chains)
ess = effective_sample_size(chains)
r_hat_max = float(np.nanmax(r_hat)) if np.any(np.isfinite(r_hat)) else np.nan
ess_min = float(np.nanmin(ess)) if np.any(np.isfinite(ess)) else 0.0
converged = bool(r_hat_max < 1.05 and ess_min > 100)
return {
"r_hat": r_hat,
"ess": ess,
"r_hat_max": r_hat_max,
"ess_min": ess_min,
"converged": converged,
}