Source code for jwspecmcmc.samplers

"""MCMC sampler wrappers for emcee, nautilus, and NumPyro NUTS.

All wrappers accept a :class:`~jwspecmcmc.likelihood.LikelihoodSpec`
and :class:`~jwspecmcmc.priors.PriorSet`, run the sampler, and return
a common result dict.
"""

from __future__ import annotations

import logging
from typing import Any

import numpy as np

from .likelihood import LikelihoodSpec, log_probability
from .priors import PriorSet

logger = logging.getLogger(__name__)


def _auto_n_walkers(n_dim: int) -> int:
    """Choose n_walkers from n_dim and available CPU cores.

    The number of walkers must satisfy ``n_walkers >= 2 * n_dim`` for
    emcee's stretch move.  We round up to the nearest multiple of
    ``n_cores`` so that future pool-based parallelisation distributes
    likelihood evaluations evenly across cores.

    Parameters
    ----------
    n_dim : int
        Number of free parameters.

    Returns
    -------
    int
        Even number of walkers.
    """
    import os

    n_cores = os.cpu_count() or 4
    min_walkers = 2 * n_dim + 2
    # Round up to nearest multiple of n_cores.
    n_walkers = max(min_walkers, n_cores)
    remainder = n_walkers % n_cores
    if remainder:
        n_walkers += n_cores - remainder
    # Ensure even (emcee requirement).
    if n_walkers % 2:
        n_walkers += 1
    return n_walkers


[docs] def run_emcee( spec: LikelihoodSpec, prior_set: PriorSet, p0_free: np.ndarray, *, n_walkers: int | str = "auto", n_steps: int = 2000, n_burn: int | None = None, progress: bool = True, seed: int = 42, moves: Any = None, ) -> dict[str, Any]: """Run the emcee ensemble sampler. Parameters ---------- spec : LikelihoodSpec Cached data for likelihood evaluation. prior_set : PriorSet Prior distributions. p0_free : np.ndarray MLE estimate in free-parameter space (used to initialise walkers). n_walkers : int or ``"auto"`` Number of walkers. ``"auto"`` (default) picks a value based on ``n_dim`` and the number of CPU cores. n_steps : int Number of MCMC steps (default 2000). n_burn : int or None Burn-in steps to discard. If ``None``, estimated from the integrated autocorrelation time. progress : bool Show a progress bar. seed : int Random seed. moves : optional Custom emcee moves. If ``None``, uses the default ``StretchMove``. Returns ------- dict Keys: ``flat_chains`` (n_samples, n_dim), ``flat_log_prob`` (n_samples,), ``chains`` (n_walkers, n_steps_kept, n_dim), ``log_prob_chains`` (n_walkers, n_steps_kept), ``n_burn`` (int), ``sampler_name`` (str), ``sampler_meta`` (dict). """ import emcee from .priors import GaussianPrior, LogUniformPrior, UniformPrior n_dim = prior_set.n_dim if n_walkers == "auto": n_walkers = _auto_n_walkers(n_dim) logger.info("Auto n_walkers=%d (n_dim=%d, n_cores=%d).", n_walkers, n_dim, __import__("os").cpu_count() or 4) # emcee requires n_walkers >= 2 * n_dim for the stretch move. min_walkers = 2 * n_dim + 2 # +2 for safety (must be even) if n_walkers < min_walkers: n_walkers = min_walkers + (min_walkers % 2) # ensure even logger.info("Increased n_walkers to %d (>= 2 * n_dim = %d).", n_walkers, 2 * n_dim) rng = np.random.default_rng(seed) # Initialise walkers as a Gaussian ball around the MLE. # Use 1% of the prior range as the scatter scale to ensure walkers # are spread enough to be linearly independent while staying close # to the MLE. scale = np.zeros(n_dim) for i, prior in enumerate(prior_set.priors): if isinstance(prior, (UniformPrior, LogUniformPrior)): scale[i] = 0.01 * (prior.hi - prior.lo) elif isinstance(prior, GaussianPrior): scale[i] = 0.01 * prior.std else: scale[i] = max(np.abs(p0_free[i]) * 1e-3, 1e-30) scale = np.maximum(scale, 1e-30) p0 = p0_free[np.newaxis, :] + scale[np.newaxis, :] * rng.standard_normal( (n_walkers, n_dim) ) # Clip to prior support to avoid -inf at start. for i, prior in enumerate(prior_set.priors): if isinstance(prior, (UniformPrior, LogUniformPrior)): lo, hi = prior.lo, prior.hi p0[:, i] = np.clip(p0[:, i], lo + 1e-30, hi - 1e-30) elif isinstance(prior, GaussianPrior): if np.isfinite(prior.lo): p0[:, i] = np.maximum(p0[:, i], prior.lo + 1e-30) if np.isfinite(prior.hi): p0[:, i] = np.minimum(p0[:, i], prior.hi - 1e-30) sampler = emcee.EnsembleSampler( n_walkers, n_dim, log_probability, args=(spec, prior_set), moves=moves, ) logger.info( "Running emcee: %d walkers, %d steps, %d dims", n_walkers, n_steps, n_dim, ) # skip_initial_state_check: emcee's condition number check can fail # when parameters span many orders of magnitude (e.g. amplitude ~1e-18, # centroid ~30000 Å). This is cosmetic — the sampling is fine. sampler.run_mcmc(p0, n_steps, progress=progress, skip_initial_state_check=True) # Determine burn-in. if n_burn is None: try: tau = sampler.get_autocorr_time(quiet=True) n_burn = int(2.0 * np.nanmax(tau)) logger.info("Auto burn-in from autocorrelation time: %d steps", n_burn) except Exception: n_burn = n_steps // 4 logger.info("Autocorrelation time estimation failed; using n_burn=%d", n_burn) n_burn = min(n_burn, n_steps - 1) chains = sampler.get_chain(discard=n_burn) # (n_steps_kept, n_walkers, n_dim) chains = chains.transpose(1, 0, 2) # (n_walkers, n_steps_kept, n_dim) log_prob_chains = sampler.get_log_prob(discard=n_burn).T # (n_walkers, n_steps_kept) flat_chains = sampler.get_chain(discard=n_burn, flat=True) # (n_samples, n_dim) flat_log_prob = sampler.get_log_prob(discard=n_burn, flat=True) return { "flat_chains": flat_chains, "flat_log_prob": flat_log_prob, "chains": chains, "log_prob_chains": log_prob_chains, "n_burn": n_burn, "sampler_name": "emcee", "sampler_meta": { "n_walkers": n_walkers, "n_steps": n_steps, "n_dim": n_dim, "n_burn": n_burn, }, }
[docs] def run_nautilus( spec: LikelihoodSpec, prior_set: PriorSet, *, n_live: int = 2000, n_eff: int = 10000, progress: bool = True, seed: int = 42, ) -> dict[str, Any]: """Run the nautilus nested sampler. Parameters ---------- spec : LikelihoodSpec Cached data for likelihood evaluation. prior_set : PriorSet Prior distributions. n_live : int Number of live points (default 2000). n_eff : int Target effective sample size (default 10000). progress : bool Show a progress bar. seed : int Random seed. Returns ------- dict Same keys as :func:`run_emcee`, except ``chains`` and ``log_prob_chains`` are ``None`` (nautilus does not produce walker chains). """ from nautilus import Prior as NautilusPrior, Sampler from .priors import GaussianPrior, LogUniformPrior, UniformPrior n_dim = prior_set.n_dim # Build nautilus Prior object from our PriorSet. naut_prior = NautilusPrior() for i, prior in enumerate(prior_set.priors): param_name = f"p{i}" if isinstance(prior, UniformPrior): naut_prior.add_parameter(param_name, dist=(prior.lo, prior.hi)) elif isinstance(prior, LogUniformPrior): from scipy.stats import loguniform naut_prior.add_parameter( param_name, dist=loguniform(prior.lo, prior.hi), ) elif isinstance(prior, GaussianPrior): from scipy.stats import truncnorm a = (prior.lo - prior.mean) / prior.std if np.isfinite(prior.lo) else -np.inf b = (prior.hi - prior.mean) / prior.std if np.isfinite(prior.hi) else np.inf naut_prior.add_parameter( param_name, dist=truncnorm(a, b, loc=prior.mean, scale=prior.std), ) else: # Fallback: treat as uniform on a generous range. logger.warning("Unknown prior type for param %d; using uniform fallback.", i) naut_prior.add_parameter(param_name, dist=(-1e30, 1e30)) # nautilus v1.0.5 passes params as a dict keyed by parameter name # (e.g. {"p0": val0, "p1": val1, ...}). Convert to a flat array. param_names = [f"p{i}" for i in range(n_dim)] def likelihood_fn(params: dict) -> float: p_arr = np.array([float(params[k]) for k in param_names]) return log_probability(p_arr, spec, prior_set) sampler = Sampler( naut_prior, likelihood_fn, n_live=n_live, seed=seed, ) logger.info( "Running nautilus: %d live points, target n_eff=%d, %d dims", n_live, n_eff, n_dim, ) sampler.run(n_eff=n_eff, verbose=progress) points, log_w, log_l = sampler.posterior() # Resample to equally-weighted samples. weights = np.exp(log_w - np.max(log_w)) weights /= weights.sum() rng = np.random.default_rng(seed) indices = rng.choice(len(points), size=min(n_eff, len(points)), p=weights) flat_chains = points[indices] flat_log_prob = log_l[indices] return { "flat_chains": flat_chains, "flat_log_prob": flat_log_prob, "chains": None, "log_prob_chains": None, "n_burn": 0, "sampler_name": "nautilus", "sampler_meta": { "n_live": n_live, "n_eff": n_eff, "n_dim": n_dim, "n_samples": len(flat_chains), }, }
[docs] def run_nuts( spec: LikelihoodSpec, prior_set: PriorSet, p0_free: np.ndarray, *, n_warmup: int = 500, n_samples: int = 2000, n_chains: int = 6, progress: bool = True, seed: int = 42, target_accept_prob: float = 0.8, max_tree_depth: int = 10, ) -> dict[str, Any]: """Run NumPyro NUTS (No-U-Turn Sampler) with JAX-accelerated likelihood. Uses automatic differentiation through the emission-line model for gradient-based HMC sampling, requiring far fewer likelihood evaluations than ensemble samplers like emcee. Parameters ---------- spec : LikelihoodSpec Cached data for likelihood evaluation. prior_set : PriorSet Prior distributions. p0_free : np.ndarray MLE estimate in free-parameter space (used to initialise). n_warmup : int Number of warmup (adaptation) steps (default 500). n_samples : int Number of posterior samples per chain (default 2000). n_chains : int Number of independent chains (default 1). progress : bool Show a progress bar. seed : int Random seed. target_accept_prob : float Target acceptance probability for NUTS adaptation (default 0.8). max_tree_depth : int Maximum tree depth for NUTS (default 10). Returns ------- dict Same keys as :func:`run_emcee`. """ import jax import jax.numpy as jnp import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS from .jax_likelihood import make_jax_log_likelihood from .priors import GaussianPrior, LogUniformPrior, UniformPrior n_dim = prior_set.n_dim # Build JIT-compiled log-likelihood. log_likelihood_jax, static_data = make_jax_log_likelihood(spec) # Build bounds arrays for the uniform priors. lb = np.array([ p.lo if isinstance(p, (UniformPrior, LogUniformPrior)) else -np.inf for p in prior_set.priors ]) ub = np.array([ p.hi if isinstance(p, (UniformPrior, LogUniformPrior)) else np.inf for p in prior_set.priors ]) lb_jax = jnp.array(lb, dtype=jnp.float64) ub_jax = jnp.array(ub, dtype=jnp.float64) # Pre-compute log-prior normalisation for uniform priors. log_prior_norm = 0.0 for prior in prior_set.priors: if isinstance(prior, UniformPrior): log_prior_norm -= np.log(prior.hi - prior.lo) def numpyro_model(): # Sample each free parameter with its prior. p_free = numpyro.sample( "params", dist.Uniform(lb_jax, ub_jax), ) # Add log-likelihood as a factor. ll = log_likelihood_jax(p_free) numpyro.factor("log_likelihood", ll) # Validate initial parameters before passing to NumPyro. p0_jax = jnp.array(p0_free, dtype=jnp.float64) oob_lo = p0_free <= lb oob_hi = p0_free >= ub if np.any(oob_lo) or np.any(oob_hi): bad = np.where(oob_lo | oob_hi)[0] for idx in bad: logger.warning( " p0_free[%d] = %.6e bounds = [%.6e, %.6e] %s", idx, p0_free[idx], lb[idx], ub[idx], "BELOW" if oob_lo[idx] else "ABOVE", ) # Clip to strictly inside bounds. p0_free = np.clip(p0_free, lb + 1e-10 * np.abs(lb + 1), ub - 1e-10 * np.abs(ub + 1)) p0_jax = jnp.array(p0_free, dtype=jnp.float64) ll_init = float(log_likelihood_jax(p0_jax)) if not np.isfinite(ll_init): # Dump diagnostics to help identify the problem. print(f"\n=== NUTS INIT FAILURE: log-likelihood = {ll_init} ===") print(f"n_free = {n_dim}, n_lines = {spec.n_lines}") for idx in range(n_dim): flag = "" if p0_free[idx] <= lb[idx]: flag = " *** AT LOWER BOUND" elif p0_free[idx] >= ub[idx]: flag = " *** AT UPPER BOUND" if lb[idx] == ub[idx]: flag = " *** ZERO-WIDTH BOUND" print(f" [{idx:3d}] p0={p0_free[idx]:+.6e} " f"lb={lb[idx]:+.6e} ub={ub[idx]:+.6e}{flag}") # Also check for zero sigmas in the full param vector. free_mask = spec.constraints.free_mask() full = np.zeros(3 * spec.n_lines) full[free_mask] = p0_free from .jax_likelihood import _compile_tying_ops ops = _compile_tying_ops(spec.constraints) for d, s, r in ops: full[d] = full[s] * r nL = spec.n_lines sigs = full[2 * nL:] for i, sig in enumerate(sigs): if sig <= 0: print(f" *** sigma[{i}] = {sig:.6e} (line index {i}) — ZERO/NEGATIVE") # Compare with NumPy likelihood. from .likelihood import log_likelihood as _ll_np ll_np = _ll_np(p0_free, spec) print(f"NumPy log-likelihood at p0: {ll_np}") # Test at midpoint of bounds. p_mid = 0.5 * (lb + ub) ll_mid_jax = float(log_likelihood_jax(jnp.array(p_mid, dtype=jnp.float64))) ll_mid_np = _ll_np(p_mid, spec) print(f"Midpoint JAX ll: {ll_mid_jax}, NumPy ll: {ll_mid_np}") # Check for non-finite data inputs. n_bad_flam = int(np.sum(~np.isfinite(spec.flam))) n_bad_err = int(np.sum(~np.isfinite(spec.flam_err))) n_zero_err = int(np.sum((spec.flam_err <= 0) & spec.valid)) print(f"Data: {n_bad_flam} non-finite flam, {n_bad_err} non-finite err, " f"{n_zero_err} zero/neg err in valid pixels") print("=== END DIAGNOSTICS ===\n") raise RuntimeError( f"NUTS initialisation failed: log-likelihood = {ll_init} at p0. " f"This usually means parameter bounds are too tight or inconsistent. " f"See diagnostics above." ) logger.info("Initial log-likelihood at p0: %.2f", ll_init) kernel = NUTS( numpyro_model, target_accept_prob=target_accept_prob, max_tree_depth=max_tree_depth, init_strategy=numpyro.infer.init_to_value( values={"params": p0_jax}, ), ) mcmc = MCMC( kernel, num_warmup=n_warmup, num_samples=n_samples, num_chains=n_chains, progress_bar=progress, ) logger.info( "Running NumPyro NUTS: %d warmup, %d samples, %d chain(s), %d dims", n_warmup, n_samples, n_chains, n_dim, ) rng_key = jax.random.PRNGKey(seed) mcmc.run(rng_key) # Extract samples. ``group_by_chain=True`` keeps the per-chain axis, # which is needed for Gelman--Rubin R-hat, ESS, and trace plots. samples_by_chain = mcmc.get_samples(group_by_chain=True)["params"] chains = np.asarray(samples_by_chain) # (n_chains, n_samples, n_dim) flat_chains = chains.reshape(-1, chains.shape[-1]) # (n_chains * n_samples, n_dim) # Compute log-probabilities for each sample. log_prob_fn = jax.jit(lambda p: log_likelihood_jax(p) + log_prior_norm) flat_log_prob = np.array(jax.vmap(log_prob_fn)(jnp.array(flat_chains))) log_prob_chains = flat_log_prob.reshape(chains.shape[:2]) # (n_chains, n_samples) # Extract diagnostics. extra_fields = mcmc.get_extra_fields() diverging = extra_fields.get("diverging", None) n_divergent = int(np.sum(diverging)) if diverging is not None else 0 if n_divergent > 0: logger.warning( "%d divergent transitions detected (%.1f%%). Consider increasing " "target_accept_prob or max_tree_depth.", n_divergent, 100.0 * n_divergent / len(flat_chains), ) return { "flat_chains": flat_chains, "flat_log_prob": flat_log_prob, "chains": chains, "log_prob_chains": log_prob_chains, "n_burn": 0, "sampler_name": "nuts", "sampler_meta": { "n_warmup": n_warmup, "n_samples": n_samples, "n_chains": n_chains, "n_dim": n_dim, "n_divergent": n_divergent, "target_accept_prob": target_accept_prob, "max_tree_depth": max_tree_depth, }, }