"""Spectrum I/O: FITS and NPZ readers, Spectrum container."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import numpy as np
from astropy.io import fits
logger = logging.getLogger(__name__)
# Physical constants for unit conversion.
_C_CGS = 2.99792458e10 # cm/s
[docs]
@dataclass
class Spectrum:
"""Container for a 1-D spectrum.
Attributes
----------
wave_um : np.ndarray
Observed wavelength in microns.
flux_ujy : np.ndarray
Flux density in micro-Jansky.
err_ujy : np.ndarray
1-sigma uncertainty in micro-Jansky.
grating : str or None
Grating name (e.g. ``"PRISM"``, ``"G395M"``). ``None`` for stacked spectra.
z : float or None
Source redshift (set by user, not from header).
R : float or callable or None
Spectral resolving power. Overrides ``grating`` when set (useful for stacks).
meta : dict
Arbitrary metadata from the FITS header or user.
sci_2d : np.ndarray or None
Optional 2-D rectified spectrum image with shape
``(n_spatial, n_pix)`` where ``n_pix == len(wave_um)``. Populated
by :func:`read_fits` when an ``SCI`` ImageHDU is present and its
wavelength axis matches the 1-D extraction; otherwise ``None``.
Read by :func:`~jwspecfit.plotting.plot_spectrum_interactive` to
render a 2-D + 1-D stacked view; ignored by every other code
path.
"""
wave_um: np.ndarray
flux_ujy: np.ndarray
err_ujy: np.ndarray
grating: str | None = None
z: float | None = None
R: float | None = None
meta: dict[str, Any] = field(default_factory=dict)
sci_2d: np.ndarray | None = None
# --- Derived properties ---------------------------------------------------
@property
def wave_A(self) -> np.ndarray:
"""Wavelength in Angstroms."""
return self.wave_um * 1e4
@property
def n_pix(self) -> int:
return len(self.wave_um)
@property
def wave_edges_A(self) -> np.ndarray:
"""Pixel-edge wavelengths in Angstroms (length n_pix + 1)."""
w = self.wave_A
mid = 0.5 * (w[:-1] + w[1:])
left = 2.0 * w[0] - mid[0]
right = 2.0 * w[-1] - mid[-1]
return np.concatenate([[left], mid, [right]])
@property
def dlam_A(self) -> np.ndarray:
"""Pixel widths in Angstroms."""
edges = self.wave_edges_A
return edges[1:] - edges[:-1]
@property
def flux_flam(self) -> np.ndarray:
"""Flux density in erg/s/cm²/Å."""
return _ujy_to_flam(self.flux_ujy, self.wave_um)
@property
def err_flam(self) -> np.ndarray:
"""Error in erg/s/cm²/Å."""
return _ujy_to_flam(self.err_ujy, self.wave_um)
[docs]
def mask_valid(self) -> np.ndarray:
"""Boolean mask: True where flux and error are finite and err > 0."""
return np.isfinite(self.flux_ujy) & np.isfinite(self.err_ujy) & (self.err_ujy > 0)
[docs]
def copy(self) -> "Spectrum":
"""Return a shallow copy with copied arrays."""
return Spectrum(
wave_um=self.wave_um.copy(),
flux_ujy=self.flux_ujy.copy(),
err_ujy=self.err_ujy.copy(),
grating=self.grating,
z=self.z,
R=self.R,
meta=dict(self.meta),
sci_2d=None if self.sci_2d is None else self.sci_2d.copy(),
)
# --- Unit aliases for FITS auto-detection ---------------------------------
_WAVE_TO_UM = {
"um": 1.0, "micron": 1.0, "microns": 1.0, "micrometer": 1.0,
"micrometre": 1.0, "micrometers": 1.0, "micrometres": 1.0,
"a": 1e-4, "ang": 1e-4, "angstrom": 1e-4, "angstroms": 1e-4,
"nm": 1e-3, "nanometer": 1e-3, "nanometre": 1e-3,
"m": 1e6, "meter": 1e6, "metre": 1e6,
}
_FNU_TO_UJY = {
"ujy": 1.0, "microjansky": 1.0, "microjy": 1.0,
"njy": 1e-3, "nanojansky": 1e-3,
"mjy": 1e3, "millijansky": 1e3,
"jy": 1e6, "jansky": 1e6,
}
_FLAM_UNITS = {
"erg/s/cm2/a", "erg/s/cm^2/a", "erg/s/cm**2/a",
"erg s-1 cm-2 angstrom-1", "erg s^-1 cm^-2 angstrom^-1",
"erg/(s cm2 a)", "flam",
}
def _normalise_unit(u: str | None) -> str:
"""Lowercase, strip, and remove common ornamentation from a unit string."""
if u is None:
return ""
s = str(u).strip().lower()
# Replace fancy characters and squash whitespace.
return (
s.replace("μ", "u").replace("µ", "u")
.replace("å", "a").replace("Å", "a")
.replace(" ", "")
)
def _convert_wave_to_um(wave: np.ndarray, unit: str | None) -> np.ndarray:
"""Convert wavelength array in *unit* to microns. Defaults to µm."""
u = _normalise_unit(unit)
if u == "" or u in _WAVE_TO_UM:
return wave * _WAVE_TO_UM.get(u, 1.0)
raise ValueError(f"Unrecognised wavelength unit: {unit!r}")
def _convert_flux_to_ujy(
flux: np.ndarray, unit: str | None, wave_um: np.ndarray,
) -> np.ndarray:
"""Convert flux array in *unit* to µJy. Defaults to µJy."""
u = _normalise_unit(unit)
if u == "" or u in _FNU_TO_UJY:
return flux * _FNU_TO_UJY.get(u, 1.0)
if u in _FLAM_UNITS:
return _flam_to_ujy(flux, wave_um)
raise ValueError(f"Unrecognised flux unit: {unit!r}")
# --- Column-name aliases for auto-detection -------------------------------
_WAVE_NAMES = {"wave", "wavelength", "lam", "lambda", "loglam", "wavelen"}
_FLUX_NAMES = {"flux", "flux_density", "fnu", "flam", "f_lambda", "f_nu", "spec"}
_ERR_NAMES = {
"err", "error", "flux_err", "flux_error", "fluxerr", "sigma", "noise",
"uncertainty", "stddev", "std",
}
_IVAR_NAMES = {"ivar", "inv_var", "inverse_variance", "weight", "wht"}
def _find_column(table_names: list[str], aliases: set[str]) -> str | None:
"""Return the first column name in *table_names* matching one of *aliases*.
Comparison is case-insensitive. Exact matches are preferred over
substring matches.
"""
lower_map = {name.lower(): name for name in table_names}
for alias in aliases:
if alias in lower_map:
return lower_map[alias]
# Fallback: substring match.
for alias in aliases:
for lname, orig in lower_map.items():
if alias in lname:
return orig
return None
def _read_bintable_spectrum(
hdu, wave_col: str | None, flux_col: str | None, err_col: str | None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
"""Extract (wave_um, flux_ujy, err_ujy, meta) from a BinTable HDU."""
data = hdu.data
columns = hdu.columns
names = [c.name for c in columns]
units = {c.name: c.unit for c in columns}
wcol = wave_col or _find_column(names, _WAVE_NAMES)
fcol = flux_col or _find_column(names, _FLUX_NAMES)
if wcol is None or fcol is None:
raise ValueError(
f"Could not auto-detect wave/flux columns in HDU "
f"{hdu.name!r}: available columns = {names}"
)
wave_raw = np.asarray(data[wcol], dtype=float)
flux_raw = np.asarray(data[fcol], dtype=float)
# Handle SDSS-style log10(λ/Å) column.
if wcol.lower() == "loglam":
wave_um = (10.0 ** wave_raw) * 1e-4
else:
wave_um = _convert_wave_to_um(wave_raw, units.get(wcol))
flux_ujy = _convert_flux_to_ujy(flux_raw, units.get(fcol), wave_um)
# Error column: explicit name, then aliases, then derive from ivar.
ecol = err_col or _find_column(names, _ERR_NAMES)
if ecol is not None:
err_raw = np.asarray(data[ecol], dtype=float)
err_unit = units.get(ecol) or units.get(fcol) # err often missing TUNIT
err_ujy = _convert_flux_to_ujy(err_raw, err_unit, wave_um)
else:
ivar_col = _find_column(names, _IVAR_NAMES)
if ivar_col is not None:
ivar = np.asarray(data[ivar_col], dtype=float)
err_raw = np.where(ivar > 0, 1.0 / np.sqrt(np.where(ivar > 0, ivar, 1.0)), np.inf)
err_ujy = _convert_flux_to_ujy(err_raw, units.get(fcol), wave_um)
else:
logger.warning(
"No error/uncertainty column found in HDU %r; setting err=0.",
hdu.name,
)
err_ujy = np.zeros_like(flux_ujy)
meta = {"hdu": hdu.name, "columns": names, "wave_col": wcol, "flux_col": fcol}
return wave_um, flux_ujy, err_ujy, meta
def _read_image_spectrum(hdu) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
"""Extract a 1-D spectrum from an Image HDU using WCS keywords.
Reads ``CRVAL1``, ``CDELT1`` (or ``CD1_1``), ``CRPIX1``, ``CTYPE1``,
``CUNIT1`` to build the wavelength axis. Flux units come from
``BUNIT``. No error array is available from a single image — the
error is set to zero.
"""
header = hdu.header
data = np.asarray(hdu.data, dtype=float)
if data.ndim != 1:
raise ValueError(
f"Image HDU has shape {data.shape}; need 1-D for spectrum reading."
)
n = data.size
crval = float(header.get("CRVAL1", 0.0))
crpix = float(header.get("CRPIX1", 1.0))
cdelt = float(header.get("CDELT1", header.get("CD1_1", 1.0)))
ctype = str(header.get("CTYPE1", "")).upper()
cunit = header.get("CUNIT1", None)
bunit = header.get("BUNIT", None)
pix = np.arange(n, dtype=float) + 1.0
wave_raw = crval + (pix - crpix) * cdelt
if "LOG" in ctype or "AWAV-LOG" in ctype:
wave_raw = 10.0 ** wave_raw
wave_um = _convert_wave_to_um(wave_raw, cunit)
flux_ujy = _convert_flux_to_ujy(data, bunit, wave_um)
err_ujy = np.zeros_like(flux_ujy)
meta = {"hdu": hdu.name, "ctype1": ctype, "cunit1": cunit, "bunit": bunit}
return wave_um, flux_ujy, err_ujy, meta
[docs]
def read_fits(
path: str | Path,
z: float | None = None,
*,
hdu: str | int | None = None,
wave_col: str | None = None,
flux_col: str | None = None,
err_col: str | None = None,
) -> Spectrum:
"""Read a 1-D spectrum from a FITS file.
By default tries the JWST/NIRSpec ``SPEC1D`` BinTable convention with
columns ``wave`` (µm), ``flux`` (µJy), ``err`` (µJy). Falls through
to auto-detection across all extensions when SPEC1D is absent or
when *hdu* is given:
1. **BinTable HDUs** — the first table containing wavelength- and
flux-like columns wins. Recognised column-name aliases include
``wave/wavelength/lam/lambda/loglam`` and ``flux/fnu/flam`` and
``err/error/sigma/noise`` (or ``ivar``). Units are read from
``TUNITn`` keywords; common aliases (µm/Å/nm, µJy/mJy/Jy,
erg/s/cm²/Å) are converted automatically.
2. **Image HDUs** — a 1-D image is read with the WCS keywords
``CRVAL1``, ``CDELT1``/``CD1_1``, ``CRPIX1``, ``CTYPE1``,
``CUNIT1``; flux unit is ``BUNIT``. Errors are not available
from an image and are set to zero.
Parameters
----------
path : str or Path
Path to the FITS file.
z : float, optional
Source redshift to attach to the returned :class:`Spectrum`.
hdu : str or int, optional
Force a specific HDU to read (name or index). When ``None``
(default), tries SPEC1D first, then auto-detects.
wave_col, flux_col, err_col : str, optional
Force specific column names instead of auto-detecting. Only
used for BinTable HDUs.
Returns
-------
Spectrum
"""
path = Path(path)
with fits.open(path) as hdul:
# 1. Explicit HDU request.
if hdu is not None:
target = hdul[hdu]
wave_um, flux_ujy, err_ujy, meta = _read_one(
target, wave_col, flux_col, err_col,
)
# 2. Default fast path: SPEC1D, optionally with explicit columns.
elif "SPEC1D" in [h.name for h in hdul]:
wave_um, flux_ujy, err_ujy, meta = _read_one(
hdul["SPEC1D"], wave_col, flux_col, err_col,
)
# 3. Auto-detect across all extensions.
else:
wave_um, flux_ujy, err_ujy, meta = _autodetect_spectrum(
hdul, wave_col, flux_col, err_col,
)
# Header metadata: prefer the spectrum HDU, fall back to primary.
primary_header = hdul[0].header
try:
spec_header = hdul[meta["hdu"]].header if meta.get("hdu") else primary_header
except KeyError:
spec_header = primary_header
grating = spec_header.get("GRATING", primary_header.get("GRATING", None))
filt = spec_header.get("FILTER", primary_header.get("FILTER", None))
# Opportunistic 2-D pickup: many JWST/NIRSpec pipelines (msaexp
# in particular) ship a rectified 2-D image in the ``SCI``
# ImageHDU sharing the SPEC1D wavelength grid. We attach it to
# the Spectrum when shape[1] matches the 1-D length so that
# plot_spectrum_interactive can render a 2-D + 1-D preview. All
# other code paths (fitting, abundances, dust correction)
# ignore this field entirely.
sci_2d: np.ndarray | None = None
ext_names = [h.name for h in hdul]
if "SCI" in ext_names:
try:
cand = hdul["SCI"].data
if (
cand is not None
and getattr(cand, "ndim", 0) == 2
and cand.shape[1] == len(wave_um)
):
sci_2d = np.asarray(cand, dtype=float)
except Exception: # pragma: no cover — never block the 1-D read
sci_2d = None
meta.update({"filename": path.name, "filter": filt})
logger.info(
"Read %s [%s]: %d pixels, grating=%s",
path.name, meta.get("hdu", "?"), len(wave_um), grating,
)
return Spectrum(
wave_um=wave_um,
flux_ujy=flux_ujy,
err_ujy=err_ujy,
grating=grating,
z=z,
meta=meta,
sci_2d=sci_2d,
)
def _read_one(
hdu, wave_col: str | None, flux_col: str | None, err_col: str | None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
"""Dispatch a single HDU to BinTable or Image reader."""
if hasattr(hdu, "columns") and hdu.columns is not None:
return _read_bintable_spectrum(hdu, wave_col, flux_col, err_col)
if hdu.data is not None and hdu.data.ndim == 1:
return _read_image_spectrum(hdu)
raise ValueError(
f"HDU {hdu.name!r} is neither a BinTable nor a 1-D image."
)
def _autodetect_spectrum(
hdul, wave_col: str | None, flux_col: str | None, err_col: str | None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
"""Try every extension until one yields a valid (wave, flux) pair."""
last_err: Exception | None = None
for hdu in hdul:
# Skip primary HDUs with no data.
if hdu.data is None:
continue
try:
return _read_one(hdu, wave_col, flux_col, err_col)
except (ValueError, KeyError) as exc:
last_err = exc
continue
raise ValueError(
f"No HDU in file contained a recognisable 1-D spectrum. "
f"Last error: {last_err}"
)
[docs]
def read_npz(
path: str | Path,
z: float | None = None,
R: float | None = None,
) -> Spectrum:
"""Read a stacked spectrum from a NumPy .npz file.
Expected keys: ``wave_angstrom``, ``flux``, ``err``.
Optionally ``n_stacked``.
Parameters
----------
path : str or Path
Path to the .npz file.
z : float, optional
Source redshift.
R : float, optional
Effective spectral resolving power of the stack.
Returns
-------
Spectrum
"""
path = Path(path)
npz = np.load(path, allow_pickle=False)
wave_A = np.asarray(npz["wave_angstrom"], dtype=float)
flux = np.asarray(npz["flux"], dtype=float)
err = np.asarray(npz["err"], dtype=float)
meta = {"filename": path.name}
if "n_stacked" in npz:
meta["n_stacked"] = int(npz["n_stacked"])
logger.info("Read %s: %d pixels, R=%s", path.name, len(wave_A), R)
return Spectrum(
wave_um=wave_A * 1e-4,
flux_ujy=flux,
err_ujy=err,
grating=None,
z=z,
R=R,
meta=meta,
)
[docs]
def read_dict(
data: dict[str, np.ndarray],
z: float | None = None,
grating: str | None = None,
R: float | None = None,
) -> Spectrum:
"""Create a Spectrum from a dict with keys ``wave``/``lam``, ``flux``, ``err``.
Wavelength assumed in microns.
Parameters
----------
data : dict
Must contain ``"wave"`` or ``"lam"`` (µm), ``"flux"`` (µJy), ``"err"`` (µJy).
z : float, optional
Source redshift.
grating : str, optional
Grating name.
R : float, optional
Resolving power.
Returns
-------
Spectrum
"""
wave = np.asarray(data.get("wave", data.get("lam")), dtype=float)
flux = np.asarray(data["flux"], dtype=float)
err = np.asarray(data["err"], dtype=float)
return Spectrum(wave_um=wave, flux_ujy=flux, err_ujy=err, grating=grating, z=z, R=R)
# ---------------------------------------------------------------------------
# Save / load FitResult
# ---------------------------------------------------------------------------
[docs]
def save_result(result: "FitResult", path: str | Path) -> None:
"""Save a FitResult to a .npz file for later replotting.
Parameters
----------
result : FitResult
Fit result to save.
path : str or Path
Output file path (should end in ``.npz``).
"""
from .fitter import FitResult, LineResult
import json
path = Path(path)
# Serialise line results as JSON-compatible dict.
lines_data = {}
for name, lr in result.lines.items():
lines_data[name] = {
"rest_wave_A": lr.rest_wave_A,
"amplitude": lr.amplitude,
"centroid_A": lr.centroid_A,
"sigma_A": lr.sigma_A,
"flux": lr.flux,
"flux_err": lr.flux_err,
"ew_A": lr.ew_A,
"snr_int_err": lr.snr_int_err,
"snr_int_cont": lr.snr_int_cont,
"snr_peak_err": lr.snr_peak_err,
"snr_peak_cont": lr.snr_peak_cont,
}
np.savez_compressed(
path,
# Arrays
params=result.params,
model_flux=result.model_flux,
continuum=result.continuum,
residuals=result.residuals,
wave_um=result.spectrum.wave_um,
flux_ujy=result.spectrum.flux_ujy,
err_ujy=result.spectrum.err_ujy,
# Scalars / metadata
chi2=np.array([result.chi2]),
success=np.array([result.success]),
line_names=np.array(result.line_names),
lines_json=np.array([json.dumps(lines_data)]),
grating=np.array([result.spectrum.grating or ""]),
z=np.array([result.spectrum.z or 0.0]),
)
logger.info("Saved FitResult to %s", path)
[docs]
def load_result(path: str | Path) -> "FitResult":
"""Load a FitResult from a .npz file saved by :func:`save_result`.
Parameters
----------
path : str or Path
Path to the ``.npz`` file.
Returns
-------
FitResult
"""
from .fitter import FitResult, LineResult
import json
path = Path(path)
data = np.load(path, allow_pickle=False)
grating_str = str(data["grating"][0])
grating = grating_str if grating_str else None
z_val = float(data["z"][0])
spec = Spectrum(
wave_um=data["wave_um"],
flux_ujy=data["flux_ujy"],
err_ujy=data["err_ujy"],
grating=grating,
z=z_val if z_val != 0.0 else None,
)
lines_data = json.loads(str(data["lines_json"][0]))
lines = {}
for name, ld in lines_data.items():
# Backward compatibility: old files have a single "snr" key.
old_snr = ld.get("snr", 0.0)
lines[name] = LineResult(
name=name,
rest_wave_A=ld["rest_wave_A"],
amplitude=ld["amplitude"],
centroid_A=ld["centroid_A"],
sigma_A=ld["sigma_A"],
flux=ld["flux"],
flux_err=ld["flux_err"],
ew_A=ld["ew_A"],
snr_int_err=ld.get("snr_int_err", old_snr),
snr_int_cont=ld.get("snr_int_cont", 0.0),
snr_peak_err=ld.get("snr_peak_err", 0.0),
snr_peak_cont=ld.get("snr_peak_cont", 0.0),
)
line_names = list(data["line_names"])
return FitResult(
lines=lines,
params=data["params"],
model_flux=data["model_flux"],
continuum=data["continuum"],
residuals=data["residuals"],
chi2=float(data["chi2"][0]),
spectrum=spec,
line_names=line_names,
success=bool(data["success"][0]),
)
[docs]
def export_lines_txt(result: "FitResult", path: str | Path, z: float | None = None) -> None:
"""Export per-line measurements to a text file.
Columns: name, rest_wave_A, centroid_A, flux, flux_err, ew_A,
sigma_v_kms, snr_integrated, snr_peak.
Parameters
----------
result : FitResult
Fit result.
path : str or Path
Output text file path.
z : float, optional
Redshift (for velocity calculation). If ``None``, uses ``result.spectrum.z``.
"""
path = Path(path)
z = z if z is not None else (result.spectrum.z or 0.0)
c_kms = 299792.458
with open(path, "w") as f:
f.write(
f"# jwspecfit line measurements z={z:.6f}\n"
f"# flux units: erg/s/cm2 | EW units: rest-frame Angstrom\n"
f"# {'name':<18s} {'rest_A':>10s} {'centroid_A':>12s} "
f"{'flux':>14s} {'flux_err':>14s} {'EW_A':>10s} "
f"{'sigma_v':>12s} {'SNR_i_err':>10s} {'SNR_i_cont':>10s} "
f"{'SNR_p_err':>10s} {'SNR_p_cont':>10s}\n"
)
for name, lr in result.lines.items():
sigma_v = c_kms * lr.sigma_A / lr.centroid_A if lr.centroid_A > 0 else 0.0
f.write(
f" {name:<18s} {lr.rest_wave_A:10.3f} {lr.centroid_A:12.3f} "
f"{lr.flux:14.6e} {lr.flux_err:14.6e} {lr.ew_A:10.3f} "
f"{sigma_v:12.2f} {lr.snr_int_err:10.2f} {lr.snr_int_cont:10.2f} "
f"{lr.snr_peak_err:10.2f} {lr.snr_peak_cont:10.2f}\n"
)
logger.info("Exported %d lines to %s", len(result.lines), path)
# ---------------------------------------------------------------------------
# Unit conversion helpers
# ---------------------------------------------------------------------------
def _ujy_to_flam(flux_ujy: np.ndarray, wave_um: np.ndarray) -> np.ndarray:
"""Convert µJy → erg/s/cm²/Å."""
lam_cm = wave_um * 1e-4
fnu_cgs = flux_ujy * 1e-29 # µJy → erg/s/cm²/Hz
# F_λ = F_ν · c / λ² (in CGS), then /1e8 to get per-Å instead of per-cm
return fnu_cgs * _C_CGS / (lam_cm**2) / 1e8
def _flam_to_ujy(flux_flam: np.ndarray, wave_um: np.ndarray) -> np.ndarray:
"""Convert erg/s/cm²/Å → µJy."""
lam_cm = wave_um * 1e-4
fnu_cgs = flux_flam * 1e8 * lam_cm**2 / _C_CGS
return fnu_cgs / 1e-29