Source code for adam_core.dynamics.ephemeris

import multiprocessing as mp
from typing import Dict, List, Optional, Tuple

import jax.numpy as jnp
import numpy as np
import pyarrow as pa
import quivr as qv
import ray
from jax import jit, lax, vmap
from ray import ObjectRef

from ..constants import Constants as c
from ..coordinates.cartesian import CartesianCoordinates
from ..coordinates.covariances import (
    CoordinateCovariances,
    transform_covariances_jacobian,
)
from ..coordinates.origin import Origin, OriginCodes
from ..coordinates.spherical import SphericalCoordinates
from ..coordinates.transform import _cartesian_to_spherical, transform_coordinates
from ..observers.observers import Observers
from ..orbits.ephemeris import Ephemeris
from ..orbits.orbits import Orbits
from ..photometry.magnitude import (
    calculate_apparent_magnitude_v,
    calculate_apparent_magnitude_v_and_phase_angle,
    calculate_phase_angle,
)
from ..ray_cluster import initialize_use_ray
from ..utils.chunking import process_in_chunks
from ..utils.iter import _iterate_chunks
from .aberrations import _add_light_time, add_stellar_aberration
from .exceptions import DynamicsNumericalError

_TRANSFORM_EC2EQ = jnp.asarray(c.TRANSFORM_EC2EQ, dtype=jnp.float64)


@jit
def _rotate_cartesian_state_ec2eq(state_ec: jnp.ndarray) -> jnp.ndarray:
    """
    Rotate a 6D Cartesian state from ecliptic J2000 to equatorial J2000.
    """
    pos_eq = _TRANSFORM_EC2EQ @ state_ec[0:3]
    vel_eq = _TRANSFORM_EC2EQ @ state_ec[3:6]
    return jnp.concatenate([pos_eq, vel_eq])


@jit
def _generate_ephemeris_2body(
    propagated_orbit: np.ndarray,
    observation_time: float,
    observer_coordinates: jnp.ndarray,
    mu: float,
    lt_tol: float = 1e-10,
    max_iter: int = 100,
    tol: float = 1e-15,
    stellar_aberration: bool = False,
) -> Tuple[jnp.ndarray, jnp.float64, jnp.ndarray]:
    """
    Given a propagated orbit, generate its on-sky ephemeris as viewed from the observer.
    This function calculates the light time delay between the propagated orbit and the observer,
    and then propagates the orbit backward by that amount to when the light from object was actually
    emitted towards the observer ("astrometric coordinates").

    The motion of the observer in an inertial frame will cause an object
    to appear in a different location than its true location, this is known as
    stellar aberration (often referred to in combination with other aberrations as "apparent
    coordinates"). Stellar aberration can optionally be applied after
    light time correction has been added but it should not be necessary when comparing to ephemerides
    of solar system small bodies extracted from astrometric catalogs. The stars to which the
    catalog is calibrated undergo the same aberration as the moving objects as seen from the observer.

    If stellar aberration is applied then the velocity of the input orbits are unmodified, only the position
    vector is modified with stellar aberration.

    For more details on aberrations see:
        https://naif.jpl.nasa.gov/pub/naif/toolkit_docs/FORTRAN/req/abcorr.html
        https://ssd.jpl.nasa.gov/horizons/manual.html#defs

    Parameters
    ----------
    propagated_orbit : `~jax.numpy.ndarray` (6)
        Barycentric Cartesian orbit propagated to the given time.
    observation_time : float
        Epoch at which orbit and observer coordinates are defined.
    observer_coordinates : `~jax.numpy.ndarray` (3)
        Barycentric Cartesian observer coordinates.
    mu : float (1)
        Gravitational parameter (GM) of the attracting body in units of
        AU**3 / d**2.
    lt_tol : float, optional
        Calculate aberration to within this value in time (units of days).
    max_iter : int, optional
        Maximum number of iterations over which to converge for propagation.
    tol : float, optional
        Numerical tolerance to which to compute universal anomaly during propagation using the Newtown-Raphson
        method.
    stellar_aberration : bool, optional
        Apply stellar aberration to the ephemerides.

    Returns
    -------
    ephemeris_spherical : `~jax.numpy.ndarray` (6)
        Topocentric Spherical ephemeris.
    lt : float
        Light time correction (t0 - corrected_t0).
    aberrated_orbit : `~jax.numpy.ndarray` (6)
        Barycentric Cartesian orbit corrected for light time (emission time state).
    """
    # Add light time correction
    propagated_orbits_aberrated, light_time = _add_light_time(
        propagated_orbit,
        observation_time,
        observer_coordinates[0:3],
        lt_tol=lt_tol,
        mu=mu,
        max_iter=max_iter,
        tol=tol,
    )

    # Calculate topocentric coordinates
    topocentric_coordinates = propagated_orbits_aberrated - observer_coordinates

    # Apply stellar aberration to topocentric coordinates
    topocentric_coordinates = lax.cond(
        stellar_aberration,
        lambda topocentric_coords: topocentric_coords.at[0:3].set(
            add_stellar_aberration(
                propagated_orbits_aberrated.reshape(1, -1),
                observer_coordinates.reshape(1, -1),
            )[0],
        ),
        lambda topocentric_coords: topocentric_coords,
        topocentric_coordinates,
    )

    # Convert to spherical coordinates in the equatorial frame.
    #
    # `topocentric_coordinates` is in the same (ecliptic) inertial frame as the inputs.
    # Rotating the Cartesian state and then converting avoids an expensive
    # spherical->cartesian->spherical round-trip later in the public wrapper.
    ephemeris_spherical = _cartesian_to_spherical(
        _rotate_cartesian_state_ec2eq(topocentric_coordinates)
    )

    return ephemeris_spherical, light_time, propagated_orbits_aberrated


# Vectorization Map: _generate_ephemeris_2body
_generate_ephemeris_2body_vmap = jit(
    vmap(
        _generate_ephemeris_2body,
        in_axes=(0, 0, 0, 0, None, None, None, None),
        out_axes=(0, 0, 0),
    )
)


def _first_non_finite(values: np.ndarray) -> Optional[int]:
    bad = np.flatnonzero(~np.isfinite(values))
    return int(bad[0]) if bad.size > 0 else None


def _raise_ephemeris_numerical_error(
    *,
    reason: str,
    row_index: int,
    orbit_id: str,
    object_id: str,
    observation_time: float,
    light_time: Optional[float],
    max_iter: int,
    tol: float,
    lt_tol: float,
) -> None:
    raise DynamicsNumericalError(
        stage="ephemeris",
        reason=reason,
        context={
            "row_index": row_index,
            "orbit_id": orbit_id,
            "object_id": object_id,
            "observation_time_mjd_tdb": float(observation_time),
            "light_time_days": None if light_time is None else float(light_time),
            "max_iter": int(max_iter),
            "tol": float(tol),
            "lt_tol": float(lt_tol),
        },
    )


def _generate_ephemeris_2body_serial(
    propagated_orbits: Orbits,
    observers: Observers,
    *,
    lt_tol: float,
    max_iter: int,
    tol: float,
    stellar_aberration: bool,
    predict_magnitudes: bool,
    predict_phase_angle: bool,
) -> Ephemeris:
    # Delegate to the public function's existing implementation, but without Ray.
    return generate_ephemeris_2body(
        propagated_orbits,
        observers,
        lt_tol=lt_tol,
        max_iter=max_iter,
        tol=tol,
        stellar_aberration=stellar_aberration,
        predict_magnitudes=predict_magnitudes,
        predict_phase_angle=predict_phase_angle,
        max_processes=1,
    )


@ray.remote
def ephemeris_2body_worker_ray(
    start: int,
    idx_chunk: np.ndarray,
    propagated_orbits: Orbits,
    observers: Observers,
    lt_tol: float,
    max_iter: int,
    tol: float,
    stellar_aberration: bool,
    predict_magnitudes: bool,
    predict_phase_angle: bool,
) -> Tuple[int, Ephemeris]:
    prop_chunk = propagated_orbits.take(idx_chunk)
    obs_chunk = observers.take(idx_chunk)
    eph = _generate_ephemeris_2body_serial(
        prop_chunk,
        obs_chunk,
        lt_tol=lt_tol,
        max_iter=max_iter,
        tol=tol,
        stellar_aberration=stellar_aberration,
        predict_magnitudes=predict_magnitudes,
        predict_phase_angle=predict_phase_angle,
    )
    return start, eph


[docs] def generate_ephemeris_2body( propagated_orbits: Orbits, observers: Observers, lt_tol: float = 1e-10, max_iter: int = 1000, tol: float = 1e-15, stellar_aberration: bool = False, predict_magnitudes: bool = True, *, predict_phase_angle: bool = False, max_processes: Optional[int] = 1, chunk_size: int = 100, ) -> Ephemeris: """ Generate on-sky ephemerides for each propagated orbit as viewed by the observers. This function calculates the light time delay between the propagated orbit and the observer, and then propagates the orbit backward by that amount to when the light from object was actually emitted towards the observer ("astrometric coordinates"). The motion of the observer in an inertial frame will cause an object to appear in a different location than its true location, this is known as stellar aberration (often referred to in combination with other aberrations as "apparent coordinates"). Stellar aberration can optionally be applied after light time correction has been added but it should not be necessary when comparing to ephemerides of solar system small bodies extracted from astrometric catalogs. The stars to which the catalog is calibrated undergo the same aberration as the moving objects as seen from the observer. If stellar aberration is applied then the velocity of the input orbits are unmodified, only the position vector is modified with stellar aberration. For more details on aberrations see: https://naif.jpl.nasa.gov/pub/naif/toolkit_docs/FORTRAN/req/abcorr.html https://ssd.jpl.nasa.gov/horizons/manual.html#defs Parameters ---------- propagated_orbits : `~adam_core.orbits.orbits.Orbits` (N) Propagated orbits. observers : `~adam_core.observers.observers.Observers` (N) Observers for which to generate ephemerides. Orbits should already have been propagated to the same times as the observers. lt_tol : float, optional Calculate aberration to within this value in time (units of days). mu : float, optional Gravitational parameter (GM) of the attracting body in units of AU**3 / d**2. max_iter : int, optional Maximum number of iterations over which to converge for propagation. tol : float, optional Numerical tolerance to which to compute universal anomaly during propagation using the Newtown-Raphson method. stellar_aberration : bool, optional Apply stellar aberration to the ephemerides. Returns ------- ephemeris : `~adam_core.orbits.ephemeris.Ephemeris` (N) Topocentric ephemerides for each propagated orbit as observed by the given observers. """ if max_processes is None: max_processes = mp.cpu_count() if max_processes > 1: initialize_use_ray(num_cpus=max_processes) num_entries = len(observers) assert len(propagated_orbits) == num_entries propagated_ref = ray.put(propagated_orbits) # type: ignore[name-defined] observers_ref = ray.put(observers) # type: ignore[name-defined] idx = np.arange(0, num_entries, dtype=np.int64) pending: List["ObjectRef"] = [] # type: ignore[name-defined] results: Dict[int, Ephemeris] = {} for idx_chunk in _iterate_chunks(idx, chunk_size): start = int(idx_chunk[0]) if len(idx_chunk) else 0 pending.append( ephemeris_2body_worker_ray.remote( # type: ignore[name-defined] start, idx_chunk, propagated_ref, observers_ref, lt_tol, max_iter, tol, stellar_aberration, predict_magnitudes, predict_phase_angle, ) ) if len(pending) >= max_processes * 1.5: finished, pending = ray.wait(pending, num_returns=1) # type: ignore[name-defined] start_i, eph_i = ray.get(finished[0]) # type: ignore[name-defined] results[int(start_i)] = eph_i while pending: finished, pending = ray.wait(pending, num_returns=1) # type: ignore[name-defined] start_i, eph_i = ray.get(finished[0]) # type: ignore[name-defined] results[int(start_i)] = eph_i chunks = [results[k] for k in sorted(results.keys())] return qv.concatenate(chunks) if chunks else Ephemeris.empty() num_entries = len(observers) assert ( len(propagated_orbits) == num_entries ), "Orbits and observers must be paired and orbits must be propagated to observer times." # Transform both the orbits and observers to the barycenter if they are not already. # # Fast path: common workload uses SUN/ecliptic for both, on an aligned time grid. # In that case we can compute the SUN->SSB translation vectors once and apply them # to both orbits and observers (strictly equivalent, but avoids duplicate work). propagated_orbits_barycentric = None observers_barycentric = None try: po = propagated_orbits.coordinates obc = observers.coordinates po_origin = po.origin.code.to_numpy(zero_copy_only=False) ob_origin = obc.origin.code.to_numpy(zero_copy_only=False) if ( str(po.frame) == "ecliptic" and str(obc.frame) == "ecliptic" and np.all(po_origin == OriginCodes.SUN.name) and np.all(ob_origin == OriginCodes.SUN.name) ): t_po = po.time.rescale("tdb") t_ob = obc.time.rescale("tdb") same_time = np.array_equal( t_po.days.to_numpy(zero_copy_only=False), t_ob.days.to_numpy(zero_copy_only=False), ) and np.array_equal( t_po.nanos.to_numpy(zero_copy_only=False), t_ob.nanos.to_numpy(zero_copy_only=False), ) if same_time: from ..utils.spice import get_perturber_state sun_wrt_ssb = get_perturber_state( OriginCodes.SUN, t_po, frame="ecliptic", origin=OriginCodes.SOLAR_SYSTEM_BARYCENTER, ).values coords_po = po.translate( sun_wrt_ssb, OriginCodes.SOLAR_SYSTEM_BARYCENTER.name ) coords_ob = obc.translate( sun_wrt_ssb, OriginCodes.SOLAR_SYSTEM_BARYCENTER.name ) propagated_orbits_barycentric = propagated_orbits.set_column( "coordinates", coords_po ) observers_barycentric = observers.set_column("coordinates", coords_ob) except Exception: propagated_orbits_barycentric = None observers_barycentric = None if propagated_orbits_barycentric is None or observers_barycentric is None: propagated_orbits_barycentric = propagated_orbits.set_column( "coordinates", transform_coordinates( propagated_orbits.coordinates, CartesianCoordinates, frame_out="ecliptic", origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER, ), ) observers_barycentric = observers.set_column( "coordinates", transform_coordinates( observers.coordinates, CartesianCoordinates, frame_out="ecliptic", origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER, ), ) observer_coordinates = observers_barycentric.coordinates.values observer_codes = observers_barycentric.code.to_numpy(zero_copy_only=False) mu = observers_barycentric.coordinates.origin.mu() times = propagated_orbits.coordinates.time.mjd().to_numpy(zero_copy_only=False) # Inner (JAX) batch size. # # This controls the shape of the vmapped JAX kernel inside each process/worker. # Larger batches reduce Python loop overhead significantly for large workloads. chunk_size = 2000 # Process in chunks ephemeris_spherical = np.empty((num_entries, 6), dtype=np.float64) light_time = np.empty((num_entries,), dtype=np.float64) aberrated_orbits = np.empty((num_entries, 6), dtype=np.float64) start = 0 for orbits_chunk, times_chunk, observer_coords_chunk, mu_chunk in zip( process_in_chunks(propagated_orbits_barycentric.coordinates.values, chunk_size), process_in_chunks(times, chunk_size), process_in_chunks(observer_coordinates, chunk_size), process_in_chunks(mu, chunk_size), ): valid = min(chunk_size, num_entries - start) ephemeris_chunk, light_time_chunk, aberrated_chunk = ( _generate_ephemeris_2body_vmap( orbits_chunk, times_chunk, observer_coords_chunk, mu_chunk, lt_tol, max_iter, tol, stellar_aberration, ) ) eph_np = np.asarray(ephemeris_chunk, dtype=np.float64)[:valid] lt_np = np.asarray(light_time_chunk, dtype=np.float64)[:valid] aberrated_np = np.asarray(aberrated_chunk, dtype=np.float64)[:valid] bad_lt = _first_non_finite(lt_np) if bad_lt is not None: abs_idx = start + bad_lt _raise_ephemeris_numerical_error( reason="non_finite_light_time", row_index=abs_idx, orbit_id=str( propagated_orbits_barycentric.orbit_id.to_numpy( zero_copy_only=False )[abs_idx] ), object_id=str( propagated_orbits_barycentric.object_id.to_numpy( zero_copy_only=False )[abs_idx] ), observation_time=float(times_chunk[bad_lt]), light_time=float(lt_np[bad_lt]), max_iter=max_iter, tol=tol, lt_tol=lt_tol, ) bad_eph = _first_non_finite(eph_np) if bad_eph is not None: abs_idx = start + bad_eph _raise_ephemeris_numerical_error( reason="non_finite_ephemeris_state", row_index=abs_idx, orbit_id=str( propagated_orbits_barycentric.orbit_id.to_numpy( zero_copy_only=False )[abs_idx] ), object_id=str( propagated_orbits_barycentric.object_id.to_numpy( zero_copy_only=False )[abs_idx] ), observation_time=float(times_chunk[bad_eph]), light_time=float(lt_np[bad_eph]), max_iter=max_iter, tol=tol, lt_tol=lt_tol, ) bad_aberrated = _first_non_finite(aberrated_np) if bad_aberrated is not None: abs_idx = start + bad_aberrated _raise_ephemeris_numerical_error( reason="non_finite_aberrated_state", row_index=abs_idx, orbit_id=str( propagated_orbits_barycentric.orbit_id.to_numpy( zero_copy_only=False )[abs_idx] ), object_id=str( propagated_orbits_barycentric.object_id.to_numpy( zero_copy_only=False )[abs_idx] ), observation_time=float(times_chunk[bad_aberrated]), light_time=float(lt_np[bad_aberrated]), max_iter=max_iter, tol=tol, lt_tol=lt_tol, ) ephemeris_spherical[start : start + valid] = eph_np light_time[start : start + valid] = lt_np aberrated_orbits[start : start + valid] = aberrated_np start += valid if start != num_entries: raise RuntimeError( f"Internal error: expected {num_entries} ephemeris rows, got {start}" ) # Compute emission times by subtracting light-time (in days) from the observation times. bad_light_time = _first_non_finite(light_time) if bad_light_time is not None: _raise_ephemeris_numerical_error( reason="non_finite_light_time_before_emission_time", row_index=bad_light_time, orbit_id=str( propagated_orbits_barycentric.orbit_id.to_numpy(zero_copy_only=False)[ bad_light_time ] ), object_id=str( propagated_orbits_barycentric.object_id.to_numpy(zero_copy_only=False)[ bad_light_time ] ), observation_time=float(times[bad_light_time]), light_time=float(light_time[bad_light_time]), max_iter=max_iter, tol=tol, lt_tol=lt_tol, ) emission_times = propagated_orbits_barycentric.coordinates.time.add_fractional_days( pa.array(-light_time) ) aberrated_coordinates = CartesianCoordinates.from_kwargs( x=aberrated_orbits[:, 0], y=aberrated_orbits[:, 1], z=aberrated_orbits[:, 2], vx=aberrated_orbits[:, 3], vy=aberrated_orbits[:, 4], vz=aberrated_orbits[:, 5], time=emission_times, origin=Origin.from_kwargs( code=np.full(num_entries, OriginCodes.SOLAR_SYSTEM_BARYCENTER.name) ), frame="ecliptic", ) if not propagated_orbits.coordinates.covariance.is_all_nan(): cartesian_covariances = propagated_orbits.coordinates.covariance.to_matrix() covariances_spherical = transform_covariances_jacobian( propagated_orbits.coordinates.values, cartesian_covariances, _generate_ephemeris_2body, in_axes=(0, 0, 0, 0, None, None, None, None), out_axes=(0, 0, 0), observation_times=times, observer_coordinates=observer_coordinates, mu=mu, lt_tol=lt_tol, max_iter=max_iter, tol=tol, stellar_aberration=stellar_aberration, ) covariances_spherical = CoordinateCovariances.from_matrix( np.array(covariances_spherical) ) else: covariances_spherical = None spherical_coordinates = SphericalCoordinates.from_kwargs( time=propagated_orbits.coordinates.time, rho=ephemeris_spherical[:, 0], lon=ephemeris_spherical[:, 1], lat=ephemeris_spherical[:, 2], vrho=ephemeris_spherical[:, 3], vlon=ephemeris_spherical[:, 4], vlat=ephemeris_spherical[:, 5], covariance=covariances_spherical, origin=Origin.from_kwargs(code=observer_codes), frame="equatorial", ) ephemeris = Ephemeris.from_kwargs( orbit_id=propagated_orbits_barycentric.orbit_id, object_id=propagated_orbits_barycentric.object_id, coordinates=spherical_coordinates, light_time=light_time, aberrated_coordinates=aberrated_coordinates, ) want_alpha = bool(predict_phase_angle) want_mags = bool(predict_magnitudes) if not want_alpha and not want_mags: return ephemeris # Determine whether we can compute magnitudes (needs H and G). has_params = None H_v = None G = None if want_mags: H_v = propagated_orbits.physical_parameters.H_v.to_numpy(zero_copy_only=False) G = propagated_orbits.physical_parameters.G.to_numpy(zero_copy_only=False) has_params = np.isfinite(H_v) & np.isfinite(G) if not np.any(has_params): want_mags = False if not want_alpha and not want_mags: return ephemeris # Transform object and observer coordinates to heliocentric for photometry. if aberrated_coordinates is None: raise RuntimeError( "Internal error: aberrated coordinates are required for photometry but were not computed." ) aberrated_heliocentric = transform_coordinates( aberrated_coordinates, CartesianCoordinates, frame_out="ecliptic", origin_out=OriginCodes.SUN, ) observers_heliocentric = observers.set_column( "coordinates", transform_coordinates( observers.coordinates, CartesianCoordinates, frame_out="ecliptic", origin_out=OriginCodes.SUN, ), ) alpha_deg = None mags = None if want_mags and want_alpha: assert H_v is not None and G is not None and has_params is not None mags, alpha_deg = calculate_apparent_magnitude_v_and_phase_angle( H_v=H_v, object_coords=aberrated_heliocentric, observer=observers_heliocentric, G=G, ) elif want_alpha: alpha_deg = calculate_phase_angle( aberrated_heliocentric, observers_heliocentric ) elif want_mags: assert H_v is not None and G is not None and has_params is not None mags = calculate_apparent_magnitude_v( H_v=H_v, object_coords=aberrated_heliocentric, observer=observers_heliocentric, G=G, ) if alpha_deg is not None: alpha_deg = np.asarray(alpha_deg, dtype=np.float64) ephemeris = ephemeris.set_column( "alpha", pa.array(alpha_deg, mask=~np.isfinite(alpha_deg), type=pa.float64()), ) if mags is not None: assert has_params is not None mags = np.asarray(mags, dtype=np.float64) valid = has_params & np.isfinite(mags) ephemeris = ephemeris.set_column( "predicted_magnitude_v", pa.array(mags, mask=~valid, type=pa.float64()) ) return ephemeris