Source code for adam_core.dynamics.aberrations

from typing import Tuple

import jax.numpy as jnp
from jax import config, jit, lax, vmap

from ..constants import Constants as c
from .propagation import _propagate_2body

config.update("jax_enable_x64", True)

MU = c.MU
C = c.C


@jit
def _add_light_time(
    orbit: jnp.ndarray,
    t0: float,
    observer_position: jnp.ndarray,
    lt_tol: float = 1e-10,
    mu: float = MU,
    max_iter: int = 1000,
    tol: float = 1e-15,
    max_lt_iter: int = 10,
) -> Tuple[jnp.ndarray, jnp.float64]:
    """
    When generating ephemeris, orbits need to be backwards propagated to the time
    at which the light emitted or relflected from the object towards the observer.

    Light time correction must be added to orbits in expressed in an inertial frame (ie, orbits
    must be barycentric).

    Parameters
    ----------
    orbit : `~jax.numpy.ndarray` (6)
        Barycentric orbit in cartesian elements to correct for light time delay.
    t0 : float
        Epoch at which orbits are defined.
    observer_positions : `~jax.numpy.ndarray` (3)
        Location of the observer in barycentric cartesian elements at the time of observation.
    lt_tol : float, optional
        Calculate aberration to within this value in time (units of days.)
    mu : float, optional
        Gravitational parameter (GM) of the attracting body in units of
        AU**3 / d**2.
    max_iter : int, optional
        Maximum number of iterations over which to converge for propagation.
    tol : float, optional
        Numerical tolerance to which to compute universal anomaly during propagation using the Newtown-Raphson
        method.

    Returns
    -------
    corrected_orbit : `~jax.numpy.ndarray` (6)
        Orbit adjusted for light travel time.
    lt : float
        Light time correction (t0 - corrected_t0).
    """
    dlt = 1e30
    lt = 1e30

    @jit
    def _iterate_light_time(p):

        orbit_i = p[0]
        t0 = p[1]
        lt0 = p[2]
        dlt = p[3]
        lt_iterations = p[4]

        # Calculate topocentric distance
        rho = jnp.linalg.norm(orbit_i[:3] - observer_position)

        # Calculate initial guess of light time
        lt = rho / C

        # Calculate difference between previous light time correction
        # and current guess
        dlt = jnp.abs(lt - lt0)

        # Propagate backwards to new epoch
        t1 = t0 - lt
        orbit_propagated = _propagate_2body(
            orbit, t0, t1, mu=mu, max_iter=max_iter, tol=tol
        )

        p[0] = orbit_propagated
        p[1] = t1
        p[2] = lt
        p[3] = dlt
        p[4] = lt_iterations + 1
        return p

    @jit
    def _while_condition(p):
        dlt = p[3]
        iterations = p[4]
        return (dlt > lt_tol) & (iterations < max_lt_iter)

    lt_iterations = 0
    p = [orbit, t0, lt, dlt, lt_iterations]
    p = lax.while_loop(_while_condition, _iterate_light_time, p)

    orbit_aberrated = p[0]
    t0_aberrated = p[1]  # noqa: F841
    lt = p[2]
    dlt = p[3]
    iterations = p[4]
    # Return NaN light-time when convergence was not reached so host-side callers
    # can fail fast with row-level context.
    lt = jnp.where((dlt > lt_tol) | (iterations >= max_lt_iter), jnp.nan, lt)
    return orbit_aberrated, lt


# Vectorization Map: _add_light_time
_add_light_time_vmap = jit(
    vmap(
        _add_light_time,
        in_axes=(0, 0, 0, None, None, None, None, None),
        out_axes=(0, 0),
    )
)


[docs] @jit def add_light_time( orbits: jnp.ndarray, t0: jnp.ndarray, observer_positions: jnp.ndarray, lt_tol: float = 1e-10, mu: float = MU, max_iter: int = 1000, tol: float = 1e-15, max_lt_iter: int = 10, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ When generating ephemeris, orbits need to be backwards propagated to the time at which the light emitted or relflected from the object towards the observer. Light time correction must be added to orbits in expressed in an inertial frame (ie, orbits must be barycentric). Parameters ---------- orbits : `~jax.numpy.ndarray` (N, 6) Barycentric orbits in cartesian elements to correct for light time delay. t0 : `~jax.numpy.ndarray` (N) Epoch at which orbits are defined. observer_positions : `~jax.numpy.ndarray` (N, 3) Location of the observer in barycentric cartesian elements at the time of observation. lt_tol : float, optional Calculate aberration to within this value in time (units of days.) mu : float, optional Gravitational parameter (GM) of the attracting body in units of AU**3 / d**2. max_iter : int, optional Maximum number of iterations over which to converge for propagation. tol : float, optional Numerical tolerance to which to compute universal anomaly during propagation using the Newtown-Raphson method. Returns ------- corrected_orbits : `~jax.numpy.ndarray` (N, 6) Orbits adjusted for light travel time. lt : `~jax.numpy.ndarray` (N) Light time correction (t0 - corrected_t0). """ orbits_aberrated, lts = _add_light_time_vmap( orbits, t0, observer_positions, lt_tol, mu, max_iter, tol, max_lt_iter ) return orbits_aberrated, lts
[docs] @jit def add_stellar_aberration( orbits: jnp.ndarray, observer_states: jnp.ndarray ) -> jnp.ndarray: """ The motion of the observer in an inertial frame will cause an object to appear in a different location than its true geometric location. This aberration is typically applied after light time corrections have been added. The velocity of the input orbits are unmodified only the position vector is modified with stellar aberration. Parameters ---------- orbits : `~jax.numpy.ndarray` (N, 6) Orbits in barycentric cartesian elements. observer_states : `~jax.numpy.ndarray` (N, 6) Observer states in barycentric cartesian elements. Returns ------- rho_aberrated : `~jax.numpy.ndarray` (N, 3) The topocentric position vector for each orbit with added stellar aberration. References ---------- [1] Urban, S. E; Seidelmann, P. K. (2013) Explanatory Supplement to the Astronomical Almanac. 3rd ed., University Science Books. ISBN-13: 978-1891389856 """ topo_states = orbits - observer_states rho_aberrated = jnp.zeros((len(topo_states), 3), dtype=jnp.float64) rho_aberrated = rho_aberrated.at[:].set(topo_states[:, :3]) v_obs = observer_states[:, 3:] gamma = v_obs / C beta_inv = jnp.sqrt(1 - jnp.linalg.norm(gamma, axis=1, keepdims=True) ** 2) delta = jnp.linalg.norm(topo_states[:, :3], axis=1, keepdims=True) # Equation 7.40 in Urban & Seidelmann (2013) [1] rho = topo_states[:, :3] / delta rho_dot_gamma = jnp.sum(rho * gamma, axis=1, keepdims=True) rho_aberrated = rho_aberrated.at[:].set( (beta_inv * rho + gamma + rho_dot_gamma * gamma / (1 + beta_inv)) / (1 + rho_dot_gamma) ) rho_aberrated = rho_aberrated.at[:].multiply(delta) return rho_aberrated