jwspecmcmc.jax_likelihood
JAX-accelerated likelihood for NUTS/HMC sampling.
Provides JIT-compiled versions of the emission-line model and log-likelihood that are compatible with JAX’s automatic differentiation, enabling gradient-based samplers like NumPyro NUTS.
The constraint system is “compiled” into static index arrays so that the entire likelihood — from free parameters to chi-squared — is a single differentiable JAX computation graph.
Functions
|
Build a JIT-compiled JAX log-likelihood function. |
- jwspecmcmc.jax_likelihood.make_jax_log_likelihood(spec)[source]
Build a JIT-compiled JAX log-likelihood function.
- Parameters:
spec (LikelihoodSpec) – Cached data for likelihood evaluation (NumPy arrays).
- Return type:
- Returns:
log_likelihood_jax (callable) –
f(p_free) -> float, JIT-compiled.static_data (dict) – Pre-computed JAX arrays and metadata for the sampler.