Source code for adam_core.dynamics.chi

from dataclasses import dataclass
from typing import Tuple

import jax.numpy as jnp
import numpy as np
from jax import config, jit, lax

from ..constants import Constants as c
from .stumpff import calc_stumpff

config.update("jax_enable_x64", True)

MU = c.MU


[docs] @dataclass(frozen=True) class ChiDiagnostics: dt: float mu: float r_norm: float v_norm: float alpha: float chi: float finite: bool
[docs] @jit def calc_chi( r: jnp.ndarray, v: jnp.ndarray, dt: float, mu: float = MU, max_iter: int = 100, tol: float = 1e-16, ) -> Tuple[ np.float64, np.float64, np.float64, np.float64, np.float64, np.float64, np.float64 ]: """ Calculate universal anomaly chi using Newton-Raphson. Parameters ---------- r : `~jax.numpy.ndarray` (3) Position vector in au. v : `~jax.numpy.ndarray` (3) Velocity vector in au per day. dt : float Time from epoch to which calculate chi in units of decimal days. mu : float Gravitational parameter (GM) of the attracting body in units of au**3 / d**2. max_iter : int Maximum number of iterations over which to converge. If number of iterations is exceeded, will return the value of the universal anomaly at the last iteration. tol : float Numerical tolerance to which to compute chi using the Newtown-Raphson method. Returns ------- chi : float Universal anomaly. c0, c1, c2, c3, c4, c5 : 6 x float First six Stumpff functions. References ---------- [1] Curtis, H. D. (2014). Orbital Mechanics for Engineering Students. 3rd ed., Elsevier Ltd. ISBN-13: 978-0080977478 """ v_mag = jnp.linalg.norm(v) r_mag = jnp.linalg.norm(r) rv_mag = jnp.dot(r, v) / r_mag sqrt_mu = jnp.sqrt(mu) # Equations 3.48 and 3.50 in Curtis (2014) [1] alpha = -(v_mag**2) / mu + 2 / r_mag # Equation 3.66 in Curtis (2014) [1] chi = sqrt_mu * jnp.abs(alpha) * dt ratio = 1e15 iterations = 0 # Define parameters array (arguments that will be changing as # the while loop iterates): # chi, c0, c1, c2, c3, c4, c5, ratio, iterations p = [chi, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ratio, iterations] # Define while loop body function @jit def _chi_newton_raphson(p): chi = p[0] ratio = p[-2] iterations = p[-1] chi2 = chi**2 psi = alpha * chi2 c0, c1, c2, c3, c4, c5 = calc_stumpff(psi) # Newton-Raphson # Equation 3.65 in Curtis (2014) [1] f = ( r_mag * rv_mag / sqrt_mu * chi2 * c2 + (1 - alpha * r_mag) * chi**3 * c3 + r_mag * chi - sqrt_mu * dt ) fp = ( r_mag * rv_mag / sqrt_mu * chi * (1 - alpha * chi2 * c3) + (1 - alpha * r_mag) * chi2 * c2 + r_mag ) ratio = f / fp chi -= ratio iterations += 1 p[0] = chi p[1] = c0 p[2] = c1 p[3] = c2 p[4] = c3 p[5] = c4 p[6] = c5 p[7] = ratio p[8] = iterations return p # Define while loop condition function @jit def _while_condition(p): ratio = p[-2] iterations = p[-1] return (jnp.abs(ratio) > tol) & (iterations <= max_iter) p = lax.while_loop(_while_condition, _chi_newton_raphson, p) chi = p[0] c0 = p[1] c1 = p[2] c2 = p[3] c3 = p[4] c4 = p[5] c5 = p[6] return chi, c0, c1, c2, c3, c4, c5
[docs] def calc_chi_diagnostics( r: np.ndarray, v: np.ndarray, dt: float, mu: float = MU, max_iter: int = 100, tol: float = 1e-16, ) -> ChiDiagnostics: """ Host-side chi diagnostics helper for fail-fast error reporting. """ r_arr = np.asarray(r, dtype=np.float64) v_arr = np.asarray(v, dtype=np.float64) r_norm = float(np.linalg.norm(r_arr)) v_norm = float(np.linalg.norm(v_arr)) alpha = float(-(v_norm**2) / mu + 2.0 / r_norm) if r_norm > 0 else np.nan chi = float(calc_chi(r_arr, v_arr, dt, mu=mu, max_iter=max_iter, tol=tol)[0]) finite = bool( np.isfinite(r_norm) and np.isfinite(v_norm) and np.isfinite(alpha) and np.isfinite(chi) ) return ChiDiagnostics( dt=float(dt), mu=float(mu), r_norm=r_norm, v_norm=v_norm, alpha=alpha, chi=chi, finite=finite, )