Source code for jwspecfit.io

"""Spectrum I/O: FITS and NPZ readers, Spectrum container."""

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import numpy as np
from astropy.io import fits

logger = logging.getLogger(__name__)

# Physical constants for unit conversion.
_C_CGS = 2.99792458e10  # cm/s


[docs] @dataclass class Spectrum: """Container for a 1-D spectrum. Attributes ---------- wave_um : np.ndarray Observed wavelength in microns. flux_ujy : np.ndarray Flux density in micro-Jansky. err_ujy : np.ndarray 1-sigma uncertainty in micro-Jansky. grating : str or None Grating name (e.g. ``"PRISM"``, ``"G395M"``). ``None`` for stacked spectra. z : float or None Source redshift (set by user, not from header). R : float or callable or None Spectral resolving power. Overrides ``grating`` when set (useful for stacks). meta : dict Arbitrary metadata from the FITS header or user. sci_2d : np.ndarray or None Optional 2-D rectified spectrum image with shape ``(n_spatial, n_pix)`` where ``n_pix == len(wave_um)``. Populated by :func:`read_fits` when an ``SCI`` ImageHDU is present and its wavelength axis matches the 1-D extraction; otherwise ``None``. Read by :func:`~jwspecfit.plotting.plot_spectrum_interactive` to render a 2-D + 1-D stacked view; ignored by every other code path. """ wave_um: np.ndarray flux_ujy: np.ndarray err_ujy: np.ndarray grating: str | None = None z: float | None = None R: float | None = None meta: dict[str, Any] = field(default_factory=dict) sci_2d: np.ndarray | None = None # --- Derived properties --------------------------------------------------- @property def wave_A(self) -> np.ndarray: """Wavelength in Angstroms.""" return self.wave_um * 1e4 @property def n_pix(self) -> int: return len(self.wave_um) @property def wave_edges_A(self) -> np.ndarray: """Pixel-edge wavelengths in Angstroms (length n_pix + 1).""" w = self.wave_A mid = 0.5 * (w[:-1] + w[1:]) left = 2.0 * w[0] - mid[0] right = 2.0 * w[-1] - mid[-1] return np.concatenate([[left], mid, [right]]) @property def dlam_A(self) -> np.ndarray: """Pixel widths in Angstroms.""" edges = self.wave_edges_A return edges[1:] - edges[:-1] @property def flux_flam(self) -> np.ndarray: """Flux density in erg/s/cm²/Å.""" return _ujy_to_flam(self.flux_ujy, self.wave_um) @property def err_flam(self) -> np.ndarray: """Error in erg/s/cm²/Å.""" return _ujy_to_flam(self.err_ujy, self.wave_um)
[docs] def mask_valid(self) -> np.ndarray: """Boolean mask: True where flux and error are finite and err > 0.""" return np.isfinite(self.flux_ujy) & np.isfinite(self.err_ujy) & (self.err_ujy > 0)
[docs] def copy(self) -> "Spectrum": """Return a shallow copy with copied arrays.""" return Spectrum( wave_um=self.wave_um.copy(), flux_ujy=self.flux_ujy.copy(), err_ujy=self.err_ujy.copy(), grating=self.grating, z=self.z, R=self.R, meta=dict(self.meta), sci_2d=None if self.sci_2d is None else self.sci_2d.copy(), )
# --- Unit aliases for FITS auto-detection --------------------------------- _WAVE_TO_UM = { "um": 1.0, "micron": 1.0, "microns": 1.0, "micrometer": 1.0, "micrometre": 1.0, "micrometers": 1.0, "micrometres": 1.0, "a": 1e-4, "ang": 1e-4, "angstrom": 1e-4, "angstroms": 1e-4, "nm": 1e-3, "nanometer": 1e-3, "nanometre": 1e-3, "m": 1e6, "meter": 1e6, "metre": 1e6, } _FNU_TO_UJY = { "ujy": 1.0, "microjansky": 1.0, "microjy": 1.0, "njy": 1e-3, "nanojansky": 1e-3, "mjy": 1e3, "millijansky": 1e3, "jy": 1e6, "jansky": 1e6, } _FLAM_UNITS = { "erg/s/cm2/a", "erg/s/cm^2/a", "erg/s/cm**2/a", "erg s-1 cm-2 angstrom-1", "erg s^-1 cm^-2 angstrom^-1", "erg/(s cm2 a)", "flam", } def _normalise_unit(u: str | None) -> str: """Lowercase, strip, and remove common ornamentation from a unit string.""" if u is None: return "" s = str(u).strip().lower() # Replace fancy characters and squash whitespace. return ( s.replace("μ", "u").replace("µ", "u") .replace("å", "a").replace("Å", "a") .replace(" ", "") ) def _convert_wave_to_um(wave: np.ndarray, unit: str | None) -> np.ndarray: """Convert wavelength array in *unit* to microns. Defaults to µm.""" u = _normalise_unit(unit) if u == "" or u in _WAVE_TO_UM: return wave * _WAVE_TO_UM.get(u, 1.0) raise ValueError(f"Unrecognised wavelength unit: {unit!r}") def _convert_flux_to_ujy( flux: np.ndarray, unit: str | None, wave_um: np.ndarray, ) -> np.ndarray: """Convert flux array in *unit* to µJy. Defaults to µJy.""" u = _normalise_unit(unit) if u == "" or u in _FNU_TO_UJY: return flux * _FNU_TO_UJY.get(u, 1.0) if u in _FLAM_UNITS: return _flam_to_ujy(flux, wave_um) raise ValueError(f"Unrecognised flux unit: {unit!r}") # --- Column-name aliases for auto-detection ------------------------------- _WAVE_NAMES = {"wave", "wavelength", "lam", "lambda", "loglam", "wavelen"} _FLUX_NAMES = {"flux", "flux_density", "fnu", "flam", "f_lambda", "f_nu", "spec"} _ERR_NAMES = { "err", "error", "flux_err", "flux_error", "fluxerr", "sigma", "noise", "uncertainty", "stddev", "std", } _IVAR_NAMES = {"ivar", "inv_var", "inverse_variance", "weight", "wht"} def _find_column(table_names: list[str], aliases: set[str]) -> str | None: """Return the first column name in *table_names* matching one of *aliases*. Comparison is case-insensitive. Exact matches are preferred over substring matches. """ lower_map = {name.lower(): name for name in table_names} for alias in aliases: if alias in lower_map: return lower_map[alias] # Fallback: substring match. for alias in aliases: for lname, orig in lower_map.items(): if alias in lname: return orig return None def _read_bintable_spectrum( hdu, wave_col: str | None, flux_col: str | None, err_col: str | None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]: """Extract (wave_um, flux_ujy, err_ujy, meta) from a BinTable HDU.""" data = hdu.data columns = hdu.columns names = [c.name for c in columns] units = {c.name: c.unit for c in columns} wcol = wave_col or _find_column(names, _WAVE_NAMES) fcol = flux_col or _find_column(names, _FLUX_NAMES) if wcol is None or fcol is None: raise ValueError( f"Could not auto-detect wave/flux columns in HDU " f"{hdu.name!r}: available columns = {names}" ) wave_raw = np.asarray(data[wcol], dtype=float) flux_raw = np.asarray(data[fcol], dtype=float) # Handle SDSS-style log10(λ/Å) column. if wcol.lower() == "loglam": wave_um = (10.0 ** wave_raw) * 1e-4 else: wave_um = _convert_wave_to_um(wave_raw, units.get(wcol)) flux_ujy = _convert_flux_to_ujy(flux_raw, units.get(fcol), wave_um) # Error column: explicit name, then aliases, then derive from ivar. ecol = err_col or _find_column(names, _ERR_NAMES) if ecol is not None: err_raw = np.asarray(data[ecol], dtype=float) err_unit = units.get(ecol) or units.get(fcol) # err often missing TUNIT err_ujy = _convert_flux_to_ujy(err_raw, err_unit, wave_um) else: ivar_col = _find_column(names, _IVAR_NAMES) if ivar_col is not None: ivar = np.asarray(data[ivar_col], dtype=float) err_raw = np.where(ivar > 0, 1.0 / np.sqrt(np.where(ivar > 0, ivar, 1.0)), np.inf) err_ujy = _convert_flux_to_ujy(err_raw, units.get(fcol), wave_um) else: logger.warning( "No error/uncertainty column found in HDU %r; setting err=0.", hdu.name, ) err_ujy = np.zeros_like(flux_ujy) meta = {"hdu": hdu.name, "columns": names, "wave_col": wcol, "flux_col": fcol} return wave_um, flux_ujy, err_ujy, meta def _read_image_spectrum(hdu) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]: """Extract a 1-D spectrum from an Image HDU using WCS keywords. Reads ``CRVAL1``, ``CDELT1`` (or ``CD1_1``), ``CRPIX1``, ``CTYPE1``, ``CUNIT1`` to build the wavelength axis. Flux units come from ``BUNIT``. No error array is available from a single image — the error is set to zero. """ header = hdu.header data = np.asarray(hdu.data, dtype=float) if data.ndim != 1: raise ValueError( f"Image HDU has shape {data.shape}; need 1-D for spectrum reading." ) n = data.size crval = float(header.get("CRVAL1", 0.0)) crpix = float(header.get("CRPIX1", 1.0)) cdelt = float(header.get("CDELT1", header.get("CD1_1", 1.0))) ctype = str(header.get("CTYPE1", "")).upper() cunit = header.get("CUNIT1", None) bunit = header.get("BUNIT", None) pix = np.arange(n, dtype=float) + 1.0 wave_raw = crval + (pix - crpix) * cdelt if "LOG" in ctype or "AWAV-LOG" in ctype: wave_raw = 10.0 ** wave_raw wave_um = _convert_wave_to_um(wave_raw, cunit) flux_ujy = _convert_flux_to_ujy(data, bunit, wave_um) err_ujy = np.zeros_like(flux_ujy) meta = {"hdu": hdu.name, "ctype1": ctype, "cunit1": cunit, "bunit": bunit} return wave_um, flux_ujy, err_ujy, meta
[docs] def read_fits( path: str | Path, z: float | None = None, *, hdu: str | int | None = None, wave_col: str | None = None, flux_col: str | None = None, err_col: str | None = None, ) -> Spectrum: """Read a 1-D spectrum from a FITS file. By default tries the JWST/NIRSpec ``SPEC1D`` BinTable convention with columns ``wave`` (µm), ``flux`` (µJy), ``err`` (µJy). Falls through to auto-detection across all extensions when SPEC1D is absent or when *hdu* is given: 1. **BinTable HDUs** — the first table containing wavelength- and flux-like columns wins. Recognised column-name aliases include ``wave/wavelength/lam/lambda/loglam`` and ``flux/fnu/flam`` and ``err/error/sigma/noise`` (or ``ivar``). Units are read from ``TUNITn`` keywords; common aliases (µm/Å/nm, µJy/mJy/Jy, erg/s/cm²/Å) are converted automatically. 2. **Image HDUs** — a 1-D image is read with the WCS keywords ``CRVAL1``, ``CDELT1``/``CD1_1``, ``CRPIX1``, ``CTYPE1``, ``CUNIT1``; flux unit is ``BUNIT``. Errors are not available from an image and are set to zero. Parameters ---------- path : str or Path Path to the FITS file. z : float, optional Source redshift to attach to the returned :class:`Spectrum`. hdu : str or int, optional Force a specific HDU to read (name or index). When ``None`` (default), tries SPEC1D first, then auto-detects. wave_col, flux_col, err_col : str, optional Force specific column names instead of auto-detecting. Only used for BinTable HDUs. Returns ------- Spectrum """ path = Path(path) with fits.open(path) as hdul: # 1. Explicit HDU request. if hdu is not None: target = hdul[hdu] wave_um, flux_ujy, err_ujy, meta = _read_one( target, wave_col, flux_col, err_col, ) # 2. Default fast path: SPEC1D, optionally with explicit columns. elif "SPEC1D" in [h.name for h in hdul]: wave_um, flux_ujy, err_ujy, meta = _read_one( hdul["SPEC1D"], wave_col, flux_col, err_col, ) # 3. Auto-detect across all extensions. else: wave_um, flux_ujy, err_ujy, meta = _autodetect_spectrum( hdul, wave_col, flux_col, err_col, ) # Header metadata: prefer the spectrum HDU, fall back to primary. primary_header = hdul[0].header try: spec_header = hdul[meta["hdu"]].header if meta.get("hdu") else primary_header except KeyError: spec_header = primary_header grating = spec_header.get("GRATING", primary_header.get("GRATING", None)) filt = spec_header.get("FILTER", primary_header.get("FILTER", None)) # Opportunistic 2-D pickup: many JWST/NIRSpec pipelines (msaexp # in particular) ship a rectified 2-D image in the ``SCI`` # ImageHDU sharing the SPEC1D wavelength grid. We attach it to # the Spectrum when shape[1] matches the 1-D length so that # plot_spectrum_interactive can render a 2-D + 1-D preview. All # other code paths (fitting, abundances, dust correction) # ignore this field entirely. sci_2d: np.ndarray | None = None ext_names = [h.name for h in hdul] if "SCI" in ext_names: try: cand = hdul["SCI"].data if ( cand is not None and getattr(cand, "ndim", 0) == 2 and cand.shape[1] == len(wave_um) ): sci_2d = np.asarray(cand, dtype=float) except Exception: # pragma: no cover — never block the 1-D read sci_2d = None meta.update({"filename": path.name, "filter": filt}) logger.info( "Read %s [%s]: %d pixels, grating=%s", path.name, meta.get("hdu", "?"), len(wave_um), grating, ) return Spectrum( wave_um=wave_um, flux_ujy=flux_ujy, err_ujy=err_ujy, grating=grating, z=z, meta=meta, sci_2d=sci_2d, )
def _read_one( hdu, wave_col: str | None, flux_col: str | None, err_col: str | None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]: """Dispatch a single HDU to BinTable or Image reader.""" if hasattr(hdu, "columns") and hdu.columns is not None: return _read_bintable_spectrum(hdu, wave_col, flux_col, err_col) if hdu.data is not None and hdu.data.ndim == 1: return _read_image_spectrum(hdu) raise ValueError( f"HDU {hdu.name!r} is neither a BinTable nor a 1-D image." ) def _autodetect_spectrum( hdul, wave_col: str | None, flux_col: str | None, err_col: str | None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]: """Try every extension until one yields a valid (wave, flux) pair.""" last_err: Exception | None = None for hdu in hdul: # Skip primary HDUs with no data. if hdu.data is None: continue try: return _read_one(hdu, wave_col, flux_col, err_col) except (ValueError, KeyError) as exc: last_err = exc continue raise ValueError( f"No HDU in file contained a recognisable 1-D spectrum. " f"Last error: {last_err}" )
[docs] def read_npz( path: str | Path, z: float | None = None, R: float | None = None, ) -> Spectrum: """Read a stacked spectrum from a NumPy .npz file. Expected keys: ``wave_angstrom``, ``flux``, ``err``. Optionally ``n_stacked``. Parameters ---------- path : str or Path Path to the .npz file. z : float, optional Source redshift. R : float, optional Effective spectral resolving power of the stack. Returns ------- Spectrum """ path = Path(path) npz = np.load(path, allow_pickle=False) wave_A = np.asarray(npz["wave_angstrom"], dtype=float) flux = np.asarray(npz["flux"], dtype=float) err = np.asarray(npz["err"], dtype=float) meta = {"filename": path.name} if "n_stacked" in npz: meta["n_stacked"] = int(npz["n_stacked"]) logger.info("Read %s: %d pixels, R=%s", path.name, len(wave_A), R) return Spectrum( wave_um=wave_A * 1e-4, flux_ujy=flux, err_ujy=err, grating=None, z=z, R=R, meta=meta, )
[docs] def read_dict( data: dict[str, np.ndarray], z: float | None = None, grating: str | None = None, R: float | None = None, ) -> Spectrum: """Create a Spectrum from a dict with keys ``wave``/``lam``, ``flux``, ``err``. Wavelength assumed in microns. Parameters ---------- data : dict Must contain ``"wave"`` or ``"lam"`` (µm), ``"flux"`` (µJy), ``"err"`` (µJy). z : float, optional Source redshift. grating : str, optional Grating name. R : float, optional Resolving power. Returns ------- Spectrum """ wave = np.asarray(data.get("wave", data.get("lam")), dtype=float) flux = np.asarray(data["flux"], dtype=float) err = np.asarray(data["err"], dtype=float) return Spectrum(wave_um=wave, flux_ujy=flux, err_ujy=err, grating=grating, z=z, R=R)
# --------------------------------------------------------------------------- # Save / load FitResult # ---------------------------------------------------------------------------
[docs] def save_result(result: "FitResult", path: str | Path) -> None: """Save a FitResult to a .npz file for later replotting. Parameters ---------- result : FitResult Fit result to save. path : str or Path Output file path (should end in ``.npz``). """ from .fitter import FitResult, LineResult import json path = Path(path) # Serialise line results as JSON-compatible dict. lines_data = {} for name, lr in result.lines.items(): lines_data[name] = { "rest_wave_A": lr.rest_wave_A, "amplitude": lr.amplitude, "centroid_A": lr.centroid_A, "sigma_A": lr.sigma_A, "flux": lr.flux, "flux_err": lr.flux_err, "ew_A": lr.ew_A, "snr_int_err": lr.snr_int_err, "snr_int_cont": lr.snr_int_cont, "snr_peak_err": lr.snr_peak_err, "snr_peak_cont": lr.snr_peak_cont, } np.savez_compressed( path, # Arrays params=result.params, model_flux=result.model_flux, continuum=result.continuum, residuals=result.residuals, wave_um=result.spectrum.wave_um, flux_ujy=result.spectrum.flux_ujy, err_ujy=result.spectrum.err_ujy, # Scalars / metadata chi2=np.array([result.chi2]), success=np.array([result.success]), line_names=np.array(result.line_names), lines_json=np.array([json.dumps(lines_data)]), grating=np.array([result.spectrum.grating or ""]), z=np.array([result.spectrum.z or 0.0]), ) logger.info("Saved FitResult to %s", path)
[docs] def load_result(path: str | Path) -> "FitResult": """Load a FitResult from a .npz file saved by :func:`save_result`. Parameters ---------- path : str or Path Path to the ``.npz`` file. Returns ------- FitResult """ from .fitter import FitResult, LineResult import json path = Path(path) data = np.load(path, allow_pickle=False) grating_str = str(data["grating"][0]) grating = grating_str if grating_str else None z_val = float(data["z"][0]) spec = Spectrum( wave_um=data["wave_um"], flux_ujy=data["flux_ujy"], err_ujy=data["err_ujy"], grating=grating, z=z_val if z_val != 0.0 else None, ) lines_data = json.loads(str(data["lines_json"][0])) lines = {} for name, ld in lines_data.items(): # Backward compatibility: old files have a single "snr" key. old_snr = ld.get("snr", 0.0) lines[name] = LineResult( name=name, rest_wave_A=ld["rest_wave_A"], amplitude=ld["amplitude"], centroid_A=ld["centroid_A"], sigma_A=ld["sigma_A"], flux=ld["flux"], flux_err=ld["flux_err"], ew_A=ld["ew_A"], snr_int_err=ld.get("snr_int_err", old_snr), snr_int_cont=ld.get("snr_int_cont", 0.0), snr_peak_err=ld.get("snr_peak_err", 0.0), snr_peak_cont=ld.get("snr_peak_cont", 0.0), ) line_names = list(data["line_names"]) return FitResult( lines=lines, params=data["params"], model_flux=data["model_flux"], continuum=data["continuum"], residuals=data["residuals"], chi2=float(data["chi2"][0]), spectrum=spec, line_names=line_names, success=bool(data["success"][0]), )
[docs] def export_lines_txt(result: "FitResult", path: str | Path, z: float | None = None) -> None: """Export per-line measurements to a text file. Columns: name, rest_wave_A, centroid_A, flux, flux_err, ew_A, sigma_v_kms, snr_integrated, snr_peak. Parameters ---------- result : FitResult Fit result. path : str or Path Output text file path. z : float, optional Redshift (for velocity calculation). If ``None``, uses ``result.spectrum.z``. """ path = Path(path) z = z if z is not None else (result.spectrum.z or 0.0) c_kms = 299792.458 with open(path, "w") as f: f.write( f"# jwspecfit line measurements z={z:.6f}\n" f"# flux units: erg/s/cm2 | EW units: rest-frame Angstrom\n" f"# {'name':<18s} {'rest_A':>10s} {'centroid_A':>12s} " f"{'flux':>14s} {'flux_err':>14s} {'EW_A':>10s} " f"{'sigma_v':>12s} {'SNR_i_err':>10s} {'SNR_i_cont':>10s} " f"{'SNR_p_err':>10s} {'SNR_p_cont':>10s}\n" ) for name, lr in result.lines.items(): sigma_v = c_kms * lr.sigma_A / lr.centroid_A if lr.centroid_A > 0 else 0.0 f.write( f" {name:<18s} {lr.rest_wave_A:10.3f} {lr.centroid_A:12.3f} " f"{lr.flux:14.6e} {lr.flux_err:14.6e} {lr.ew_A:10.3f} " f"{sigma_v:12.2f} {lr.snr_int_err:10.2f} {lr.snr_int_cont:10.2f} " f"{lr.snr_peak_err:10.2f} {lr.snr_peak_cont:10.2f}\n" ) logger.info("Exported %d lines to %s", len(result.lines), path)
# --------------------------------------------------------------------------- # Unit conversion helpers # --------------------------------------------------------------------------- def _ujy_to_flam(flux_ujy: np.ndarray, wave_um: np.ndarray) -> np.ndarray: """Convert µJy → erg/s/cm²/Å.""" lam_cm = wave_um * 1e-4 fnu_cgs = flux_ujy * 1e-29 # µJy → erg/s/cm²/Hz # F_λ = F_ν · c / λ² (in CGS), then /1e8 to get per-Å instead of per-cm return fnu_cgs * _C_CGS / (lam_cm**2) / 1e8 def _flam_to_ujy(flux_flam: np.ndarray, wave_um: np.ndarray) -> np.ndarray: """Convert erg/s/cm²/Å → µJy.""" lam_cm = wave_um * 1e-4 fnu_cgs = flux_flam * 1e8 * lam_cm**2 / _C_CGS return fnu_cgs / 1e-29