Source code for adam_core.photometry.rotation.jax_backend

from __future__ import annotations

from dataclasses import dataclass
from functools import partial
from math import ceil

import jax
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt

jax.config.update("jax_enable_x64", True)  # type: ignore[no-untyped-call]

# Design-matrix column-bucket size.  The design width varies per object (it grows with
# the per-(filter, session) offset columns), and array width is a JIT shape -> a fresh
# XLA compile per distinct width, i.e. ~per object.  Padding the width up to a multiple
# of this collapses those into a few shared shapes so the compile cache is reused across
# objects.  Padded columns are zeros (the +1e-15*I ridge solves their coeffs to 0), so
# the fit is bit-identical; ``df`` is computed from the REAL parameter count.
_DESIGN_COL_PAD_MULTIPLE = 8


[docs] @dataclass(slots=True) class JAXBatchFitResult: scores: npt.NDArray[np.float64] best_valid: bool best_coeffs: npt.NDArray[np.float64] best_mask: npt.NDArray[np.bool_] best_sigma: float best_rss: float best_df: int best_n_fit: int best_n_clipped: int
def _next_multiple(value: int, multiple: int) -> int: if multiple <= 1: return int(value) return int(ceil(value / multiple) * multiple) def _row_bucket(n_rows: int, floor: int) -> int: """Round the row count up to a power-of-two bucket (>= floor). Padded rows are masked out of the fit (``row_mask``), so the bucket is purely a JIT-shape control: n_obs spans ~30..2000 across the candle set, and a fine ×64 pad makes almost every object a distinct row shape -> a fresh XLA compile each. Power-of-two buckets collapse that to ~7 shared shapes (waste capped at <2x rows). """ n = max(int(n_rows), int(floor)) return 1 << (n - 1).bit_length() def _pad_rows( *, time_rel: npt.NDArray[np.float64], y: npt.NDArray[np.float64], fixed: npt.NDArray[np.float64], weights: npt.NDArray[np.float64] | None, row_pad_multiple: int, ) -> tuple[ npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.bool_], ]: n_rows = int(y.shape[0]) n_padded = _row_bucket(n_rows, row_pad_multiple) time_pad = np.zeros(n_padded, dtype=np.float64) y_pad = np.zeros(n_padded, dtype=np.float64) fixed_pad = np.zeros((n_padded, fixed.shape[1]), dtype=np.float64) weights_pad = np.zeros(n_padded, dtype=np.float64) row_mask = np.zeros(n_padded, dtype=bool) time_pad[:n_rows] = time_rel y_pad[:n_rows] = y fixed_pad[:n_rows, :] = fixed weights_pad[:n_rows] = ( np.ones(n_rows, dtype=np.float64) if weights is None else np.asarray(weights, dtype=np.float64) ) row_mask[:n_rows] = True return time_pad, y_pad, fixed_pad, weights_pad, row_mask @partial(jax.jit, static_argnames=("fourier_order",)) def _build_fourier_batch( time_rel: jnp.ndarray, frequencies: jnp.ndarray, *, fourier_order: int, ) -> jnp.ndarray: phase = 2.0 * jnp.pi * frequencies[:, None] * time_rel[None, :] cols = [] for harmonic in range(1, fourier_order + 1): harmonic_phase = harmonic * phase cols.append(jnp.cos(harmonic_phase)) cols.append(jnp.sin(harmonic_phase)) return jnp.stack(cols, axis=2) @partial( jax.jit, static_argnames=("fourier_order", "max_clip_iterations", "has_observation_weights"), ) def _evaluate_frequency_batch_jit( time_rel: jnp.ndarray, y: jnp.ndarray, fixed: jnp.ndarray, weights: jnp.ndarray, row_mask: jnp.ndarray, prior_rows: jnp.ndarray, prior_target: jnp.ndarray, prior_weights: jnp.ndarray, frequencies: jnp.ndarray, frequency_valid: jnp.ndarray, n_par_real: jnp.ndarray, *, fourier_order: int, clip_sigma: float, max_clip_iterations: int, has_observation_weights: bool, ) -> tuple[ jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: fourier = _build_fourier_batch( time_rel, frequencies, fourier_order=fourier_order, ) design_real = jnp.concatenate( [ jnp.broadcast_to( fixed[None, :, :], (frequencies.shape[0], fixed.shape[0], fixed.shape[1]), ), fourier, ], axis=2, ) n_par = design_real.shape[2] eye = jnp.eye(n_par, dtype=jnp.float64) target = jnp.broadcast_to(y[None, :], (frequencies.shape[0], y.shape[0])) active = ( jnp.broadcast_to(row_mask[None, :], target.shape) & frequency_valid[:, None] ) n_obs = jnp.sum(row_mask) prior_design = jnp.broadcast_to( prior_rows[None, :, :], (frequencies.shape[0], prior_rows.shape[0], prior_rows.shape[1]), ) sqrt_prior_weights = jnp.sqrt(prior_weights) prior_design_w = prior_design * sqrt_prior_weights[None, :, None] prior_target_w = jnp.broadcast_to( prior_target[None, :] * sqrt_prior_weights[None, :], (frequencies.shape[0], prior_target.shape[0]), ) def solve_with_mask( active_mask: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: w_eff = jnp.where(active_mask, weights[None, :], 0.0) sqrt_w = jnp.sqrt(w_eff) design_real_w = design_real * sqrt_w[:, :, None] target_real_w = target * sqrt_w design_w = jnp.concatenate([design_real_w, prior_design_w], axis=1) target_w = jnp.concatenate([target_real_w, prior_target_w], axis=1) xt = jnp.swapaxes(design_w, 1, 2) xtx = jnp.matmul(xt, design_w) xty = jnp.matmul(xt, target_w[:, :, None])[:, :, 0] coeffs = jnp.linalg.solve( xtx + 1.0e-15 * eye[None, :, :], xty[:, :, None], )[:, :, 0] model = jnp.matmul(design_real, coeffs[:, :, None])[:, :, 0] resid = target - model n_fit = jnp.sum(active_mask, axis=1) df = n_fit - n_par_real if has_observation_weights: rss = jnp.sum(jnp.where(active_mask, w_eff * resid * resid, 0.0), axis=1) weight_sum = jnp.sum(w_eff, axis=1) sigma2 = jnp.where( (df > 0) & (weight_sum > 0.0), (n_obs / weight_sum) * rss / df, jnp.inf, ) else: rss = jnp.sum(jnp.where(active_mask, resid * resid, 0.0), axis=1) sigma2 = jnp.where(df > 0, rss / df, jnp.inf) sigma = jnp.sqrt(jnp.maximum(sigma2, 0.0)) return coeffs, resid, rss, sigma coeffs = jnp.zeros((frequencies.shape[0], n_par), dtype=jnp.float64) resid = jnp.zeros_like(target) rss = jnp.full((frequencies.shape[0],), jnp.inf, dtype=jnp.float64) sigma = jnp.full((frequencies.shape[0],), jnp.inf, dtype=jnp.float64) for _ in range(max_clip_iterations): coeffs, resid, rss, sigma = solve_with_mask(active) clip_limit = clip_sigma * sigma[:, None] keep = jnp.abs(resid) <= clip_limit new_active = active & keep active = jnp.where( frequency_valid[:, None], new_active, active, ) coeffs, resid, rss, sigma = solve_with_mask(active) n_fit = jnp.sum(active, axis=1) df = n_fit - n_par_real valid = frequency_valid & (df > 0) & jnp.isfinite(sigma) scores = jnp.where(valid, sigma, jnp.inf) best_idx = jnp.argmin(scores) return ( scores, valid, coeffs[best_idx], active[best_idx], sigma[best_idx], rss[best_idx], df[best_idx], n_fit[best_idx], )
[docs] def evaluate_frequency_indices_jax( *, time_rel: npt.NDArray[np.float64], y: npt.NDArray[np.float64], fixed: npt.NDArray[np.float64], weights: npt.NDArray[np.float64] | None, prior_rows: npt.NDArray[np.float64], prior_target: npt.NDArray[np.float64], prior_weights: npt.NDArray[np.float64], frequencies: npt.NDArray[np.float64], sample_indices: npt.NDArray[np.int64], fourier_order: int, clip_sigma: float, jax_batch_size: int, row_pad_multiple: int, max_clip_iterations: int, ) -> JAXBatchFitResult: if sample_indices.size == 0: return JAXBatchFitResult( scores=np.zeros(0, dtype=np.float64), best_valid=False, best_coeffs=np.zeros(fixed.shape[1] + 2 * fourier_order, dtype=np.float64), best_mask=np.zeros(time_rel.shape[0], dtype=bool), best_sigma=float("inf"), best_rss=float("inf"), best_df=0, best_n_fit=0, best_n_clipped=int(time_rel.shape[0]), ) time_pad, y_pad, fixed_pad, weights_pad, row_mask = _pad_rows( time_rel=np.asarray(time_rel, dtype=np.float64), y=np.asarray(y, dtype=np.float64), fixed=np.asarray(fixed, dtype=np.float64), weights=None if weights is None else np.asarray(weights, dtype=np.float64), row_pad_multiple=row_pad_multiple, ) # Column-bucket the design width (see _DESIGN_COL_PAD_MULTIPLE): pad with zero # columns so distinct per-object widths share JIT-compiled shapes. ``n_par_real`` # carries the true parameter count (for df); padded coeffs are stripped on return. real_design_width = int(fixed.shape[1]) n_par_real = real_design_width + 2 * int(fourier_order) padded_design_width = _next_multiple(real_design_width, _DESIGN_COL_PAD_MULTIPLE) col_pad = int(padded_design_width - real_design_width) if col_pad: fixed_pad = np.concatenate( [fixed_pad, np.zeros((fixed_pad.shape[0], col_pad), dtype=np.float64)], axis=1, ) pr = np.asarray(prior_rows, dtype=np.float64) prior_rows = np.concatenate( [ pr[:, :real_design_width], np.zeros((pr.shape[0], col_pad), dtype=np.float64), pr[:, real_design_width:], ], axis=1, ) n_scores = int(sample_indices.size) scores = np.full(n_scores, np.nan, dtype=np.float64) best_valid = False best_sigma = float("inf") best_coeffs = np.zeros(fixed.shape[1] + 2 * fourier_order, dtype=np.float64) best_mask = np.zeros(time_rel.shape[0], dtype=bool) best_rss = float("inf") best_df = 0 best_n_fit = 0 batch_size = max(1, int(jax_batch_size)) n_batches = int(ceil(n_scores / batch_size)) for batch_idx in range(n_batches): start = batch_idx * batch_size stop = min(start + batch_size, n_scores) batch_indices = np.asarray(sample_indices[start:stop], dtype=np.int64) frequencies_batch = np.zeros(batch_size, dtype=np.float64) frequency_valid = np.zeros(batch_size, dtype=bool) valid_count = stop - start frequencies_batch[:valid_count] = frequencies[batch_indices] frequency_valid[:valid_count] = True ( scores_batch, valid_batch, coeffs_batch, mask_batch, sigma_batch, rss_batch, df_batch, n_fit_batch, ) = _evaluate_frequency_batch_jit( jnp.asarray(time_pad), jnp.asarray(y_pad), jnp.asarray(fixed_pad), jnp.asarray(weights_pad), jnp.asarray(row_mask), jnp.asarray(prior_rows, dtype=jnp.float64), jnp.asarray(prior_target, dtype=jnp.float64), jnp.asarray(prior_weights, dtype=jnp.float64), jnp.asarray(frequencies_batch), jnp.asarray(frequency_valid), jnp.asarray(n_par_real, dtype=jnp.int64), fourier_order=int(fourier_order), clip_sigma=float(clip_sigma), max_clip_iterations=int(max_clip_iterations), has_observation_weights=weights is not None, ) scores_np = np.asarray(scores_batch, dtype=np.float64)[:valid_count] valid_np = np.asarray(valid_batch, dtype=bool)[:valid_count] scores[start:stop] = np.where(valid_np, scores_np, np.nan) if np.any(valid_np): local_idx = int(np.nanargmin(np.where(valid_np, scores_np, np.nan))) local_sigma = float(scores_np[local_idx]) if local_sigma < best_sigma: best_valid = True best_sigma = local_sigma cb = np.asarray(coeffs_batch, dtype=np.float64) best_coeffs = ( cb if not col_pad else np.concatenate( [ cb[:real_design_width], cb[ padded_design_width : padded_design_width + 2 * fourier_order ], ] ) ) best_mask = np.asarray(mask_batch, dtype=bool)[: time_rel.shape[0]] best_rss = float(np.asarray(rss_batch, dtype=np.float64)) best_df = int(np.asarray(df_batch, dtype=np.int64)) best_n_fit = int(np.asarray(n_fit_batch, dtype=np.int64)) return JAXBatchFitResult( scores=scores, best_valid=best_valid, best_coeffs=best_coeffs, best_mask=best_mask, best_sigma=best_sigma, best_rss=best_rss, best_df=best_df, best_n_fit=best_n_fit, best_n_clipped=int(time_rel.shape[0] - best_n_fit), )