Source code for adam_core.dynamics.kepler

from typing import Tuple

import jax.numpy as jnp
import jax.typing as jnpt
from jax import config, jit, lax

from .barker import solve_barker

config.update("jax_enable_x64", True)


[docs] @jit def calc_period(a: jnpt.ArrayLike, mu: jnpt.ArrayLike) -> jnpt.ArrayLike: """ Calculate the period of an orbit given the semi-major axis and gravitational parameter. Parameters ---------- a Semi-major axis. mu Gravitational parameter. Returns ------- P Period. """ return jnp.where(a < 0.0, jnp.inf, 2 * jnp.pi * jnp.sqrt(a**3 / mu))
[docs] @jit def calc_periapsis_distance(a: jnpt.ArrayLike, e: jnpt.ArrayLike) -> jnpt.ArrayLike: """ Calculate the periapsis distance of an orbit given the semi-major axis and eccentricity. Parameters ---------- a Semi-major axis. e Eccentricity. Returns ------- q Periapsis distance. """ return a * (1 - e)
[docs] @jit def calc_apoapsis_distance(a: jnpt.ArrayLike, e: jnpt.ArrayLike) -> jnpt.ArrayLike: """ Calculate the apoapsis distance of an orbit given the semi-major axis and eccentricity. Parameters ---------- a Semi-major axis. e Eccentricity. Returns ------- Q Apoapsis distance. """ return jnp.where(e >= 1.0, jnp.inf, a * (1 + e))
[docs] @jit def calc_semi_major_axis(q: jnpt.ArrayLike, e: jnpt.ArrayLike) -> jnpt.ArrayLike: """ Calculate the semi-major axis of an orbit given the periapsis distance and eccentricity. Parameters ---------- q Periapsis distance. e Eccentricity. Returns ------- a Semi-major axis. """ return q / (1 - e)
[docs] @jit def calc_semi_latus_rectum(a: jnpt.ArrayLike, e: jnpt.ArrayLike) -> jnpt.ArrayLike: """ Calculate the semi-latus rectum of an orbit given the semi-major axis and eccentricity. Parameters ---------- a Semi-major axis. e Eccentricity. Returns ------- p Semi-latus rectum. """ return a * (1 - e**2)
[docs] @jit def calc_mean_motion(a: jnpt.ArrayLike, mu: jnpt.ArrayLike) -> jnpt.ArrayLike: """ Calculate the mean motion of an orbit given the semi-major axis and gravitational parameter. Parameters ---------- a Semi-major axis. mu Gravitational parameter. Returns ------- n Mean motion in radians per unit time. """ return jnp.sqrt(mu / jnp.abs(a) ** 3)
[docs] @jit def calc_mean_anomaly(nu: float, e: float) -> float: """ Calculate the mean anomaly given true anomaly in radians and eccentricity. Parameters ---------- nu : float True anomaly in radians. e : float Eccentricity. Returns ------- M : float Mean anomaly in radians. """ E, M = lax.cond( e < 1.0, _calc_elliptical_anomalies, lambda nu, e: lax.cond( e > 1.0, _calc_hyperbolic_anomalies, _calc_parabolic_anomalies, nu, e ), nu, e, ) return M
@jit def _calc_elliptical_anomalies(nu: float, e: float) -> Tuple[float, float]: nu_ = jnp.where(nu >= 2 * jnp.pi, nu % (2 * jnp.pi), nu) E = jnp.arctan2(jnp.sqrt(1 - e**2) * jnp.sin(nu_), e + jnp.cos(nu_)) M = E - e * jnp.sin(E) M = jnp.where(M < 0.0, M + 2 * jnp.pi, M) return E, M @jit def _calc_hyperbolic_anomalies(nu: float, e: float) -> Tuple[float, float]: nu_ = jnp.where(nu >= 2 * jnp.pi, nu % (2 * jnp.pi), nu) H = 2 * jnp.arctanh(jnp.sqrt((e - 1) / (e + 1)) * jnp.tan(nu_ / 2)) M = e * jnp.sinh(H) - H M = jnp.where(M < 0.0, M + 2 * jnp.pi, M) return H, M @jit def _calc_parabolic_anomalies(nu: float, e: float) -> Tuple[float, float]: nu_ = jnp.where(nu >= 2 * jnp.pi, nu % (2 * jnp.pi), nu) D = jnp.arctan(nu_ / 2) M = D + (D**3 / 3) M = jnp.where(M < 0.0, M + 2 * jnp.pi, M) return D, M
[docs] @jit def solve_kepler(e: float, M: float, max_iter: int = 100, tol: float = 1e-15) -> float: """ Solve Kepler's equation for true anomaly (nu) given eccentricity and mean anomaly using Newton-Raphson. Technically, this is only valid for orbits with eccentricity < 1.0 and eccentricity > 1.0. However, for parabolic orbits (e = 1.0) this function will call the `solve_barker` function from `thor.dynamics.barker`. Parameters ---------- e : float Eccentricity M : float Mean anomaly in radians. max_iter : int, optional Maximum number of iterations over which to converge. If number of iterations is exceeded, will use the value of the relevant anomaly at the last iteration. tol : float, optional Numerical tolerance to which to compute anomalies using the Newtown-Raphson method. Returns ------- nu : float True anomaly in radians. References ---------- [1] Curtis, H. D. (2014). Orbital Mechanics for Engineering Students. 3rd ed., Elsevier Ltd. ISBN-13: 978-0080977478 """ ratio = 1e15 iterations = 0 E_init = jnp.where(e < 1.0, M, M) p = [E_init, e, M, ratio, iterations] @jit def _elliptical_newton_raphson(p): E = p[0] e = p[1] M = p[2] iterations = p[4] # Newton-Raphson # Equation 3.17 in Curtis (2014) [1] f = E - e * jnp.sin(E) - M fp = 1 - e * jnp.cos(E) ratio = f / fp E -= ratio iterations += 1 p[0] = E p[1] = e p[2] = M p[3] = ratio p[4] = iterations return p @jit def _hyperbolic_newton_raphson(p): F = p[0] e = p[1] M = p[2] iterations = p[4] # Newton-Raphson # Equation 3.45 in Curtis (2014) [1] f = e * jnp.sinh(F) - F - M fp = e * jnp.cosh(F) - 1 ratio = f / fp F -= ratio iterations += 1 p[0] = F p[1] = e p[2] = M p[3] = ratio p[4] = iterations return p # Define while loop condition function @jit def _while_condition(p): ratio = p[-2] iterations = p[-1] return (jnp.abs(ratio) > tol) & (iterations <= max_iter) # Calculate parameters, if e < 1.0 then the orbit is elliptical # if e > 1.0 then the orbit is hyperbolic p = lax.cond( e < 1.0, lambda p: lax.while_loop(_while_condition, _elliptical_newton_raphson, p), lambda p: lax.cond( e > 1.0, lambda p: lax.while_loop( _while_condition, _hyperbolic_newton_raphson, p, ), # For parabolic orbits return the parameters as is since # no iteration is needed for parabolic orbits lambda p: p, p, ), p, ) nu = lax.cond( e < 1.0, lambda E, e, M: 2 * jnp.arctan2( jnp.sqrt(1 + e) * jnp.sin(E / 2), jnp.sqrt(1 - e) * jnp.cos(E / 2) ), lambda E, e, M: lax.cond( e > 1.0, lambda H, e, M: 2 * jnp.arctan( jnp.sqrt(e + 1) * jnp.sinh(H / 2) / (jnp.sqrt(e - 1) * jnp.cosh(H / 2)) ), lambda P, e, M: solve_barker(M), p[0], p[1], p[2], ), p[0], p[1], p[2], ) # True anomaly should be in the range [0, 2*pi) nu = jnp.where(nu < 0.0, nu + 2 * jnp.pi, nu) nu = jnp.where(nu >= 2 * jnp.pi, nu % (2 * jnp.pi), nu) return nu