Source code for adam_core.photometry.magnitude_common

from __future__ import annotations

from functools import lru_cache
from typing import TypeAlias, Union

import jax
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
import pyarrow as pa

from .bandpasses.api import assert_filter_ids_have_curves  # noqa: F401
from .bandpasses.api import compute_mix_integrals as _compute_bandpass_mix_integrals
from .bandpasses.api import get_integrals as _get_bandpass_integrals
from .bandpasses.api import load_bandpass_curves as _load_bandpass_curves

JAX_CHUNK_SIZE = 8192

BandpassComposition: TypeAlias = Union[str, tuple[float, float]]


[docs] @lru_cache(maxsize=1) def bandpass_filter_id_table() -> tuple[ tuple[str, ...], pa.Array, dict[str, int], int, ]: """ Return (filter_ids, filter_ids_arrow, filter_to_id, v_id) for bandpass conversions. We intentionally build this lazily (rather than at import time) since it requires reading packaged Parquet data. """ curves = _load_bandpass_curves() filter_ids = tuple(curves.filter_id.to_pylist()) if "V" not in filter_ids: raise ValueError("Bandpass curves must include a canonical 'V' filter_id.") filter_ids_arrow = pa.array(list(filter_ids), type=pa.large_string()) filter_to_id = {name: i for i, name in enumerate(filter_ids)} v_id = int(filter_to_id["V"]) return filter_ids, filter_ids_arrow, filter_to_id, v_id
[docs] def bandpass_integrals_for_composition( composition: BandpassComposition, filter_ids: npt.NDArray[np.object_] ) -> npt.NDArray[np.float64]: if isinstance(composition, str): return _get_bandpass_integrals(composition, filter_ids) try: w_c, w_s = composition except Exception as e: raise TypeError( "composition must be either a template_id string (e.g. 'C') " "or a (weight_C, weight_S) tuple" ) from e return _compute_bandpass_mix_integrals(float(w_c), float(w_s), filter_ids)
[docs] def bandpass_composition_key(composition: BandpassComposition) -> BandpassComposition: if isinstance(composition, str): if not composition: raise ValueError("composition template_id must be non-empty") return composition try: w_c, w_s = composition except Exception as e: raise TypeError( "composition must be either a template_id string (e.g. 'C') " "or a (weight_C, weight_S) tuple" ) from e w_c = float(w_c) w_s = float(w_s) if not np.isfinite(w_c) or not np.isfinite(w_s): raise ValueError("composition weights must be finite") if w_c < 0.0 or w_s < 0.0: raise ValueError("composition weights must be non-negative") s = w_c + w_s if s <= 0.0: raise ValueError("at least one composition weight must be > 0") return (w_c / s, w_s / s)
[docs] @lru_cache(maxsize=None) def bandpass_delta_table_for_composition_cached( composition_key: BandpassComposition, ) -> npt.NDArray[np.float64]: """ Compute per-filter delta magnitudes relative to V for the given composition: delta[filter] = m_filter - m_V """ filter_ids, _, _, v_id = bandpass_filter_id_table() ids = np.asarray(filter_ids, dtype=object) integrals = bandpass_integrals_for_composition(composition_key, ids) i_v = float(integrals[v_id]) if not np.isfinite(i_v) or i_v <= 0.0: raise ValueError("Invalid V-band integral for bandpass conversion.") with np.errstate(divide="raise", invalid="raise"): delta = -2.5 * np.log10(np.asarray(integrals, dtype=np.float64) / i_v) return np.asarray(delta, dtype=np.float64)
[docs] @lru_cache(maxsize=None) def bandpass_delta_table_jax_for_composition_cached( composition_key: BandpassComposition, ) -> jax.Array: delta = bandpass_delta_table_for_composition_cached(composition_key) return jnp.asarray(delta, dtype=jnp.float64)
[docs] def bandpass_delta_table_for_composition( composition: BandpassComposition, ) -> npt.NDArray[np.float64]: return bandpass_delta_table_for_composition_cached( bandpass_composition_key(composition) )