Source code for adam_core.dynamics.propagation

import multiprocessing as mp
from typing import Dict, List, Optional, Tuple

import jax.numpy as jnp
import numpy as np
import quivr as qv
import ray
from jax import config, jit, vmap
from ray import ObjectRef

from ..coordinates.cartesian import CartesianCoordinates
from ..coordinates.covariances import (
    CoordinateCovariances,
    transform_covariances_jacobian,
)
from ..coordinates.origin import Origin
from ..orbits.orbits import Orbits
from ..ray_cluster import initialize_use_ray
from ..time import Timestamp
from ..utils.chunking import process_in_chunks
from ..utils.iter import _iterate_chunks
from .chi import calc_chi_diagnostics
from .exceptions import DynamicsNumericalError
from .lagrange import apply_lagrange_coefficients, calc_lagrange_coefficients

config.update("jax_enable_x64", True)


@jit
def _propagate_2body(
    orbit: jnp.ndarray,
    t0: float,
    t1: float,
    mu: float,
    max_iter: int = 1000,
    tol: float = 1e-14,
) -> jnp.ndarray:
    """
    Propagate an orbit from t0 to t1.

    Parameters
    ----------
    orbit : `~jax.numpy.ndarray` (6)
        Cartesian orbit with position in units of au and velocity in units of au per day.
    t0 : float (1)
        Epoch in MJD at which the orbit are defined.
    t1 : float (N)
        Epochs to which to propagate the given orbit.
    mu : float (1)
        Gravitational parameter (GM) of the attracting body in units of
        au**3 / d**2.
    max_iter : int, optional
        Maximum number of iterations over which to converge. If number of iterations is
        exceeded, will return the value of the universal anomaly at the last iteration.
    tol : float, optional
        Numerical tolerance to which to compute universal anomaly using the Newtown-Raphson
        method.

    Returns
    -------
    orbits : `~jax.numpy.ndarray` (N, 6)
        Orbit propagated to each MJD with position in units of au and velocity in units
        of au per day.
    """
    r = orbit[0:3]
    v = orbit[3:6]
    dt = t1 - t0

    lagrange_coeffs, stumpff_coeffs, chi = calc_lagrange_coefficients(
        r, v, dt, mu=mu, max_iter=max_iter, tol=tol
    )
    r_new, v_new = apply_lagrange_coefficients(r, v, *lagrange_coeffs)

    return jnp.array([r_new[0], r_new[1], r_new[2], v_new[0], v_new[1], v_new[2]])


# Vectorization Map: _propagate_2body
_propagate_2body_vmap = jit(
    vmap(_propagate_2body, in_axes=(0, 0, 0, 0, None, None), out_axes=(0))
)


def _first_non_finite_row(values: np.ndarray) -> Optional[int]:
    finite_mask = np.isfinite(values).all(axis=1)
    bad = np.flatnonzero(~finite_mask)
    return int(bad[0]) if bad.size > 0 else None


def _raise_non_finite_propagation_error(
    *,
    stage: str,
    reason: str,
    absolute_idx: int,
    orbit_id: str,
    object_id: str,
    orbit_row: np.ndarray,
    t0: float,
    t1: float,
    mu: float,
    max_iter: int,
    tol: float,
) -> None:
    diag = calc_chi_diagnostics(
        orbit_row[0:3],
        orbit_row[3:6],
        t1 - t0,
        mu=mu,
        max_iter=max_iter,
        tol=tol,
    )
    raise DynamicsNumericalError(
        stage=stage,
        reason=reason,
        context={
            "row_index": absolute_idx,
            "orbit_id": orbit_id,
            "object_id": object_id,
            "t0": float(t0),
            "t1": float(t1),
            "dt": float(t1 - t0),
            "mu": float(mu),
            "r_norm": diag.r_norm,
            "v_norm": diag.v_norm,
            "alpha": diag.alpha,
            "chi": diag.chi,
            "chi_finite": diag.finite,
            "max_iter": int(max_iter),
            "tol": float(tol),
        },
    )


def _propagate_2body_serial(
    orbits: Orbits,
    times: Timestamp,
    *,
    max_iter: int,
    tol: float,
) -> Orbits:
    """
    Serial (single-process) implementation of 2-body propagation.

    The Ray backend uses this function inside each worker.
    """
    # Extract and prepare data
    cartesian_orbits = orbits.coordinates.values
    t0 = orbits.coordinates.time.rescale("tdb").mjd()
    t1 = times.rescale("tdb").mjd()
    mu = orbits.coordinates.origin.mu()
    orbit_ids = orbits.orbit_id.to_numpy(zero_copy_only=False)
    object_ids = orbits.object_id.to_numpy(zero_copy_only=False)

    # Fixed chunk size to keep JAX shapes stable.
    chunk_size = 200

    n_orbits = cartesian_orbits.shape[0]
    n_times = len(times)
    orbit_ids_ = np.repeat(orbit_ids, n_times)
    object_ids_ = np.repeat(object_ids, n_times)
    orbits_array_ = np.repeat(cartesian_orbits, n_times, axis=0)
    mu_ = np.repeat(mu, n_times)
    t0_ = np.repeat(t0, n_times)
    t1_ = np.tile(t1, n_orbits)

    # Preserve physical parameters by repeating per-orbit rows across times.
    pp_idx = np.repeat(np.arange(n_orbits), n_times).tolist()
    physical_parameters_ = orbits.physical_parameters.take(pp_idx)

    num_entries = n_orbits * n_times
    orbits_propagated = np.empty((num_entries, 6), dtype=np.float64)
    start = 0
    for orbits_chunk, t0_chunk, t1_chunk, mu_chunk in zip(
        process_in_chunks(orbits_array_, chunk_size),
        process_in_chunks(t0_, chunk_size),
        process_in_chunks(t1_, chunk_size),
        process_in_chunks(mu_, chunk_size),
    ):
        valid = min(chunk_size, num_entries - start)
        bad_input = _first_non_finite_row(orbits_chunk[:valid])
        if bad_input is not None:
            abs_idx = start + bad_input
            _raise_non_finite_propagation_error(
                stage="propagation",
                reason="non_finite_input_state",
                absolute_idx=abs_idx,
                orbit_id=str(orbit_ids_[abs_idx]),
                object_id=str(object_ids_[abs_idx]),
                orbit_row=np.asarray(orbits_chunk[bad_input], dtype=np.float64),
                t0=float(t0_chunk[bad_input]),
                t1=float(t1_chunk[bad_input]),
                mu=float(mu_chunk[bad_input]),
                max_iter=max_iter,
                tol=tol,
            )
        orbits_propagated_chunk = _propagate_2body_vmap(
            orbits_chunk, t0_chunk, t1_chunk, mu_chunk, max_iter, tol
        )
        chunk_np = np.asarray(orbits_propagated_chunk, dtype=np.float64)[:valid]
        bad_output = _first_non_finite_row(chunk_np)
        if bad_output is not None:
            abs_idx = start + bad_output
            _raise_non_finite_propagation_error(
                stage="propagation",
                reason="non_finite_output_state",
                absolute_idx=abs_idx,
                orbit_id=str(orbit_ids_[abs_idx]),
                object_id=str(object_ids_[abs_idx]),
                orbit_row=np.asarray(orbits_chunk[bad_output], dtype=np.float64),
                t0=float(t0_chunk[bad_output]),
                t1=float(t1_chunk[bad_output]),
                mu=float(mu_chunk[bad_output]),
                max_iter=max_iter,
                tol=tol,
            )
        orbits_propagated[start : start + valid] = chunk_np
        start += valid

    if start != num_entries:
        raise RuntimeError(
            f"Internal error: expected {num_entries} propagated rows, got {start}"
        )

    if not orbits.coordinates.covariance.is_all_nan():
        cartesian_covariances = orbits.coordinates.covariance.to_matrix()
        covariances_array_ = np.repeat(cartesian_covariances, n_times, axis=0)

        cartesian_covariances = transform_covariances_jacobian(
            orbits_array_,
            covariances_array_,
            _propagate_2body,
            in_axes=(0, 0, 0, 0, None, None),
            out_axes=0,
            t0=t0_,
            t1=t1_,
            mu=mu_,
            max_iter=max_iter,
            tol=tol,
        )
        cartesian_covariances = CoordinateCovariances.from_matrix(cartesian_covariances)
    else:
        cartesian_covariances = None

    origin_code = np.repeat(
        orbits.coordinates.origin.code.to_numpy(zero_copy_only=False), n_times
    )

    return Orbits.from_kwargs(
        orbit_id=orbit_ids_,
        object_id=object_ids_,
        physical_parameters=physical_parameters_,
        coordinates=CartesianCoordinates.from_kwargs(
            x=orbits_propagated[:, 0],
            y=orbits_propagated[:, 1],
            z=orbits_propagated[:, 2],
            vx=orbits_propagated[:, 3],
            vy=orbits_propagated[:, 4],
            vz=orbits_propagated[:, 5],
            covariance=cartesian_covariances,
            time=Timestamp.from_mjd(t1_, scale="tdb"),
            origin=Origin.from_kwargs(code=origin_code),
            frame="ecliptic",
        ),
    )


@ray.remote
def propagate_2body_worker_ray(
    start: int,
    idx_chunk: np.ndarray,
    orbits: Orbits,
    times: Timestamp,
    max_iter: int,
    tol: float,
) -> Tuple[int, Orbits]:
    orbits_chunk = orbits.take(idx_chunk)
    propagated = _propagate_2body_serial(
        orbits_chunk,
        times,
        max_iter=max_iter,
        tol=tol,
    )
    return start, propagated


[docs] def propagate_2body( orbits: Orbits, times: Timestamp, max_iter: int = 1000, tol: float = 1e-14, *, max_processes: Optional[int] = 1, chunk_size: int = 100, ) -> Orbits: """ Propagate orbits using the 2-body universal anomaly formalism. Parameters ---------- orbits : `~adam_core.orbits.orbits.Orbits` (N) Cartesian orbits with position in units of au and velocity in units of au per day. times : Timestamp (M) Epochs to which to propagate each orbit. If a single epoch is given, all orbits are propagated to this epoch. If multiple epochs are given, then each orbit to will be propagated to each epoch. max_iter : int, optional Maximum number of iterations over which to converge. If number of iterations is exceeded, will return the value of the universal anomaly at the last iteration. tol : float, optional Numerical tolerance to which to compute universal anomaly using the Newtown-Raphson method. Returns ------- orbits : `~adam_core.orbits.orbits.Orbits` (N*M) Orbits propagated to each MJD. """ if max_processes is None: max_processes = mp.cpu_count() if max_processes <= 1: return _propagate_2body_serial( orbits, times, max_iter=max_iter, tol=tol, ) initialize_use_ray(num_cpus=max_processes) # Put large inputs in object store once. orbits_ref = ray.put(orbits) # type: ignore[name-defined] times_ref = ray.put(times) # type: ignore[name-defined] idx = np.arange(0, len(orbits), dtype=np.int64) pending: List["ObjectRef"] = [] # type: ignore[name-defined] results: Dict[int, Orbits] = {} for idx_chunk in _iterate_chunks(idx, chunk_size): start = int(idx_chunk[0]) if len(idx_chunk) else 0 pending.append( propagate_2body_worker_ray.remote( # type: ignore[name-defined] start, idx_chunk, orbits_ref, times_ref, max_iter, tol ) ) if len(pending) >= max_processes * 1.5: finished, pending = ray.wait(pending, num_returns=1) # type: ignore[name-defined] start_i, propagated_i = ray.get(finished[0]) # type: ignore[name-defined] results[int(start_i)] = propagated_i while pending: finished, pending = ray.wait(pending, num_returns=1) # type: ignore[name-defined] start_i, propagated_i = ray.get(finished[0]) # type: ignore[name-defined] results[int(start_i)] = propagated_i chunks = [results[k] for k in sorted(results.keys())] return qv.concatenate(chunks) if chunks else Orbits.empty()