"""Diagnostic plots for MCMC results.
Corner plots, trace plots, and flux posterior histograms.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import numpy as np
if TYPE_CHECKING:
from matplotlib.figure import Figure
from .result import MCMCResult
logger = logging.getLogger(__name__)
[docs]
def plot_corner(
result: MCMCResult,
*,
params: list[str] | None = None,
truths: np.ndarray | None = None,
quantiles: list[float] | None = None,
**corner_kwargs,
) -> Figure:
"""Corner plot of posterior samples.
Parameters
----------
result : MCMCResult
MCMC result.
params : list of str, optional
Parameter names to include (e.g. ``["A_OIII_5007", "sigma_Ha"]``).
If ``None``, plots all free parameters.
truths : np.ndarray, optional
True values to mark on the plot.
quantiles : list of float, optional
Quantiles for corner (default ``[0.16, 0.5, 0.84]``).
**corner_kwargs
Additional keyword arguments passed to :func:`corner.corner`.
Returns
-------
matplotlib.figure.Figure
"""
import corner
if quantiles is None:
quantiles = [0.16, 0.5, 0.84]
if params is not None:
# Resolve named params to column indices in flat_chains.
nL = len(result.line_names)
idx_map = {name: i for i, name in enumerate(result.line_names)}
col_indices = []
labels = []
for pname in params:
if pname.startswith("A_"):
line_key = pname[2:]
if line_key in idx_map:
col_indices.append(idx_map[line_key])
labels.append(pname)
elif pname.startswith("mu_"):
line_key = pname[3:]
if line_key in idx_map:
col_indices.append(nL + idx_map[line_key])
labels.append(pname)
elif pname.startswith("sigma_"):
line_key = pname[6:]
if line_key in idx_map:
col_indices.append(2 * nL + idx_map[line_key])
labels.append(pname)
data = result.flat_chains[:, col_indices]
if truths is not None:
truths = [truths[i] for i in col_indices]
else:
data = result.flat_chains_free
n_free = data.shape[1]
labels = [f"p{i}" for i in range(n_free)]
fig = corner.corner(
data,
labels=labels,
truths=truths,
quantiles=quantiles,
show_titles=True,
title_kwargs={"fontsize": 10},
**corner_kwargs,
)
return fig
[docs]
def plot_traces(
result: MCMCResult,
*,
params: list[str] | None = None,
figsize: tuple[float, float] | None = None,
) -> Figure:
"""Trace plots of MCMC chains.
Parameters
----------
result : MCMCResult
MCMC result. Must have ``chains`` (i.e. from emcee).
params : list of str, optional
Parameter names to plot. If ``None``, plots all free parameters.
figsize : tuple, optional
Figure size.
Returns
-------
matplotlib.figure.Figure
Raises
------
ValueError
If per-chain samples are not available (e.g. nautilus result).
"""
if result.chains is None:
raise ValueError(
"Trace plots require per-chain samples; the nautilus sampler "
"does not produce them. Use sampler='emcee' or 'nuts' instead."
)
chains = result.chains # (n_walkers, n_steps, n_dim)
n_walkers, n_steps, n_dim = chains.shape
if params is not None:
# Resolve named params to free-parameter indices.
nL = len(result.line_names)
idx_map = {name: i for i, name in enumerate(result.line_names)}
free_mask = result.constraints.free_mask() if result.constraints else np.ones(3 * nL, dtype=bool)
col_indices = []
labels = []
for pname in params:
if pname.startswith("A_"):
line_key = pname[2:]
full_idx = idx_map.get(line_key)
elif pname.startswith("mu_"):
line_key = pname[3:]
full_idx = nL + idx_map.get(line_key, -nL)
elif pname.startswith("sigma_"):
line_key = pname[6:]
full_idx = 2 * nL + idx_map.get(line_key, -2 * nL)
else:
continue
if full_idx is not None and free_mask[full_idx]:
free_idx = int(np.sum(free_mask[:full_idx]))
col_indices.append(free_idx)
labels.append(pname)
n_plot = len(col_indices)
else:
col_indices = list(range(n_dim))
labels = [f"p{i}" for i in range(n_dim)]
n_plot = n_dim
if figsize is None:
figsize = (10, 2.5 * n_plot)
fig, axes = plt.subplots(n_plot, 1, figsize=figsize, sharex=True)
if n_plot == 1:
axes = [axes]
n_burn = result.sampler_meta.get("n_burn", 0)
for ax, ci, label in zip(axes, col_indices, labels):
for w in range(n_walkers):
ax.plot(chains[w, :, ci], alpha=0.2, lw=0.5, color="C0")
if n_burn > 0:
ax.axvline(0, color="red", ls="--", lw=0.8, label=f"burn-in={n_burn}")
ax.set_ylabel(label, fontsize=9)
axes[-1].set_xlabel("Step (post burn-in)")
fig.tight_layout()
return fig
[docs]
def plot_flux_posterior(
result: MCMCResult,
line_name: str,
*,
bins: int = 50,
ax: plt.Axes | None = None,
) -> plt.Axes:
"""Histogram of the flux posterior for a single line.
Parameters
----------
result : MCMCResult
MCMC result.
line_name : str
Line name (e.g. ``"OIII_5007"``).
bins : int
Number of histogram bins.
ax : matplotlib.axes.Axes, optional
Axes to plot on. If ``None``, creates a new figure.
Returns
-------
matplotlib.axes.Axes
"""
if line_name not in result.lines:
raise KeyError(f"Line '{line_name}' not found in result.")
mlr = result.lines[line_name]
flux_post = mlr.flux_posterior
if ax is None:
_, ax = plt.subplots(figsize=(6, 4))
ax.hist(flux_post, bins=bins, density=True, alpha=0.7, color="C0", edgecolor="C0")
ax.axvline(mlr.flux, color="C1", ls="-", lw=1.5, label="Median")
ax.axvline(mlr.flux - mlr.flux_err[0], color="C1", ls="--", lw=1.0, label="16th pctl")
ax.axvline(mlr.flux + mlr.flux_err[1], color="C1", ls="--", lw=1.0, label="84th pctl")
ax.set_xlabel(f"Flux [{line_name}] (erg/s/cm$^2$)")
ax.set_ylabel("Probability density")
ax.legend(fontsize=8)
ax.set_title(f"{line_name} flux posterior")
return ax