jwspecmcmc.jax_likelihood

JAX-accelerated likelihood for NUTS/HMC sampling.

Provides JIT-compiled versions of the emission-line model and log-likelihood that are compatible with JAX’s automatic differentiation, enabling gradient-based samplers like NumPyro NUTS.

The constraint system is “compiled” into static index arrays so that the entire likelihood — from free parameters to chi-squared — is a single differentiable JAX computation graph.

Functions

make_jax_log_likelihood(spec)

Build a JIT-compiled JAX log-likelihood function.

jwspecmcmc.jax_likelihood.make_jax_log_likelihood(spec)[source]

Build a JIT-compiled JAX log-likelihood function.

Parameters:

spec (LikelihoodSpec) – Cached data for likelihood evaluation (NumPy arrays).

Return type:

tuple

Returns:

  • log_likelihood_jax (callable) – f(p_free) -> float, JIT-compiled.

  • static_data (dict) – Pre-computed JAX arrays and metadata for the sampler.