Source code for jwspecfit.plotting

"""Publication-quality visualisation of spectral fits."""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
    import plotly.graph_objects as go
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure

    from .fitter import FitResult
    from .io import Spectrum


# Default emission-line markers (keys into REST_LINES_A) and their
# preferred display labels.  Shared between plot_spectrum_interactive
# and plot_2d_1d so labels stay consistent.
_DEFAULT_MARKER_NAMES: list[str] = [
    "Lya", "NIV_doublet", "CIV_doublet", "HEII_1640",
    "NIII_doublet", "CIII]",
    "OII_doublet", "NeIII_3869",
    "HEI_4027", "HDELTA", "HEI_4145", "HEII_4200", "HGAMMA", "OIII_4363",
    "FeII_4584", "NIII_4642", "FeIII_4660", "HeII_4687",
    "ArIV_4713", "FeII_4732", "ArIV_4741",
    "HBETA", "OIII_4959", "OIII_5007",
    "CIV_5803", "CIV_5814",
    "HEI_5877", "OI_6302",
    "Ha", "NII_6585", "HEI_6680",
    "SII_6718", "SII_6732",
    "HEI_7067", "ArIII_7138",
]

_DEFAULT_MARKER_LABELS: dict[str, str] = {
    "Lya": "Lyα", "NIV_doublet": "NIV", "CIV_doublet": "CIV",
    "HEII_1640": "HeII 1640",
    "NIII_doublet": "NIII 1750", "CIII]": "CIII]",
    "OII_doublet": "[OII]", "NeIII_3869": "[NeIII]",
    "HEI_4027": "HeI 4027",
    "HDELTA": "Hδ",
    "HEI_4145": "HeI 4145",
    "HEII_4200": "HeII 4200",
    "HGAMMA": "Hγ",
    "OIII_4363": "[OIII]4363",
    "FeII_4584": "FeII 4584", "NIII_4642": "NIII 4642",
    "FeIII_4660": "FeIII 4660", "HeII_4687": "HeII 4687",
    "ArIV_4713": "ArIV+HeI 4705",
    "FeII_4732": "FeII 4732", "ArIV_4741": "ArIV 4741",
    "HBETA": "Hβ",
    "OIII_4959": "[OIII]4959", "OIII_5007": "[OIII]5007",
    "CIV_5803": "CIV 5803", "CIV_5814": "CIV 5814",
    "HEI_5877": "HeI 5877",
    "OI_6302": "[OI] 6302",
    "Ha": "Hα", "NII_6585": "[NII] 6585",
    "HEI_6680": "HeI 6680",
    "SII_6718": "[SII]6716", "SII_6732": "[SII]6731",
    "HEI_7067": "HeI 7065",
    "ArIII_7138": "[ArIII] 7138",
}


def _build_exclude_mask(
    wave_A: np.ndarray,
    exclude_wave_A: list[tuple[float, float]] | None,
) -> np.ndarray:
    """Return a boolean mask that is True for pixels to KEEP.

    Parameters
    ----------
    wave_A : np.ndarray
        Wavelength array in Angstroms.
    exclude_wave_A : list of (lo, hi) tuples, optional
        Wavelength ranges to exclude.

    Returns
    -------
    np.ndarray
        Boolean mask (True = keep).
    """
    keep = np.ones(len(wave_A), dtype=bool)
    if exclude_wave_A is not None:
        for lo, hi in exclude_wave_A:
            keep &= ~((wave_A >= lo) & (wave_A <= hi))
    return keep


[docs] def plot_fit( result: "FitResult", *, fig: "Figure | None" = None, wave_unit: str = "A", flux_unit: str = "fnu", show_residuals: bool = True, show_components: bool = True, label_lines: bool = True, y_pad: float = 1.3, exclude_wave_A: list[tuple[float, float]] | None = None, rest_frame: bool = False, save_path: str | None = None, ) -> "Figure": """Plot a spectral fit with data, model, continuum, and residuals. The y-axis upper limit is set to the peak of the tallest emission line (above continuum) times *y_pad*, so the plot is scaled to the lines rather than noise spikes. Parameters ---------- result : FitResult Output of :func:`~jwspecfit.fitter.fit_lines`. fig : Figure, optional Matplotlib figure to draw on. If ``None``, creates a new one. wave_unit : str ``"A"`` for Angstroms (default) or ``"um"`` for microns. flux_unit : str ``"fnu"`` for µJy (default) or ``"flam"`` for erg/s/cm²/Å. show_residuals : bool Show residual panel below the main plot (default True). show_components : bool Show individual Gaussian components as filled curves (default True). Broad components are drawn with hatching for clarity. label_lines : bool Annotate line identifications (default True). y_pad : float Multiplicative padding above the tallest line peak (default 1.3). exclude_wave_A : list of (float, float), optional Wavelength ranges in Angstroms to hide from the plot. Each tuple is ``(lo, hi)``. Useful for masking noisy detector regions. rest_frame : bool If ``True``, plot wavelengths in the rest frame by dividing by ``(1 + z)`` using the redshift stored in the spectrum. Default ``False`` (observed frame). save_path : str, optional If given, save the figure to this file path (e.g. ``"fit.pdf"``). Returns ------- Figure The matplotlib figure. """ import matplotlib.pyplot as plt from .models import build_model from .io import _flam_to_ujy, _ujy_to_flam # Auto-convert MCMCResult / MCMCBroadFitResult to FitResult. if hasattr(result, "to_fit_result") and not hasattr(result, "residuals"): result = result.to_fit_result() spec = result.spectrum # Rest-frame scaling factor. zp1 = 1.0 if rest_frame and spec.z is not None: zp1 = 1.0 + spec.z if wave_unit == "A": wave = spec.wave_A / zp1 xlabel = r"Rest Wavelength [$\mathrm{\AA}$]" if rest_frame else r"Wavelength [$\mathrm{\AA}$]" else: wave = spec.wave_um / zp1 xlabel = r"Rest Wavelength [$\mu$m]" if rest_frame else r"Wavelength [$\mu$m]" use_flam = flux_unit.lower() == "flam" if use_flam: flux = _ujy_to_flam(spec.flux_ujy, spec.wave_um) err = _ujy_to_flam(spec.err_ujy, spec.wave_um) cont = _ujy_to_flam(result.continuum, spec.wave_um) model_total = _ujy_to_flam(result.model_flux + result.continuum, spec.wave_um) resid = _ujy_to_flam(result.residuals, spec.wave_um) else: flux = spec.flux_ujy err = spec.err_ujy cont = result.continuum model_total = result.model_flux + cont resid = result.residuals # Figure setup. if fig is None: if show_residuals: fig, (ax_main, ax_res) = plt.subplots( 2, 1, figsize=(10, 6), height_ratios=[3, 1], sharex=True, gridspec_kw={"hspace": 0.05}, ) else: fig, ax_main = plt.subplots(1, 1, figsize=(10, 4.5)) ax_res = None else: axes = fig.get_axes() ax_main = axes[0] ax_res = axes[1] if len(axes) > 1 else None valid = spec.mask_valid() # Apply wavelength exclusion mask. keep = _build_exclude_mask(spec.wave_A, exclude_wave_A) show = valid & keep # Disable scientific notation / offset on axes so full numbers are shown. from matplotlib.ticker import ScalarFormatter sfmt = ScalarFormatter(useOffset=False) sfmt.set_scientific(False) # --- Individual line components (smooth Gaussians, behind data) --- if show_components and len(result.line_names) > 0: edges = spec.wave_edges_A nL = len(result.line_names) # Colour map: narrow lines in blues/greens, broad in reds/oranges. narrow_names = [n for n in result.line_names if "BROAD" not in n] broad_names = [n for n in result.line_names if "BROAD" in n] narrow_colours = plt.cm.Set2(np.linspace(0, 0.8, max(len(narrow_names), 1))) broad_colours = plt.cm.Oranges(np.linspace(0.4, 0.8, max(len(broad_names), 1))) n_narrow = 0 n_broad = 0 # NaN-out excluded regions so components don't plot through them. wave_plot = wave.copy().astype(float) wave_plot[~keep] = np.nan # Build index mapping: line_names may include "Lya" which is not # in the Gaussian params vector. Skip it for component plotting. _gauss_names = [n for n in result.line_names if n != "Lya"] _gauss_idx = {n: j for j, n in enumerate(_gauss_names)} nL_gauss = len(_gauss_names) for i, name in enumerate(result.line_names): if name == "Lya": continue # Lyα uses skewed Gaussian, not in params vector gi = _gauss_idx[name] amp = result.params[gi] p_single = np.zeros(3 * nL_gauss) p_single[gi] = amp p_single[nL_gauss + gi] = result.params[nL_gauss + gi] p_single[2 * nL_gauss + gi] = result.params[2 * nL_gauss + gi] comp_flam = build_model(p_single, edges, nL_gauss) if use_flam: comp_plot = comp_flam + _ujy_to_flam(result.continuum, spec.wave_um) else: comp_plot = _flam_to_ujy(comp_flam, spec.wave_um) + cont # NaN-out excluded regions. comp_plot_masked = comp_plot.copy() comp_plot_masked[~keep] = np.nan cont_masked = cont.copy() cont_masked[~keep] = np.nan is_broad = "BROAD" in name is_abs = name.startswith("abs_") # Fractional uncertainty for shading. lr = result.lines.get(name) frac_err = 0.0 if lr is not None and abs(lr.flux) > 0 and lr.flux_err > 0: frac_err = lr.flux_err / abs(lr.flux) if is_broad: colour = broad_colours[n_broad % len(broad_colours)] n_broad += 1 ax_main.fill_between( wave_plot, cont_masked, comp_plot_masked, alpha=0.25, color=colour, hatch="//", linewidth=0, ) ax_main.plot(wave_plot, comp_plot_masked, "-", color=colour, lw=1.2, alpha=0.8) else: colour = narrow_colours[n_narrow % len(narrow_colours)] n_narrow += 1 ax_main.fill_between( wave_plot, cont_masked, comp_plot_masked, alpha=0.20, color=colour, linewidth=0, ) ax_main.plot(wave_plot, comp_plot_masked, "-", color=colour, lw=0.8, alpha=0.7) # Uncertainty shading (±1σ on the Gaussian profile). if frac_err > 0: line_only = comp_plot_masked - cont_masked comp_hi = cont_masked + line_only * (1.0 + frac_err) comp_lo = cont_masked + line_only * max(1.0 - frac_err, 0.0) ax_main.fill_between( wave_plot, comp_lo, comp_hi, alpha=0.12, color=colour, linewidth=0, ) if label_lines: centroid_obs_A = result.params[nL_gauss + gi] centroid_A = centroid_obs_A / zp1 if wave_unit == "A": x_label = centroid_A else: x_label = centroid_A * 1e-4 y_label = comp_plot[np.argmin(np.abs(spec.wave_A - centroid_obs_A))] display_name = name.replace("abs_", "").replace("_", " ") # Absorption labels below the trough; emission labels above. y_offset = -10 if is_abs else 8 va = "top" if is_abs else "baseline" ax_main.annotate( display_name, xy=(x_label, y_label), xytext=(0, y_offset), textcoords="offset points", fontsize=7, ha="center", va=va, color=colour, fontweight="bold" if is_broad else "normal", rotation=45, ) # Lyα asymmetric Gaussian overlay for static plot. if show_components: _lya_p_s = getattr(result, "lya_params", None) if _lya_p_s is not None and len(_lya_p_s) == 4: from .models import asymmetric_gaussian as _ag centres_s = 0.5 * (edges[:-1] + edges[1:]) lya_flam_s = _ag(centres_s, _lya_p_s[0], _lya_p_s[1], _lya_p_s[2], _lya_p_s[3]) if use_flam: comp_s = lya_flam_s + _ujy_to_flam(result.continuum, spec.wave_um) else: comp_s = _flam_to_ujy(lya_flam_s, spec.wave_um) + cont comp_s_m = comp_s.copy() comp_s_m[~keep] = np.nan cont_m = cont.copy() cont_m[~keep] = np.nan ax_main.fill_between( wave_plot, cont_m, comp_s_m, alpha=0.25, color="C0", linewidth=0, ) ax_main.plot( wave_plot, comp_s_m, "-", color="C0", lw=0.8, alpha=0.7, ) if label_lines: # Peak of the asymmetric Gaussian (find numerically). _peak_idx_s = np.argmax(lya_flam_s) mu_s = centres_s[_peak_idx_s] x_lbl = mu_s / zp1 if wave_unit == "A" else mu_s * 1e-4 / zp1 peak_flam = lya_flam_s[_peak_idx_s] cont_at = np.interp(mu_s * 1e-4, spec.wave_um, result.continuum) if use_flam: y_lbl = peak_flam + _ujy_to_flam(np.array([cont_at]), np.array([mu_s * 1e-4]))[0] else: y_lbl = _flam_to_ujy(np.array([peak_flam]), np.array([mu_s * 1e-4]))[0] + cont_at ax_main.annotate( "Lyα", xy=(x_lbl, y_lbl), xytext=(0, 8), textcoords="offset points", fontsize=7, ha="center", va="baseline", color="C0", rotation=45, ) # Main panel: data + model + continuum. ax_main.step(wave[show], flux[show], where="mid", color="0.3", lw=0.8, label="Data", zorder=3) ax_main.fill_between( wave[show], (flux - err)[show], (flux + err)[show], step="mid", alpha=0.12, color="0.5", zorder=2, ) ax_main.step(wave[keep], cont[keep], where="mid", color="C2", lw=1.0, alpha=0.7, label="Continuum", linestyle="--", zorder=4) ax_main.step(wave[keep], model_total[keep], where="mid", color="C3", lw=1.2, alpha=0.5, label="Model", zorder=5) ylabel = r"$f_\lambda$ [erg s$^{-1}$ cm$^{-2}$ $\mathrm{\AA}^{-1}$]" if use_flam else r"Flux density [$\mu$Jy]" ax_main.set_ylabel(ylabel) ax_main.legend(fontsize=8, loc="upper right") ax_main.xaxis.set_major_formatter(sfmt) if result.spectrum.z is not None: ax_main.set_title(f"z = {result.spectrum.z:.4f} | χ²/dof = {result.chi2:.2f}") # --- Y-axis limits based on emission-line peaks / absorption troughs --- model_peak = np.nanmax(model_total[show]) if np.any(show) else 1.0 model_trough = np.nanmin(model_total[show]) if np.any(show) else 0.0 cont_median = np.nanmedian(cont[show]) if np.any(show) else 0.0 # Upper limit: tallest line peak × y_pad y_upper = cont_median + (model_peak - cont_median) * y_pad # Lower limit: accommodate absorption troughs or minimum continuum y_lower_cont = np.nanmin(cont[show]) * 1.1 if np.any(show) else -0.1 y_lower = min(0.0, y_lower_cont, model_trough - abs(model_trough) * 0.15) if y_upper > y_lower: ax_main.set_ylim(y_lower, y_upper) # Residual panel — x-range limited to the extent of the fitted lines. if show_residuals and ax_res is not None: ax_res.step(wave[show], resid[show], where="mid", color="0.3", lw=0.8) ax_res.axhline(0, color="C3", lw=0.8, ls="--") ax_res.fill_between( wave[show], -err[show], err[show], step="mid", alpha=0.15, color="0.5", ) ax_res.set_ylabel("Residual") ax_res.set_xlabel(xlabel) ax_res.xaxis.set_major_formatter(sfmt) # Clip x-range to the outermost fitted lines ± 5σ margin. if len(result.line_names) > 0: nL = len(result.line_names) centroids_A = result.params[nL: 2 * nL] / zp1 sigmas_A = result.params[2 * nL: 3 * nL] / zp1 xlim_lo_A = np.min(centroids_A - 5 * sigmas_A) xlim_hi_A = np.max(centroids_A + 5 * sigmas_A) if wave_unit == "A": ax_res.set_xlim(xlim_lo_A, xlim_hi_A) ax_main.set_xlim(xlim_lo_A, xlim_hi_A) else: ax_res.set_xlim(xlim_lo_A * 1e-4, xlim_hi_A * 1e-4) ax_main.set_xlim(xlim_lo_A * 1e-4, xlim_hi_A * 1e-4) else: ax_main.set_xlabel(xlabel) try: fig.tight_layout() except Exception: pass if save_path is not None: save_path = Path(save_path).with_suffix(".png") fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig
def _to_rgba(colour: str, alpha: float) -> str: """Convert any CSS colour string to ``rgba(r,g,b,alpha)``.""" if colour.startswith("rgba("): # Replace existing alpha. return colour.rsplit(",", 1)[0] + f",{alpha})" if colour.startswith("rgb("): return colour.replace("rgb(", "rgba(").replace(")", f",{alpha})") if colour.startswith("#"): h = colour.lstrip("#") return f"rgba({int(h[0:2],16)},{int(h[2:4],16)},{int(h[4:6],16)},{alpha})" return f"rgba(150,150,150,{alpha})" _MULTI_PALETTE = [ "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#17becf", ] # Component colours for plot_fit_interactive: every narrow line shares # one colour, every _BROAD shares another, every _BROAD2 a third, so # the eye groups them by kinematic class rather than by line identity. _COMP_COLOUR_NARROW = "rgba(31,119,180,0.85)" # plotly C0 blue _COMP_COLOUR_BROAD = "rgba(255,127,14,0.75)" # plotly C1 orange _COMP_COLOUR_BROAD2 = "rgba(214,39,40,0.75)" # plotly C3 red _COMP_COLOUR_ABS = "rgba(70,130,180,0.8)" # steel blue def _component_colour(name: str, is_abs: bool) -> str: """Pick the fixed colour for a fitted component by class.""" if is_abs: return _COMP_COLOUR_ABS if "_BROAD2" in name: return _COMP_COLOUR_BROAD2 if "_BROAD" in name: return _COMP_COLOUR_BROAD return _COMP_COLOUR_NARROW def _draw_emission_line_markers( fig, *, z_for_lines: "float | None", wave_unit: str, x_lo: "float | None", x_hi: "float | None", lines: "Sequence[str] | bool | None" = None, add_lines: "dict[str, float] | Sequence[str] | None" = None, line_color: str = "darkred", use_paper_shapes: bool = False, ) -> None: """Draw curated emission-line markers + labels onto a plotly figure. Shared by :func:`plot_spectrum_interactive` and :func:`plot_fit_interactive` so both honour the same ``lines`` / ``add_lines`` / ``line_color`` semantics and produce identical marker styling. Parameters ---------- fig : plotly.graph_objects.Figure Figure to draw markers on, in place. z_for_lines : float or None Redshift used to place markers. ``0.0`` puts them at rest wavelengths; ``None`` skips drawing entirely. wave_unit : str ``"A"`` (Å) or ``"um"`` (µm) — must match the x-axis. x_lo, x_hi : float or None Plotted x-range. Markers outside the range are dropped. When either is ``None`` no markers are drawn. lines, add_lines, line_color Same semantics as the public plot functions. See :func:`plot_spectrum_interactive` for details. use_paper_shapes : bool ``True`` when *fig* is a multi-row subplot — the marker is a paper-coord shape (spans every row) plus a single annotation pinned to the figure top. ``False`` uses ``add_vline`` and works for a single-panel figure. """ if z_for_lines is None or x_lo is None or x_hi is None: return from .lines import REST_LINES_A default_names = _DEFAULT_MARKER_NAMES display = _DEFAULT_MARKER_LABELS markers: list[tuple[float, str]] = [] if lines is not False: names = default_names if lines is None else list(lines) for nm in names: rest_A = REST_LINES_A.get(nm) if rest_A is None: continue obs_A = rest_A * (1.0 + z_for_lines) x = obs_A if wave_unit == "A" else obs_A * 1e-4 markers.append((x, display.get(nm, nm))) if add_lines: if isinstance(add_lines, dict): add_items = list(add_lines.items()) else: add_items = [] for nm in add_lines: rest_A = REST_LINES_A.get(nm) if rest_A is None: continue label = display.get(nm, nm.replace("_", " ")) add_items.append((label, rest_A)) for label, rest_A in add_items: obs_A = float(rest_A) * (1.0 + z_for_lines) x = obs_A if wave_unit == "A" else obs_A * 1e-4 markers.append((x, str(label))) # Clip to plotted range, then stagger so close labels don't overlap. markers = [(x, lab) for x, lab in markers if x_lo <= x <= x_hi] markers.sort(key=lambda m: m[0]) if not markers: return threshold = 0.03 * (x_hi - x_lo) row_last_x: list[float] = [] rows: list[int] = [] for x, _ in markers: placed = False for r, last_x in enumerate(row_last_x): if x - last_x >= threshold: row_last_x[r] = x rows.append(r) placed = True break if not placed: row_last_x.append(x) rows.append(len(row_last_x) - 1) row_spacing_px = 14 for (x, label), r in zip(markers, rows): if use_paper_shapes: # Paper-coord shape spans every subplot row; a single # annotation at the top avoids the per-row duplication # that add_vline(row="all", ...) produces. fig.add_shape( type="line", xref="x", yref="paper", x0=x, x1=x, y0=0, y1=1, line=dict(color=line_color, width=0.8, dash="dash"), opacity=0.6, layer="below", ) fig.add_annotation( x=x, xref="x", y=1.0, yref="paper", yshift=r * row_spacing_px, text=label, showarrow=False, font=dict(size=9, color=line_color), xanchor="center", yanchor="bottom", ) else: fig.add_vline( x=x, line_width=0.8, line_dash="dash", line_color=line_color, opacity=0.6, annotation_text=label, annotation_position="top", annotation_font_size=9, annotation_font_color=line_color, annotation_yshift=r * row_spacing_px, layer="below", ) # Grow top margin so stacked rows fit above the plot. if row_last_x: n_rows = len(row_last_x) desired_t = 60 + (n_rows - 1) * row_spacing_px + 14 fig.update_layout(margin=dict(t=desired_t))
[docs] def plot_spectrum_interactive( source: "Spectrum | str | Path | Sequence[Spectrum | str | Path]", *, z: float | None = None, wave_unit: str = "A", flux_unit: str = "fnu", rest_frame: bool = False, exclude_wave_A: list[tuple[float, float]] | None = None, title: str | None = None, labels: "str | Sequence[str] | None" = None, lines: "Sequence[str] | bool | None" = None, add_lines: "dict[str, float] | Sequence[str] | None" = None, line_color: str = "darkred", show_zero: bool = True, show_2d: "bool | str" = "auto", cmap_2d: str = "Blues", vmin_pct: float = 5.0, vmax_pct: float = 99.5, y_crop: tuple[float, float] = (0.25, 0.75), **read_kwargs, ) -> "go.Figure": """Open and interactively plot one or more 1-D spectra. Accepts a :class:`~jwspecfit.io.Spectrum` object, a path to a ``.fits`` / ``.npz`` file, or a list / tuple of such items. When given a path, the file is read via :func:`~jwspecfit.io.read_fits` (or :func:`read_npz` for ``.npz``) and any extra ``read_kwargs`` are forwarded to the reader (e.g. ``hdu=``, ``wave_col=`` for FITS overrides). Parameters ---------- source : Spectrum, str, Path, or sequence of these A single spectrum / file path, or a list / tuple of them to overplot. z : float, optional Source redshift. Used only when a *source* item is a path; for an existing :class:`Spectrum`, its own ``z`` is preserved. Forwarded to every reader call. wave_unit : str ``"A"`` for Angstroms (default) or ``"um"`` for microns. flux_unit : str ``"fnu"`` for µJy (default) or ``"flam"`` for erg/s/cm²/Å. rest_frame : bool If ``True`` and a spectrum has a redshift, divide wavelengths by ``(1 + z)``. Default ``False`` (observed frame). Applied per spectrum. exclude_wave_A : list of (float, float), optional Wavelength ranges in Angstroms to hide from the plot. title : str, optional Figure title. When a single spectrum is supplied, defaults to filename + redshift + grating if available; when multiple spectra are supplied, defaults to no title. labels : str or sequence of str, optional Legend label(s). When ``None`` (default), a single spectrum uses ``"Data"`` and multiple spectra use each spectrum's filename (or ``"Spectrum {i}"`` as a fallback). lines : sequence of str, bool, or None Emission lines to mark as vertical dashed lines at the supplied redshift. ``None`` (default) draws a curated list of common UV/optical lines. Pass an explicit list of keys from :data:`jwspecfit.lines.REST_LINES_A` to override, or ``False`` to disable. The effective redshift is taken from ``z`` if given, else from a single spectrum's own ``spec.z``. In rest-frame mode the markers sit at the rest wavelengths. add_lines : dict[str, float] or sequence of str, optional Extra lines to overlay on top of *lines*. Two accepted forms: - **dict** — ``{label: rest_wavelength_A}``. Free-form labels with explicit rest-frame wavelengths in **Angstroms**. Use this for lines not in :data:`jwspecfit.lines.REST_LINES_A`, e.g. ``add_lines={"Mg II 2796": 2796.352}``. - **list of str** — names from :data:`REST_LINES_A` (e.g. ``add_lines=["H8", "HEPSILON", "FeII_2382"]``). The rest wavelength is looked up automatically. Call :func:`jwspecfit.show_lines` to see what's available. Each entry is redshifted by ``(1 + z)`` and staggered alongside the default markers. line_color : str Colour for the emission-line markers and their labels (default ``"darkred"``). show_zero : bool Draw a light-grey dashed horizontal line at ``y = 0`` to make continuum detection easier to gauge by eye (default ``True``). show_2d : bool or {"auto"} Whether to render the 2-D rectified spectrum image above the 1-D panel. ``"auto"`` (default) shows it iff a *single* spectrum is supplied **and** its :attr:`~jwspecfit.io.Spectrum.sci_2d` is populated (e.g. when read from an msaexp ``.spec.fits`` file). ``True`` forces it (silently no-ops when multiple spectra are supplied or no 2-D is available); ``False`` always shows the 1-D only. The two panels share the wavelength axis; emission-line markers span both. cmap_2d : str Plotly colorscale for the 2-D panel (default ``"Blues"`` to mirror :func:`plot_2d_1d`). vmin_pct, vmax_pct : float Percentile clip for the 2-D colour scale (default ``5`` / ``99.5``). y_crop : (float, float) Fractional crop ``(lo, hi)`` of the spatial axis on the 2-D panel (default ``(0.25, 0.75)`` keeps the middle 50 % of rows). **read_kwargs Forwarded to the file reader when a *source* item is a path. Returns ------- plotly.graph_objects.Figure """ import plotly.graph_objects as go from plotly.subplots import make_subplots from .io import Spectrum, read_fits, read_npz, _ujy_to_flam # Normalise sources / labels to parallel lists. if isinstance(source, (list, tuple)): sources_list = list(source) else: sources_list = [source] if labels is None: labels_in = [None] * len(sources_list) elif isinstance(labels, str): labels_in = [labels] else: labels_in = list(labels) if len(labels_in) != len(sources_list): raise ValueError( f"labels has length {len(labels_in)} but {len(sources_list)} " f"sources were given." ) multi = len(sources_list) > 1 # Resolve every source to a Spectrum. specs: list[Spectrum] = [] for s in sources_list: if isinstance(s, Spectrum): specs.append(s) else: path = Path(s) suffix = path.suffix.lower() if suffix in (".fits", ".fit", ".fz"): specs.append(read_fits(path, z=z, **read_kwargs)) elif suffix == ".npz": specs.append(read_npz(path, z=z, **read_kwargs)) else: raise ValueError( f"Unsupported file extension {suffix!r}: pass a .fits " f"or .npz file, or a Spectrum object." ) use_flam = flux_unit.lower() == "flam" # Decide whether to stack a 2-D panel above the 1-D. Only meaningful # for a single spectrum with a populated sci_2d (msaexp .spec.fits). # In multi-spec mode the 2-D panel is silently omitted (the explicit # user contract: "if multiple spectra, show 1-D only"). if show_2d == "auto": _show_2d = (not multi) and (specs[0].sci_2d is not None) else: _show_2d = bool(show_2d) and (not multi) and (specs[0].sci_2d is not None) if _show_2d: fig = make_subplots( rows=2, cols=1, shared_xaxes=True, row_heights=[0.26, 0.74], vertical_spacing=0.02, ) panel_kw = dict(row=2, col=1) else: fig = go.Figure() panel_kw = {} all_flux_show: list[np.ndarray] = [] x_mins: list[float] = [] x_maxs: list[float] = [] err_legend_done = False xlabel = ylabel = flux_label = None for i, spec in enumerate(specs): # Rest-frame scaling (per spectrum). Source of redshift: # prefer the Spectrum's own z (e.g. set by read_fits(z=...)), # else fall back to the explicit ``z`` kwarg so users who read # a file without z=... can still get a rest-frame plot. Only # silently disable rest-frame when neither is available. rf = rest_frame zp1 = 1.0 z_data: float | None = spec.z if spec.z is not None else z if rf and z_data is not None: zp1 = 1.0 + float(z_data) elif rf: rf = False # No z anywhere — silently observed-frame. if wave_unit == "A": wave = spec.wave_A / zp1 xlabel = "Rest Wavelength [Å]" if rf else "Wavelength [Å]" else: wave = spec.wave_um / zp1 xlabel = "Rest Wavelength [µm]" if rf else "Wavelength [µm]" if use_flam: flux = _ujy_to_flam(spec.flux_ujy, spec.wave_um) err = _ujy_to_flam(spec.err_ujy, spec.wave_um) flux_label = "erg/s/cm²/Å" ylabel = f"fλ [{flux_label}]" else: flux = spec.flux_ujy err = spec.err_ujy flux_label = "µJy" ylabel = f"Flux density [{flux_label}]" # For plotting, only require finite flux — be lenient about errors so # spectra without an error array (e.g. image HDUs) still render. valid = np.isfinite(flux) keep = _build_exclude_mask(spec.wave_A, exclude_wave_A) show = valid & keep has_err = np.any(np.isfinite(err) & (err > 0)) err_show = show & np.isfinite(err) & (err > 0) # Per-trace colour and label. if multi: colour = _MULTI_PALETTE[i % len(_MULTI_PALETTE)] band_fill = _to_rgba(colour, 0.18) else: colour = "black" band_fill = "rgba(150,150,150,0.20)" if labels_in[i] is not None: name = labels_in[i] elif multi: fname = spec.meta.get("filename") name = str(fname) if fname else f"Spectrum {i + 1}" else: name = "Data" # Error band — step-shaped fill between ±1σ. Uses two traces # with `fill='tonexty'` so the upper/lower edges are drawn as # step functions matching the data trace shape. Legend entry # appears once across all spectra. if has_err and np.any(err_show): fig.add_trace(go.Scatter( x=wave[err_show], y=(flux - err)[err_show], mode="lines", line=dict(width=0, shape="hvh"), showlegend=False, hoverinfo="skip", ), **panel_kw) fig.add_trace(go.Scatter( x=wave[err_show], y=(flux + err)[err_show], mode="lines", line=dict(width=0, shape="hvh"), fill="tonexty", fillcolor=band_fill, name="±1σ", showlegend=not err_legend_done, hoverinfo="skip", ), **panel_kw) err_legend_done = True # Data trace (histogram-step style). fig.add_trace(go.Scatter( x=wave[show], y=flux[show], mode="lines", name=name, line=dict(color=colour, width=0.9, shape="hvh"), hovertemplate=f"λ=%{{x:.3f}}<br>flux=%{{y:.4e}} {flux_label}<extra></extra>", ), **panel_kw) if np.any(show): all_flux_show.append(flux[show]) x_mins.append(float(np.nanmin(wave[show]))) x_maxs.append(float(np.nanmax(wave[show]))) # Title. if title is None: if not multi: spec = specs[0] bits = [] fname = spec.meta.get("filename") if fname: bits.append(str(fname)) # Title z: prefer spec.z, fall back to explicit z= kwarg. _z_title = spec.z if spec.z is not None else z if _z_title is not None: bits.append(f"z = {float(_z_title):.4f}") if spec.grating: bits.append(spec.grating) title = " | ".join(bits) else: title = "" # Y-limits — show the full vertical range so emission lines (including # faint ones like [OIII]λ4363) are visible by default. Lower bound is # the 2nd percentile (floored at 0) to avoid noise outliers dominating # the axis; upper bound is the data max with a small pad on top. if all_flux_show: f_show = np.concatenate(all_flux_show) finite = np.isfinite(f_show) if np.any(finite): lo = float(np.nanpercentile(f_show[finite], 2)) hi = float(np.nanmax(f_show[finite])) pad = 0.05 * (hi - lo if hi > lo else max(abs(hi), 1.0)) y_lower = min(0.0, lo - pad) y_upper = hi + pad else: y_lower, y_upper = -1.0, 1.0 else: y_lower, y_upper = -1.0, 1.0 # --- 2-D rectified-spectrum panel (row=1 when enabled) --- if _show_2d: spec0 = specs[0] sci = np.asarray(spec0.sci_2d, dtype=float) # x-axis matches the 1-D wavelength array (and inherits rest-frame # scaling from the spec0 branch of the per-spectrum loop above). # Same z-resolution policy as the 1-D loop above: spec.z first, # explicit z= kwarg as fallback, else 1.0 (observed frame). _z_2d: float | None = spec0.z if spec0.z is not None else z zp1_0 = (1.0 + float(_z_2d)) if (rest_frame and _z_2d is not None) else 1.0 if wave_unit == "A": wave_2d = spec0.wave_A / zp1_0 else: wave_2d = spec0.wave_um / zp1_0 ny = sci.shape[0] y_lo_pix = int(round(ny * y_crop[0])) y_hi_pix = max(y_lo_pix + 1, int(round(ny * y_crop[1]))) sci_crop = sci[y_lo_pix:y_hi_pix, :] y_pix = np.arange(y_lo_pix, y_hi_pix) finite = np.isfinite(sci_crop) if np.any(finite): vmin, vmax = np.nanpercentile( sci_crop[finite], [vmin_pct, vmax_pct], ) else: vmin, vmax = 0.0, 1.0 fig.add_trace( go.Heatmap( x=wave_2d, y=y_pix, z=sci_crop, colorscale=cmap_2d, zmin=vmin, zmax=vmax, showscale=False, hoverinfo="skip", ), row=1, col=1, ) # Legend below the plot at all times — top-right covered emission # lines on tall axes, and horizontal-bottom looks the same whether # one spectrum or several are overlaid. legend = dict(orientation="h", x=0.5, xanchor="center", y=-0.22, yanchor="top") bottom_margin = 110 # Layout differs between single-panel and subplot modes: subplots # need per-axis updates (update_xaxes / update_yaxes with row/col). if _show_2d: fig.update_layout( title=title, template="plotly_white", hovermode="x unified", dragmode="zoom", legend=legend, width=1000, height=620, ) fig.update_xaxes(exponentformat="none") # Bottom (1-D) panel labels + range. fig.update_xaxes(title_text=xlabel, row=2, col=1) fig.update_yaxes( title_text=ylabel, range=[y_lower, y_upper], row=2, col=1, ) # Top (2-D) panel: hide tick labels, label spatial axis. fig.update_yaxes( title_text="Spatial pix", showticklabels=False, row=1, col=1, ) if bottom_margin is not None: fig.update_layout(margin=dict(b=bottom_margin)) else: layout_kwargs = dict( title=title, xaxis_title=xlabel, yaxis_title=ylabel, yaxis_range=[y_lower, y_upper], xaxis=dict(exponentformat="none"), template="plotly_white", hovermode="x unified", dragmode="zoom", legend=legend, width=1000, height=500, ) if bottom_margin is not None: layout_kwargs["margin"] = dict(b=bottom_margin) fig.update_layout(**layout_kwargs) # --- Zero-flux reference (light grey dashed) --- if show_zero: if _show_2d: fig.add_hline( y=0, line_width=1, line_dash="dash", line_color="lightgrey", layer="below", row=2, col=1, ) else: fig.add_hline( y=0, line_width=1, line_dash="dash", line_color="lightgrey", layer="below", ) # --- Emission-line markers at supplied redshift --- z_eff = z if z_eff is None and not multi: z_eff = specs[0].z if rest_frame: z_for_lines: float | None = 0.0 elif z_eff is not None: z_for_lines = float(z_eff) else: z_for_lines = None _draw_emission_line_markers( fig, z_for_lines=z_for_lines, wave_unit=wave_unit, x_lo=min(x_mins) if x_mins else None, x_hi=max(x_maxs) if x_maxs else None, lines=lines, add_lines=add_lines, line_color=line_color, use_paper_shapes=_show_2d, ) return fig
[docs] def plot_fit_interactive( result: "FitResult", *, wave_unit: str = "A", flux_unit: str = "fnu", show_components: bool = True, show_residuals: bool = True, y_pad: float = 1.3, exclude_wave_A: list[tuple[float, float]] | None = None, rest_frame: bool = False, z: float | None = None, lines: "Sequence[str] | bool | None" = False, add_lines: "dict[str, float] | Sequence[str] | None" = None, line_color: str = "darkred", show_2d: "bool | str" = "auto", cmap_2d: str = "Blues", vmin_pct: float = 5.0, vmax_pct: float = 99.5, y_crop: tuple[float, float] = (0.25, 0.75), ) -> "go.Figure": """Interactive plotly plot of a spectral fit with zoom and hover. Optionally stacks a 2-D rectified-spectrum panel above the main fit panel (when ``result.spectrum.sci_2d`` is populated by :func:`read_fits`) and a residual panel below. By default no curated emission-line markers are drawn — every fitted line is already labelled at its peak above the model. Pass ``lines=None`` to add the package-default markers, or an explicit list of :data:`jwspecfit.lines.REST_LINES_A` keys to mark only those. Parameters ---------- result : FitResult Output of :func:`~jwspecfit.fitter.fit_lines`. wave_unit : str ``"A"`` for Angstroms (default) or ``"um"`` for microns. flux_unit : str ``"fnu"`` for µJy (default) or ``"flam"`` for erg/s/cm²/Å. show_components : bool Show individual line components (default True). show_residuals : bool Show residual panel below the main plot (default True). y_pad : float Multiplicative padding above tallest line (default 1.3). exclude_wave_A : list of (float, float), optional Wavelength ranges in Angstroms to hide from the plot. rest_frame : bool If ``True``, plot wavelengths in the rest frame by dividing by ``(1 + z)``. Default ``False`` (observed frame). Emission-line markers are placed at rest wavelengths when ``True``. z : float, optional Redshift override used for rest-frame conversion and marker placement. When ``None`` (default), ``result.spectrum.z`` is used. Raises ``ValueError`` if ``rest_frame=True`` and neither source provides a redshift. lines : sequence of str, bool, or None Curated emission-line markers (vertical dashed lines + labels). ``False`` (default) draws none — the fit-component peak annotations already identify every fitted line. ``None`` opts in to the package-default marker set; an explicit list of :data:`jwspecfit.lines.REST_LINES_A` keys marks only those names. add_lines : dict[str, float] or sequence of str, optional Extra markers to overlay on top of *lines*. Same semantics as in :func:`plot_spectrum_interactive`. line_color : str Colour for the emission-line markers and their labels (default ``"darkred"``). show_2d : bool or {"auto"} Whether to stack the 2-D rectified spectrum image above the fit panel. ``"auto"`` (default) shows it iff ``result.spectrum.sci_2d`` is populated. ``True`` forces it (silently no-ops when no 2-D is available); ``False`` always hides it. cmap_2d : str Plotly colorscale for the 2-D panel (default ``"Blues"``). vmin_pct, vmax_pct : float Percentile clip for the 2-D colour scale (default ``5`` / ``99.5``). y_crop : (float, float) Fractional crop ``(lo, hi)`` of the spatial axis on the 2-D panel (default ``(0.25, 0.75)``). Returns ------- plotly.graph_objects.Figure """ import plotly.graph_objects as go from plotly.subplots import make_subplots from .io import _flam_to_ujy, _ujy_to_flam # Auto-convert MCMCResult / MCMCBroadFitResult to FitResult. if hasattr(result, "to_fit_result") and not hasattr(result, "residuals"): result = result.to_fit_result() spec = result.spectrum # Rest-frame scaling factor. Fail loudly if rest_frame is requested # but no z is available — silently falling back to zp1=1 produced a # plot that looked observed-frame with no warning. z_used = z if z is not None else spec.z if rest_frame and z_used is None: raise ValueError( "rest_frame=True but no redshift available: result.spectrum.z " "is None and no z= override was supplied. Pass z=... or set " "spec.z before calling." ) zp1 = (1.0 + float(z_used)) if rest_frame else 1.0 if wave_unit == "A": wave = spec.wave_A / zp1 xlabel = "Rest Wavelength [Å]" if rest_frame else "Wavelength [Å]" else: wave = spec.wave_um / zp1 xlabel = "Rest Wavelength [µm]" if rest_frame else "Wavelength [µm]" use_flam = flux_unit.lower() == "flam" if use_flam: flux = _ujy_to_flam(spec.flux_ujy, spec.wave_um) err = _ujy_to_flam(spec.err_ujy, spec.wave_um) cont = _ujy_to_flam(result.continuum, spec.wave_um) model_total = _ujy_to_flam(result.model_flux + result.continuum, spec.wave_um) resid = _ujy_to_flam(result.residuals, spec.wave_um) flux_label = "erg/s/cm²/Å" else: flux = spec.flux_ujy err = spec.err_ujy cont = result.continuum model_total = result.model_flux + cont resid = result.residuals flux_label = "µJy" valid = spec.mask_valid() keep = _build_exclude_mask(spec.wave_A, exclude_wave_A) show = valid & keep # Insert NaN breaks at excluded-region boundaries so traces don't # draw lines through masked regions. def _nan_mask(arr: np.ndarray, mask: np.ndarray) -> np.ndarray: out = arr.copy().astype(float) out[~mask] = np.nan return out wave_k = _nan_mask(wave, keep) flux_s = _nan_mask(flux, show) err_s = _nan_mask(err, show) cont_k = _nan_mask(cont, keep) model_k = _nan_mask(model_total, keep) resid_s = _nan_mask(resid, show) # Decide whether to stack a 2-D panel on top. When the spectrum # carries no 2-D, "auto" silently no-ops, and an explicit True is # also no-op'd (so the call doesn't fail on synthetic spectra). if show_2d == "auto": _show_2d = spec.sci_2d is not None else: _show_2d = bool(show_2d) and spec.sci_2d is not None # Figure layout: 1, 2, or 3 rows depending on which panels are on. # Row indices and heights are derived once so the rest of the # function can refer to _main_row / _resid_row / _viz_2d_row # without re-deriving them. if _show_2d and show_residuals: fig = make_subplots( rows=3, cols=1, shared_xaxes=True, row_heights=[0.20, 0.60, 0.20], vertical_spacing=0.03, ) _viz_2d_row, _main_row, _resid_row = 1, 2, 3 elif _show_2d and not show_residuals: fig = make_subplots( rows=2, cols=1, shared_xaxes=True, row_heights=[0.26, 0.74], vertical_spacing=0.03, ) _viz_2d_row, _main_row, _resid_row = 1, 2, None elif show_residuals: fig = make_subplots( rows=2, cols=1, shared_xaxes=True, row_heights=[0.75, 0.25], vertical_spacing=0.04, ) _viz_2d_row, _main_row, _resid_row = None, 1, 2 else: fig = go.Figure() _viz_2d_row = _main_row = _resid_row = None _has_subplots = _show_2d or show_residuals def _add(trace, row=None): if _has_subplots and row is not None: fig.add_trace(trace, row=row, col=1) else: fig.add_trace(trace) # Error band. _add(go.Scatter( x=np.concatenate([wave[show], wave[show][::-1]]), y=np.concatenate([(flux + err)[show], (flux - err)[show][::-1]]), fill="toself", fillcolor="rgba(150,150,150,0.15)", line=dict(width=0), showlegend=False, hoverinfo="skip", ), row=_main_row) # Individual line components as smooth analytical Gaussians. if show_components and len(result.line_names) > 0: from math import sqrt, pi nL = len(result.line_names) import plotly.express as px palette = px.colors.qualitative.Set2 # Interpolate continuum onto a fine grid for smooth component curves. cont_interp_fn = np.interp peak_info = [] # Build index mapping: skip Lyα (not in Gaussian params vector). _gauss_names_i = [n for n in result.line_names if n != "Lya"] _gauss_idx_i = {n: j for j, n in enumerate(_gauss_names_i)} nL_gauss_i = len(_gauss_names_i) for i, name in enumerate(result.line_names): if name == "Lya": continue # Lyα uses skewed Gaussian, not in params vector gi = _gauss_idx_i[name] amp = result.params[gi] mu_A = result.params[nL_gauss_i + gi] sig_A = result.params[2 * nL_gauss_i + gi] if amp == 0 or sig_A <= 0: continue is_abs = name.startswith("abs_") # Fine wavelength grid around ±5σ of the line. w_lo = max(mu_A - 5 * sig_A, spec.wave_A.min()) w_hi = min(mu_A + 5 * sig_A, spec.wave_A.max()) n_fine = max(int((w_hi - w_lo) / (sig_A / 5)), 100) wave_fine_A = np.linspace(w_lo, w_hi, n_fine) wave_fine_um = wave_fine_A * 1e-4 # Analytical Gaussian in F_λ: G(λ) = A / (√(2π) σ) × exp(...) gauss_flam = (amp / (sqrt(2 * pi) * sig_A)) * np.exp( -0.5 * ((wave_fine_A - mu_A) / sig_A) ** 2 ) # Continuum at fine grid points (interpolated). cont_fine_ujy = np.interp(wave_fine_um, spec.wave_um, result.continuum) if use_flam: gauss_plot = gauss_flam + _ujy_to_flam(cont_fine_ujy, wave_fine_um) cont_fine = _ujy_to_flam(cont_fine_ujy, wave_fine_um) else: gauss_plot = _flam_to_ujy(gauss_flam, wave_fine_um) + cont_fine_ujy cont_fine = cont_fine_ujy # Convert to the chosen wave unit (with rest-frame scaling). wave_fine = wave_fine_A / zp1 if wave_unit == "A" else wave_fine_um / zp1 # Apply exclusion mask. keep_fine = _build_exclude_mask(wave_fine_A, exclude_wave_A) gauss_masked = gauss_plot.copy() gauss_masked[~keep_fine] = np.nan cont_fine_masked = cont_fine.copy() cont_fine_masked[~keep_fine] = np.nan is_broad = "BROAD" in name colour = _component_colour(name, is_abs) display_name = name.replace("abs_", "").replace("_", " ") dash = "dot" if is_broad else "solid" # Uncertainty shading (±1σ). lr = result.lines.get(name) frac_err = 0.0 if lr is not None and abs(lr.flux) > 0 and lr.flux_err > 0: frac_err = lr.flux_err / abs(lr.flux) if frac_err > 0: line_only = gauss_masked - cont_fine_masked comp_hi = cont_fine_masked + line_only * (1.0 + frac_err) comp_lo = cont_fine_masked + line_only * max(1.0 - frac_err, 0.0) fill = _to_rgba(colour, 0.12) _add(go.Scatter( x=np.concatenate([wave_fine, wave_fine[::-1]]), y=np.concatenate([comp_hi, comp_lo[::-1]]), fill="toself", fillcolor=fill, line=dict(width=0), showlegend=False, hoverinfo="skip", ), row=_main_row) # Smooth Gaussian curve. _add(go.Scatter( x=wave_fine, y=gauss_masked, mode="lines", name=f"{'[B] ' if is_broad else ''}{display_name}", line=dict(color=colour, width=1.5, dash=dash), hovertemplate=f"{display_name}<br>λ=%{{x:.1f}}<br>flux=%{{y:.4e}} {flux_label}<extra></extra>", showlegend=False, ), row=_main_row) # Store peak/trough position for annotation label. gauss_peak_flam = amp / (sqrt(2 * pi) * sig_A) cont_at_peak_ujy = np.interp( mu_A * 1e-4, spec.wave_um, result.continuum, ) if use_flam: y_peak = gauss_peak_flam + _ujy_to_flam( np.array([cont_at_peak_ujy]), np.array([mu_A * 1e-4]), )[0] else: y_peak = ( _flam_to_ujy( np.array([gauss_peak_flam]), np.array([mu_A * 1e-4]), )[0] + cont_at_peak_ujy ) x_peak = mu_A / zp1 if wave_unit == "A" else mu_A * 1e-4 / zp1 peak_info.append((name, x_peak, float(y_peak), colour, is_abs)) # Lyα asymmetric Gaussian overlay. _lya_p = getattr(result, "lya_params", None) if _lya_p is not None and len(_lya_p) == 4: from .models import asymmetric_gaussian as _ag _A_pk, _mu_lya, _sig_lya, _alpha_lya = _lya_p comp_col = _COMP_COLOUR_NARROW # Fine wavelength grid around the line. w_lo_c = max(_mu_lya - 8 * _sig_lya, spec.wave_A.min()) w_hi_c = min(_mu_lya + 12 * _sig_lya, spec.wave_A.max()) n_fine_c = max(int((w_hi_c - w_lo_c) / (_sig_lya / 5)), 200) wave_fine_c = np.linspace(w_lo_c, w_hi_c, n_fine_c) wave_fine_c_um = wave_fine_c * 1e-4 prof_flam = _ag(wave_fine_c, _A_pk, _mu_lya, _sig_lya, _alpha_lya) cont_fine_c_ujy = np.interp(wave_fine_c_um, spec.wave_um, result.continuum) if use_flam: prof_plot = prof_flam + _ujy_to_flam(cont_fine_c_ujy, wave_fine_c_um) cont_fine_c = _ujy_to_flam(cont_fine_c_ujy, wave_fine_c_um) else: prof_plot = _flam_to_ujy(prof_flam, wave_fine_c_um) + cont_fine_c_ujy cont_fine_c = cont_fine_c_ujy wave_fine_c_plot = wave_fine_c / zp1 if wave_unit == "A" else wave_fine_c_um / zp1 keep_c = _build_exclude_mask(wave_fine_c, exclude_wave_A) prof_masked = prof_plot.copy() prof_masked[~keep_c] = np.nan cont_fine_c_masked = cont_fine_c.copy() cont_fine_c_masked[~keep_c] = np.nan # Uncertainty shading. lr_lya = result.lines.get("Lya") frac_err_c = 0.0 if lr_lya is not None and abs(lr_lya.flux) > 0: fe = lr_lya.flux_err fe_val = 0.5 * (fe[0] + fe[1]) if isinstance(fe, tuple) else fe if fe_val > 0: frac_err_c = fe_val / abs(lr_lya.flux) if frac_err_c > 0: line_only = prof_masked - cont_fine_c_masked comp_hi = cont_fine_c_masked + line_only * (1.0 + frac_err_c) comp_lo = cont_fine_c_masked + line_only * max(1.0 - frac_err_c, 0.0) fill_c = _to_rgba(comp_col, 0.12) _add(go.Scatter( x=np.concatenate([wave_fine_c_plot, wave_fine_c_plot[::-1]]), y=np.concatenate([comp_hi, comp_lo[::-1]]), fill="toself", fillcolor=fill_c, line=dict(width=0), showlegend=False, hoverinfo="skip", ), row=_main_row) _add(go.Scatter( x=wave_fine_c_plot, y=prof_masked, mode="lines", name="Lyα", line=dict(color=comp_col, width=1.5, dash="solid"), hovertemplate="Lyα<br>λ=%{x:.1f}<br>flux=%{y:.4e} " + flux_label + "<extra></extra>", showlegend=False, ), row=_main_row) # Peak annotation. _pk_idx = np.argmax(prof_flam) peak_flam_c = prof_flam[_pk_idx] mu_pk = wave_fine_c[_pk_idx] cont_at_peak_c = np.interp(mu_pk * 1e-4, spec.wave_um, result.continuum) if use_flam: y_peak_c = peak_flam_c + _ujy_to_flam( np.array([cont_at_peak_c]), np.array([mu_pk * 1e-4]))[0] else: y_peak_c = _flam_to_ujy( np.array([peak_flam_c]), np.array([mu_pk * 1e-4]))[0] + cont_at_peak_c x_peak_c = mu_pk / zp1 if wave_unit == "A" else mu_pk * 1e-4 / zp1 peak_info.append(("Lyα", x_peak_c, float(y_peak_c), comp_col, False)) # --- Line-name annotations with row staggering to avoid overlap --- # Separate emission and absorption labels so each side staggers # independently (emission rises above peaks, absorption drops # below troughs). Within each side, sort by x and place each # label in the first row where it sits ≥ `threshold` away from # the last label in that row. emission_lbls = [info for info in peak_info if not info[4]] absorption_lbls = [info for info in peak_info if info[4]] row_spacing_px = 14 if peak_info: x_positions = [info[1] for info in peak_info] x_range = max(x_positions) - min(x_positions) threshold = 0.025 * x_range if x_range > 0 else 0.0 else: threshold = 0.0 def _assign_rows(labels, thr): labels_sorted = sorted(labels, key=lambda info: info[1]) row_last_x: list[float] = [] rows: list[int] = [] for _, x, _, _, _ in labels_sorted: placed = False for r, last_x in enumerate(row_last_x): if x - last_x >= thr: row_last_x[r] = x rows.append(r) placed = True break if not placed: row_last_x.append(x) rows.append(len(row_last_x) - 1) return labels_sorted, rows for labels_group, sign in ((emission_lbls, +1), (absorption_lbls, -1)): if not labels_group: continue labels_sorted, rows = _assign_rows(labels_group, threshold) base_yshift = 10 if sign > 0 else -12 yanchor = "bottom" if sign > 0 else "top" for (name, x_peak, y_peak, colour, _is_abs), row in zip( labels_sorted, rows, ): display_name = name.replace("abs_", "").replace("_", " ") is_broad = "BROAD" in name if "rgba" in colour: ann_colour = colour.rsplit(",", 1)[0] + ",1.0)" else: ann_colour = colour yshift = base_yshift + sign * row * row_spacing_px fig.add_annotation( x=x_peak, y=y_peak, xref="x", yref="y", text=f"<b>{display_name}</b>" if is_broad else display_name, showarrow=False, yshift=yshift, font=dict(size=9, color=ann_colour), xanchor="center", yanchor=yanchor, ) # --- 2-D rectified-spectrum panel (row=_viz_2d_row when enabled) --- if _show_2d: sci = np.asarray(spec.sci_2d, dtype=float) # Mirror the 1-D wavelength scaling so the 2-D heatmap aligns # column-for-column with the data trace below. if wave_unit == "A": wave_2d = spec.wave_A / zp1 else: wave_2d = spec.wave_um / zp1 ny = sci.shape[0] y_lo_pix = int(round(ny * y_crop[0])) y_hi_pix = max(y_lo_pix + 1, int(round(ny * y_crop[1]))) sci_crop = sci[y_lo_pix:y_hi_pix, :] y_pix = np.arange(y_lo_pix, y_hi_pix) finite = np.isfinite(sci_crop) if np.any(finite): vmin, vmax = np.nanpercentile( sci_crop[finite], [vmin_pct, vmax_pct], ) else: vmin, vmax = 0.0, 1.0 _add(go.Heatmap( x=wave_2d, y=y_pix, z=sci_crop, colorscale=cmap_2d, zmin=vmin, zmax=vmax, showscale=False, hoverinfo="skip", ), row=_viz_2d_row) # Data (steps). _add(go.Scatter( x=wave[show], y=flux[show], mode="lines", name="Data", line=dict(color="grey", width=0.8, shape="hvh"), hovertemplate=f"Data<br>λ=%{{x:.1f}}<br>flux=%{{y:.4e}} {flux_label}<extra></extra>", ), row=_main_row) # Continuum (steps). _add(go.Scatter( x=wave_k, y=cont_k, mode="lines", name="Continuum", line=dict(color="green", width=1, dash="dash", shape="hvh"), ), row=_main_row) # Model (steps, semi-transparent). _add(go.Scatter( x=wave_k, y=model_k, mode="lines", name="Model", line=dict(color="rgba(255,0,0,0.4)", width=1.5, shape="hvh"), hovertemplate=f"Model<br>λ=%{{x:.1f}}<br>flux=%{{y:.4e}} {flux_label}<extra></extra>", ), row=_main_row) # --- Residual panel --- if show_residuals: # Error band on residuals. _add(go.Scatter( x=np.concatenate([wave[show], wave[show][::-1]]), y=np.concatenate([err[show], (-err)[show][::-1]]), fill="toself", fillcolor="rgba(150,150,150,0.15)", line=dict(width=0), showlegend=False, hoverinfo="skip", ), row=_resid_row) # Zero line. _add(go.Scatter( x=[wave[show].min(), wave[show].max()], y=[0, 0], mode="lines", showlegend=False, line=dict(color="rgba(255,0,0,0.4)", width=1, dash="dash"), ), row=_resid_row) # Residual data. _add(go.Scatter( x=wave[show], y=resid[show], mode="lines", name="Residual", showlegend=False, line=dict(color="grey", width=0.8, shape="hvh"), hovertemplate=f"Residual<br>λ=%{{x:.1f}}<br>resid=%{{y:.4e}} {flux_label}<extra></extra>", ), row=_resid_row) # Y limits — account for absorption troughs. model_peak = np.nanmax(model_total[show]) if np.any(show) else 1.0 model_trough = np.nanmin(model_total[show]) if np.any(show) else 0.0 cont_median = np.nanmedian(cont[show]) if np.any(show) else 0.0 y_upper = cont_median + (model_peak - cont_median) * y_pad y_lower_cont = np.nanmin(cont[show]) * 1.1 if np.any(show) else -0.1 y_lower = min(0.0, y_lower_cont, model_trough - abs(model_trough) * 0.15) title_bits = [] if z_used is not None: title_bits.append(f"z = {float(z_used):.4f}") title_bits.append(f"χ²/dof = {result.chi2:.2f}") if rest_frame: title_bits.append("rest frame") title = " | ".join(title_bits) ylabel = f"fλ [{flux_label}]" if use_flam else f"Flux density [{flux_label}]" # Bottom-horizontal legend in every mode — the previous top-right # placement covered emission-line markers and labels. Margin # leaves room for the horizontal legend (and any stacked line # labels above the plot — _draw_emission_line_markers grows the # top margin to fit). _legend = dict( orientation="h", x=0.5, xanchor="center", y=-0.18, yanchor="top", ) if _has_subplots: # Figure height by row count: 1 + 2D, 1 + residual, or 1+2D+residual. if _show_2d and show_residuals: fig_height = 760 elif _show_2d: fig_height = 620 else: fig_height = 650 # legacy show_residuals-only height fig.update_layout( title=title, template="plotly_white", hovermode=False, dragmode="zoom", legend=_legend, width=1000, height=fig_height, margin=dict(b=110), ) fig.update_xaxes(exponentformat="none") # Main panel always gets the wavelength label so the unit is # visible without scrolling; bottom-most panel also gets it. fig.update_xaxes(title_text=xlabel, row=_main_row, col=1) fig.update_yaxes( title_text=ylabel, range=[y_lower, y_upper], row=_main_row, col=1, ) if _show_2d: fig.update_yaxes( title_text="Spatial pix", showticklabels=False, row=_viz_2d_row, col=1, ) if show_residuals: # Clip residual y-axis to ±5× median error so noise spikes # don't dominate the panel height. med_err = float(np.nanmedian(err[show])) if np.any(show) else 1.0 res_ylim = 5.0 * med_err fig.update_yaxes( title_text="Residual", range=[-res_ylim, res_ylim], row=_resid_row, col=1, ) fig.update_xaxes(title_text=xlabel, row=_resid_row, col=1) else: fig.update_layout( title=title, xaxis_title=xlabel, yaxis_title=ylabel, yaxis_range=[y_lower, y_upper], xaxis=dict(exponentformat="none"), template="plotly_white", hovermode=False, dragmode="zoom", legend=_legend, width=1000, height=500, margin=dict(b=110), ) # --- Curated emission-line markers (rest-frame aware) --- if rest_frame: z_for_lines: float | None = 0.0 elif z_used is not None: z_for_lines = float(z_used) else: z_for_lines = None if np.any(show): x_data = wave[show] x_lo_m = float(np.nanmin(x_data)) x_hi_m = float(np.nanmax(x_data)) else: x_lo_m = x_hi_m = None _draw_emission_line_markers( fig, z_for_lines=z_for_lines, wave_unit=wave_unit, x_lo=x_lo_m, x_hi=x_hi_m, lines=lines, add_lines=add_lines, line_color=line_color, use_paper_shapes=_has_subplots, ) return fig
[docs] def plot_2d_1d( path: str | Path, z: float | None = None, *, sci_ext: str = "SCI", hdu: str | int | None = None, wave_col: str | None = None, flux_col: str | None = None, err_col: str | None = None, flux_scale: float = 1e3, flux_label: str = r"$F_\nu$ [nJy]", xlabel: str = r"$\lambda_{\rm obs}\,[\mu{\rm m}]$", cmap: str = "Blues", vmin_pct: float = 5.0, vmax_pct: float = 99.5, y_crop: tuple[float, float] = (0.25, 0.75), xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, line_colour: str = "steelblue", line_width: float = 0.5, err_colour: str = "grey", err_alpha: float = 0.35, figsize: tuple[float, float] = (8, 4), dpi: float = 300, title: str | None = None, lines: list[str] | bool | None = None, add_lines: list[str] | dict[str, float] | None = None, height_ratios: tuple[float, float] = (0.35, 1.0), ) -> tuple["Figure", tuple["Axes", "Axes"]]: """Plot a JWST/NIRSpec 2D + 1D spectrum from a ``.spec.fits`` file. Reads the 2D image from ``sci_ext`` and the 1D extraction from ``SPEC1D`` via :func:`jwspecfit.read_fits`, then renders a stacked figure (2D pcolormesh above, 1D flux below) with optional emission- line markers at the supplied redshift. Parameters ---------- path Path to the ``.spec.fits`` file. z Redshift used to place observed-frame line markers. If ``None``, no markers are drawn. sci_ext Name of the 2D image extension (default ``"SCI"``). hdu, wave_col, flux_col, err_col Forwarded to :func:`read_fits` for 1D extraction (override the default SPEC1D / column auto-detection). flux_scale, flux_label Multiplier and y-axis label applied to the 1D flux (default converts µJy → nJy). xlabel Shared x-axis label (default observed wavelength in µm). cmap Colormap for the 2D panel. vmin_pct, vmax_pct Percentile clip for the 2D colour scale. y_crop ``(y_lo_frac, y_hi_frac)`` fractional crop of the 2D vertical extent (default keeps the middle 50 % rows). xlim, ylim Optional axis limits. ``xlim`` applies to both panels. line_colour, line_width Colour and line width of the 1D flux trace. err_colour, err_alpha Fill colour and alpha for the ``±1σ`` error band. figsize, dpi, height_ratios Matplotlib figure size, resolution (default 300), and 2D-vs-1D height ratio. title Optional title placed above the 2D panel. lines Default-marker control: ``None`` uses the package defaults; a list of :data:`REST_LINES_A` keys restricts to those names; ``False`` disables defaults entirely. add_lines Extra markers: either REST_LINES_A keys, or a ``{label: rest_wavelength_A}`` dict. Returns ------- fig, (ax_2d, ax_1d) Matplotlib figure and the two axes. """ import matplotlib.pyplot as plt from astropy.io import fits from .io import read_fits from .lines import REST_LINES_A path = Path(path) spec = read_fits( path, hdu=hdu, wave_col=wave_col, flux_col=flux_col, err_col=err_col, ) with fits.open(path) as hdul: ext_names = [h.name for h in hdul] if sci_ext not in ext_names: raise KeyError( f"Extension {sci_ext!r} not found in {path.name} " f"(available: {ext_names})" ) sci2d = np.asarray(hdul[sci_ext].data) if sci2d.ndim != 2: raise ValueError( f"Expected 2D array in {sci_ext!r}, got shape {sci2d.shape}" ) wave = np.asarray(spec.wave_um) flux = np.asarray(spec.flux_ujy) err = np.asarray(spec.err_ujy) if spec.err_ujy is not None else None # 2D wavelength axis must match the 1D wave grid in length. if sci2d.shape[1] != wave.size: raise ValueError( f"SCI wavelength axis ({sci2d.shape[1]} pix) does not match " f"SPEC1D wave length ({wave.size})." ) fig, (ax2d, ax1d) = plt.subplots( 2, 1, figsize=figsize, dpi=dpi, gridspec_kw={"height_ratios": list(height_ratios), "hspace": 0.05}, sharex=True, ) # --- 2D panel -------------------------------------------------- dw = float(np.median(np.diff(wave))) wave_edges = np.concatenate([ [wave[0] - dw / 2.0], 0.5 * (wave[1:] + wave[:-1]), [wave[-1] + dw / 2.0], ]) ny = sci2d.shape[0] y_edges = np.arange(ny + 1) vmin, vmax = np.nanpercentile(sci2d, [vmin_pct, vmax_pct]) ax2d.pcolormesh( wave_edges, y_edges, sci2d, shading="auto", cmap=cmap, vmin=vmin, vmax=vmax, ) ax2d.set_ylim(ny * y_crop[0], ny * y_crop[1]) ax2d.tick_params(labelbottom=False, labelleft=False) ax2d.set_ylabel("Spatial pix", fontsize=9) if title is not None: ax2d.set_title(title, fontsize=10) # --- 1D panel -------------------------------------------------- f_scaled = flux * flux_scale ax1d.plot(wave, f_scaled, color=line_colour, lw=line_width) if err is not None: ax1d.fill_between( wave, (flux - err) * flux_scale, (flux + err) * flux_scale, color=err_colour, alpha=err_alpha, lw=0, ) ax1d.axhline(0, color="k", lw=0.8, ls="--", alpha=0.6) ax1d.tick_params(direction="in", top=True, right=True, labelsize=8) ax1d.set_xlabel(xlabel, fontsize=10) ax1d.set_ylabel(flux_label, fontsize=10) if xlim is not None: ax1d.set_xlim(*xlim) ax2d.set_xlim(*xlim) if ylim is not None: ax1d.set_ylim(*ylim) else: # Auto-scale to the brightest emission line in the visible window # (typically [OIII]5007 for low-z optical spectra) — matches the # plot_spectrum_interactive behaviour: 2nd-percentile floor, data # max with a small pad. x_lo, x_hi = ax1d.get_xlim() in_view = (wave >= x_lo) & (wave <= x_hi) & np.isfinite(f_scaled) if np.any(in_view): lo = float(np.nanpercentile(f_scaled[in_view], 2)) hi = float(np.nanmax(f_scaled[in_view])) pad = 0.05 * (hi - lo if hi > lo else max(abs(hi), 1.0)) y_lower = min(0.0, lo - pad) y_upper = hi + pad if y_upper > y_lower: ax1d.set_ylim(y_lower, y_upper) # --- Emission-line markers ------------------------------------ if z is not None: x_lo, x_hi = ax1d.get_xlim() markers: list[tuple[float, str]] = [] if lines is not False: names = _DEFAULT_MARKER_NAMES if lines is None else list(lines) for nm in names: rest_A = REST_LINES_A.get(nm) if rest_A is None: continue obs_um = float(rest_A) * (1.0 + z) / 1e4 if x_lo < obs_um < x_hi: label = _DEFAULT_MARKER_LABELS.get(nm, nm) markers.append((obs_um, label)) if add_lines: if isinstance(add_lines, dict): add_items = list(add_lines.items()) else: add_items = [] for nm in add_lines: rest_A = REST_LINES_A.get(nm) if rest_A is None: continue label = _DEFAULT_MARKER_LABELS.get(nm, nm.replace("_", " ")) add_items.append((label, rest_A)) for label, rest_A in add_items: obs_um = float(rest_A) * (1.0 + z) / 1e4 if x_lo < obs_um < x_hi: markers.append((obs_um, str(label))) y_top = ax1d.get_ylim()[1] for x_obs, label in markers: for ax in (ax2d, ax1d): ax.axvline(x_obs, color="gray", ls="--", lw=0.7, alpha=0.6) ax1d.text( x_obs, y_top * 0.95, label, color="gray", rotation=90, ha="center", va="top", fontsize=7, ) fig.subplots_adjust(left=0.10, right=0.97, top=0.94, bottom=0.12) return fig, (ax2d, ax1d)