Source code for adam_core.dynamics.lagrange

from typing import Tuple

import jax.numpy as jnp
from jax import config, jit

from ..constants import Constants as C
from .chi import calc_chi
from .stumpff import STUMPFF_TYPES

config.update("jax_enable_x64", True)


MU = C.MU
LAGRANGE_TYPES = Tuple[jnp.float64, jnp.float64, jnp.float64, jnp.float64]


[docs] @jit def calc_lagrange_coefficients( r: jnp.ndarray, v: jnp.ndarray, dt: float, mu: float = MU, max_iter: int = 100, tol: float = 1e-16, ) -> Tuple[LAGRANGE_TYPES, STUMPFF_TYPES, jnp.float64]: """ Calculate the exact Lagrange coefficients given an initial state defined at t0, and the change in time from t0 to t1 (dt = t1 - t0). Parameters ---------- r : `~jax.numpy.ndarray` (3) Position vector in au. v : `~jax.numpy.ndarray` (3) Velocity vector in au per day. dt : float Time from epoch to which calculate chi in units of decimal days. mu : float Gravitational parameter (GM) of the attracting body in units of au**3 / d**2. max_iter : int Maximum number of iterations over which to converge. If number of iterations is exceeded, will return the value of the universal anomaly at the last iteration. tol : float Numerical tolerance to which to compute chi using the Newtown-Raphson method. Returns ------- lagrange_coeffs : (float x 4) f : float Langrange f coefficient. g : float Langrange g coefficient. f_dot : float Time deriviative of the Langrange f coefficient. g_dot : float Time deriviative of the Langrange g coefficient. stumpff_coeffs : (float x 6) First six Stumpff functions (c0, c1, c2, c3, c4, c5) chi : float Universal anomaly. References ---------- [1] Curtis, H. D. (2014). Orbital Mechanics for Engineering Students. 3rd ed., Elsevier Ltd. ISBN-13: 978-0080977478 """ sqrt_mu = jnp.sqrt(mu) chi, c0, c1, c2, c3, c4, c5 = calc_chi(r, v, dt, mu=mu, max_iter=max_iter, tol=tol) stumpff_coeffs = (c0, c1, c2, c3, c4, c5) chi2 = chi**2 r_mag = jnp.linalg.norm(r) v_mag = jnp.linalg.norm(v) # Equations 3.48 and 3.50 in Curtis (2014) [1] alpha = -(v_mag**2) / mu + 2 / r_mag # Equations 3.69a and 3.69b in Curtis (2014) [1] f = 1 - chi**2 / r_mag * c2 g = dt - 1 / sqrt_mu * chi**3 * c3 r_new = f * r + g * v r_new_mag = jnp.linalg.norm(r_new) # Equations 3.69c and 3.69d in Curtis (2014) [1] f_dot = sqrt_mu / (r_mag * r_new_mag) * (alpha * chi**3 * c3 - chi) g_dot = 1 - chi2 / r_new_mag * c2 lagrange_coeffs = (f, g, f_dot, g_dot) return lagrange_coeffs, stumpff_coeffs, chi
[docs] @jit def apply_lagrange_coefficients( r: jnp.ndarray, v: jnp.ndarray, f: float, g: float, f_dot: float, g_dot: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Apply the Lagrange coefficients to r and v. Parameters ---------- r : `~jax.numpy.ndarray` (3) Position vector in au. v : `~jax.numpy.ndarray` (3) Velocity vector in au per day. f : float Langrange f coefficient. g : float Langrange g coefficient. f_dot : float Time deriviative of the Langrange f coefficient. g_dot : float Time deriviative of the Langrange g coefficient. Returns ------- r_new : `~jax.numpy.ndarray` (3) New position vector in au propagated with the Lagrange coefficients. v_new : `~jax.numpy.ndarray` (3) New velocity vector in au per day propagated with the Lagrange coefficients. """ r_new = f * r + g * v v_new = f_dot * r + g_dot * v return r_new, v_new