Source code for jwspecmcmc.jax_likelihood

"""JAX-accelerated likelihood for NUTS/HMC sampling.

Provides JIT-compiled versions of the emission-line model and
log-likelihood that are compatible with JAX's automatic
differentiation, enabling gradient-based samplers like NumPyro NUTS.

The constraint system is "compiled" into static index arrays so that
the entire likelihood — from free parameters to chi-squared — is a
single differentiable JAX computation graph.
"""

from __future__ import annotations

import logging
from math import sqrt

import numpy as np

from jwspecfit.constraints import (
    ConstraintSet,
    NII_RATIO,
    NV_RATIO,
    _INTERCOM_LOW_DENSITY_RATIOS,
)
from jwspecfit.lines import REST_LINES_A

from .likelihood import LikelihoodSpec

logger = logging.getLogger(__name__)

_SQRT2 = sqrt(2.0)


def _compile_tying_ops(cs: ConstraintSet) -> list[tuple[int, int, float]]:
    """Extract ordered (dst, src, ratio) tying operations from a ConstraintSet.

    Each operation means: ``p[dst] = p[src] * ratio``.
    The operations are returned in the same order as ``ConstraintSet.apply()``
    so that dependent constraints see their sources already updated.

    Parameters
    ----------
    cs : ConstraintSet
        The constraint set to compile.

    Returns
    -------
    list of (int, int, float)
        Ordered tying operations.
    """
    nL = len(cs.line_names)
    idx = {name: i for i, name in enumerate(cs.line_names)}
    ops: list[tuple[int, int, float]] = []

    # --- Balmer/OIII tying ---
    if cs.tie_balmer_to_oiii and "OIII_5007" in idx:
        i_o3 = idx["OIII_5007"]
        lam_o3 = REST_LINES_A["OIII_5007"]

        width_targets = [
            "Ha", "HBETA", "HGAMMA", "HDELTA",
            "HEPSILON", "H8", "H9", "H10",
        ]
        for name in width_targets:
            if name in idx:
                ratio = REST_LINES_A[name] / lam_o3
                ops.append((2 * nL + idx[name], 2 * nL + i_o3, ratio))

        centroid_targets = [
            "Ha", "HBETA", "HGAMMA", "HDELTA",
            "HEPSILON", "H8", "H9", "H10",
        ]
        for name in centroid_targets:
            if name in idx:
                ratio = REST_LINES_A[name] / lam_o3
                ops.append((nL + idx[name], nL + i_o3, ratio))

    # --- Broad centroid tying ---
    _BROAD_PAIRS = [
        ("Ha_BROAD", "Ha"), ("Ha_BROAD2", "Ha"),
        ("HBETA_BROAD", "HBETA"), ("HBETA_BROAD2", "HBETA"),
        ("HDELTA_BROAD", "HDELTA"), ("HDELTA_BROAD2", "HDELTA"),
        ("HGAMMA_BROAD", "HGAMMA"), ("HGAMMA_BROAD2", "HGAMMA"),
    ]
    for broad_name, narrow_name in _BROAD_PAIRS:
        if broad_name in idx and narrow_name in idx:
            ops.append((nL + idx[broad_name], nL + idx[narrow_name], 1.0))

    # --- OIII broad doublet kinematics (same outflowing gas) ---
    # Apply to both broad tiers (BROAD = outflow, BROAD2 = fast wind).
    oiii_lam_ratio = REST_LINES_A["OIII_4959"] / REST_LINES_A["OIII_5007"]
    for suffix in ("_BROAD", "_BROAD2"):
        pri, sec = f"OIII_5007{suffix}", f"OIII_4959{suffix}"
        if pri in idx and sec in idx:
            i_pri, i_sec = idx[pri], idx[sec]
            ops.append((nL + i_sec, nL + i_pri, oiii_lam_ratio))         # centroid
            ops.append((2 * nL + i_sec, 2 * nL + i_pri, oiii_lam_ratio))  # sigma

    # --- HeI broad kinematics (same recombination zone gas) ---
    # All HeI broads in each tier share one σ_v and Δv (anchored on
    # the first present); each line's σ_λ / Δλ scales with rest λ.
    from jwspecfit.broad import HEI_BROAD_CANDIDATES as _HEI_BC
    for suffix in ("_BROAD", "_BROAD2"):
        present = [
            f"{n}{suffix}" for n in _HEI_BC
            if f"{n}{suffix}" in idx
        ]
        if len(present) < 2:
            continue
        anchor = present[0]
        i_anchor = idx[anchor]
        lam_anchor = REST_LINES_A[anchor]
        for tgt in present[1:]:
            ratio = REST_LINES_A[tgt] / lam_anchor
            i_tgt = idx[tgt]
            ops.append((nL + i_tgt, nL + i_anchor, ratio))         # centroid
            ops.append((2 * nL + i_tgt, 2 * nL + i_anchor, ratio))  # sigma

    # --- NII doublet ---
    if cs.tie_nii and "NII_6549" in idx and "NII_6585" in idx:
        i49 = idx["NII_6549"]
        i85 = idx["NII_6585"]
        lam_ratio = REST_LINES_A["NII_6549"] / REST_LINES_A["NII_6585"]
        ops.append((i49, i85, NII_RATIO))               # amplitude
        ops.append((nL + i49, nL + i85, lam_ratio))      # centroid
        ops.append((2 * nL + i49, 2 * nL + i85, lam_ratio))  # sigma

    # --- UV doublet constraints ---
    if cs.tie_uv_doublets:
        # Amplitude-tied (NV)
        _AMPLITUDE_TIED = [
            ("NV_1", "NV_2", NV_RATIO),
        ]
        for pri, sec, ratio in _AMPLITUDE_TIED:
            if pri in idx and sec in idx:
                i_pri, i_sec = idx[pri], idx[sec]
                lam_ratio = REST_LINES_A[sec] / REST_LINES_A[pri]
                ops.append((i_sec, i_pri, ratio))                     # amplitude
                if cs.tie_uv_centroids:
                    ops.append((nL + i_sec, nL + i_pri, lam_ratio))   # centroid
                ops.append((2 * nL + i_sec, 2 * nL + i_pri, lam_ratio))  # sigma

        # Kinematic-tied intercombination doublets
        _KINEMATIC_TIED = [
            ("CIV_1", "CIV_2"),
            ("OIII_1666", "OIII_1661"),
            ("CIII]_1907", "CIII]"),
            ("NIV_1486", "NIV_1483"),
            ("NIII_1749", "NIII_1752"),
            ("SiIII_1", "SiIII_2"),
        ]
        _blended = cs.blended_doublets or set()
        for pri, sec in _KINEMATIC_TIED:
            if pri in idx and sec in idx:
                i_pri, i_sec = idx[pri], idx[sec]
                lam_ratio = REST_LINES_A[sec] / REST_LINES_A[pri]
                if cs.tie_uv_centroids:
                    ops.append((nL + i_sec, nL + i_pri, lam_ratio))
                ops.append((2 * nL + i_sec, 2 * nL + i_pri, lam_ratio))
                if sec in _blended:
                    ratio = _INTERCOM_LOW_DENSITY_RATIOS.get((pri, sec), 0.67)
                    ops.append((i_sec, i_pri, ratio))

        # UV intercombination width tying
        if cs.tie_uv_widths:
            _UV_INTERCOM = ["CIII]_1907", "NIV_1486", "NIII_1749"]
            _uv_present = [n for n in _UV_INTERCOM if n in idx]
            if len(_uv_present) >= 2:
                anchor = _uv_present[0]
                i_anchor = idx[anchor]
                lam_anchor = REST_LINES_A[anchor]
                for name in _uv_present[1:]:
                    ratio = REST_LINES_A[name] / lam_anchor
                    ops.append((2 * nL + idx[name], 2 * nL + i_anchor, ratio))

    # --- Always-tie-width doublets (unconditional) ---
    _ALWAYS_TIE_WIDTH = [
        ("CIII]_1907", "CIII]"),
    ]
    for pri, sec in _ALWAYS_TIE_WIDTH:
        if pri in idx and sec in idx:
            lam_ratio = REST_LINES_A[sec] / REST_LINES_A[pri]
            ops.append((2 * nL + idx[sec], 2 * nL + idx[pri], lam_ratio))

    return ops


[docs] def make_jax_log_likelihood( spec: LikelihoodSpec, ) -> tuple: """Build a JIT-compiled JAX log-likelihood function. Parameters ---------- spec : LikelihoodSpec Cached data for likelihood evaluation (NumPy arrays). Returns ------- log_likelihood_jax : callable ``f(p_free) -> float``, JIT-compiled. static_data : dict Pre-computed JAX arrays and metadata for the sampler. """ import jax import jax.numpy as jnp # Enable 64-bit precision (essential for wavelengths in Angstroms). jax.config.update("jax_enable_x64", True) constraints = spec.constraints nL = spec.n_lines free_mask = constraints.free_mask() free_idx = np.where(free_mask)[0] ops = _compile_tying_ops(constraints) # Convert to Python tuples for static tracing. ops_tuples = [(int(d), int(s), float(r)) for d, s, r in ops] free_idx_jax = jnp.array(free_idx, dtype=jnp.int32) # Pre-compute static data as JAX arrays. # Sanitize non-finite values to 0 — invalid pixels are masked by # inv_err_w=0, but JAX evaluates all arithmetic (NaN * 0 = NaN in # IEEE 754), so we must ensure no NaN enters the computation. _flam_safe = np.where(spec.valid, spec.flam, 0.0) _err_safe = np.where(spec.valid, spec.flam_err, 1.0) left = jnp.array(spec.edges[:-1], dtype=jnp.float64) right = jnp.array(spec.edges[1:], dtype=jnp.float64) widths = right - left flam = jnp.array(_flam_safe, dtype=jnp.float64) flam_err = jnp.array(_err_safe, dtype=jnp.float64) w_pix = jnp.array(spec.w_pix, dtype=jnp.float64) valid = jnp.array(spec.valid, dtype=jnp.bool_) # Combined weight: 1/(err) * w_pix, zeroed for invalid pixels. inv_err_w = jnp.where(valid, w_pix / flam_err, 0.0) n_full = 3 * nL _n_lya = spec.n_lya _has_lya = _n_lya > 0 # Pre-compute bin centres for the skewed Gaussian (used only for Lyα). if _has_lya: centres = 0.5 * (left + right) @jax.jit def log_likelihood_jax(p_free: jnp.ndarray) -> float: # Split off Lyα parameters if present. if _has_lya: p_gauss = p_free[:-_n_lya] p_lya = p_free[-_n_lya:] else: p_gauss = p_free # Expand free → full parameter vector. full = jnp.zeros(n_full, dtype=jnp.float64) full = full.at[free_idx_jax].set(p_gauss) # Apply tying ops (loop unrolled at trace time). for d, s, r in ops_tuples: full = full.at[d].set(full[s] * r) # Unpack parameters. amps = full[:nL] mus = full[nL:2 * nL] sigs = full[2 * nL:] # Build model: bin-averaged Gaussians via erf. inv = 1.0 / (_SQRT2 * sigs) # (nL,) cdf_r = 0.5 * (1.0 + jax.lax.erf( (right[:, None] - mus[None, :]) * inv[None, :] )) cdf_l = 0.5 * (1.0 + jax.lax.erf( (left[:, None] - mus[None, :]) * inv[None, :] )) area = cdf_r - cdf_l # (n_pix, nL) profiles = area / widths[:, None] # (n_pix, nL) model = profiles @ amps # (n_pix,) # Lyα: single asymmetric Gaussian (Bolan+2025 parameterisation). # p_lya = [A_peak, mu, sigma, alpha] if _has_lya: t = (centres - p_lya[1]) / p_lya[2] gauss = p_lya[0] * jnp.exp(-0.5 * t**2) skew_term = 1.0 + jax.lax.erf(p_lya[3] * t / _SQRT2) model = model + gauss * skew_term # Weighted residual. resid = (flam - model) * inv_err_w return -0.5 * jnp.dot(resid, resid) # Bounds in free-parameter space. lb_free = np.zeros(len(free_idx)) ub_free = np.zeros(len(free_idx)) # These will be set by the caller from the engine. static_data = { "free_idx": free_idx, "ops": ops_tuples, "n_lines": nL, "n_free": len(free_idx) + _n_lya, "n_lya": _n_lya, } return log_likelihood_jax, static_data