"""
Implementation of Lambert's problem using Izzo's method.
This implementation follows the algorithm described in:
Izzo, D. (2015). Revisiting Lambert's problem. Celestial Mechanics and Dynamical Astronomy, 121(1), 1-15.
Credits: Based on poliastro implementation by Juan Luis Cano Rodríguez and lamberthub by Jorge Martinez
"""
from typing import Tuple, Union
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
from jax import config, jit, lax, vmap
from ..constants import Constants as C
config.update("jax_enable_x64", True)
MU = C.MU
@jit
def _hyp2f1b(x):
"""Hypergeometric function 2F1(3, 1, 5/2, x), implemented with JAX."""
# Set a finite number of iterations instead of convergence check
MAX_ITER = 100
# Using lax.scan for JAX compatibility
def body_fun(state, i):
x, term, res = state
# Update term and result
term = term * (3 + i) * (1 + i) / (5 / 2 + i) * x / (i + 1)
res += term
return (x, term, res), None
# Initialize state
init_state = (
x.astype(jnp.float64),
jnp.ones_like(x, dtype=jnp.float64),
jnp.ones_like(x, dtype=jnp.float64),
)
# Run the loop with fixed iterations
(_, _, res), _ = lax.scan(
body_fun, init_state, jnp.arange(MAX_ITER, dtype=jnp.float64)
)
# Set to infinity for x >= 1
return jnp.where(x >= 1.0, jnp.inf, res)
@jit
def _compute_y(x, ll):
"""Computes y."""
return jnp.sqrt(1 - ll**2 * (1 - x**2))
@jit
def _compute_psi(x, y, ll):
"""Computes psi."""
# Compute the argument for arccos
arccos_arg = x * y + ll * (1 - x**2)
# Elliptic case (x < 1.0)
elliptic = jnp.arccos(jnp.clip(arccos_arg, -1.0, 1.0))
# Hyperbolic case (x > 1.0)
hyperbolic = jnp.arcsinh((y - x * ll) * jnp.sqrt(x**2 - 1))
# Parabolic case (x == 1.0)
parabolic = jnp.zeros_like(x)
# Use where statements instead of conditionals
result = jnp.where(x < 1.0, elliptic, jnp.where(x > 1.0, hyperbolic, parabolic))
return result
@jit
def _tof_equation_y(x, y, T0, ll, M):
"""Time of flight equation with externally computed y."""
# Special case for small number of revolutions and specific x range
# Calculate values for small M case
eta = y - ll * x
S_1 = (1 - ll - x * eta) * 0.5
Q = 4 / 3 * _hyp2f1b(S_1)
small_M_result = (eta**3 * Q + 4 * ll * eta) * 0.5
# Calculate values for general case
psi = _compute_psi(x, y, ll)
sqrt_term = jnp.sqrt(jnp.abs(1 - x**2))
numerator = jnp.divide(psi + M * jnp.pi, sqrt_term) - x + ll * y
denominator = 1 - x**2
general_result = jnp.divide(numerator, denominator)
# Use where for conditional selection
use_small_M = (M == 0) & (jnp.sqrt(0.6) < x) & (x < jnp.sqrt(1.4))
T_ = jnp.where(use_small_M, small_M_result, general_result)
result = T_ - T0
return result
@jit
def _tof_equation(x, T0, ll, M):
"""Time of flight equation."""
y = _compute_y(x, ll)
return _tof_equation_y(x, y, T0, ll, M)
@jit
def _tof_equation_p(x, y, T, ll):
"""First derivative of the time of flight equation."""
return (3 * T * x - 2 + 2 * ll**3 * x / y) / (1 - x**2)
@jit
def _tof_equation_p2(x, y, T, dT, ll):
"""Second derivative of the time of flight equation."""
return (3 * T + 5 * x * dT + 2 * (1 - ll**2) * ll**3 / y**3) / (1 - x**2)
@jit
def _tof_equation_p3(x, y, _, dT, ddT, ll):
"""Third derivative of the time of flight equation."""
return (7 * x * ddT + 8 * dT - 6 * (1 - ll**2) * ll**5 * x / y**5) / (1 - x**2)
@jit
def _initial_guess(T, ll, M, low_path):
"""Initial guess for the iterative algorithm."""
# Single revolution case (M == 0)
T_0 = jnp.arccos(ll) + ll * jnp.sqrt(1 - ll**2) + M * jnp.pi # Equation 19
T_1 = 2 * (1 - ll**3) / 3 # Equation 21
# Determine initial guess based on T
x_T0 = (T_0 / T) ** (2 / 3) - 1
x_T1 = 5 / 2 * T_1 / T * (T_1 - T) / (1 - ll**5) + 1
# For the middle case, use logarithmic interpolation
x_middle = jnp.exp(jnp.log(2) * jnp.log(T / T_0) / jnp.log(T_1 / T_0)) - 1
# Create masks for the cases
case_T_ge_T0 = T >= T_0
case_T_lt_T1 = T < T_1
# Select appropriate single-revolution case
x_single_rev = jnp.where(
case_T_ge_T0, x_T0, jnp.where(case_T_lt_T1, x_T1, x_middle)
)
# Multiple revolution case
x_0l = (((M * jnp.pi + jnp.pi) / (8 * T)) ** (2 / 3) - 1) / (
((M * jnp.pi + jnp.pi) / (8 * T)) ** (2 / 3) + 1
)
x_0r = (((8 * T) / (M * jnp.pi)) ** (2 / 3) - 1) / (
((8 * T) / (M * jnp.pi)) ** (2 / 3) + 1
)
# Choose high or low path for multi-rev
x_multi_rev = jnp.where(low_path, jnp.maximum(x_0l, x_0r), jnp.minimum(x_0l, x_0r))
# Final selection based on M
result = jnp.where(M == 0, x_single_rev, x_multi_rev)
return result
@jit
def _householder(p0, T0, ll, M, atol, rtol, maxiter):
"""Find a zero of time of flight equation using the Householder method."""
def body_fun(state):
p0, p, iter_count = state
# Update p0 to be the current p
p0 = p
# Compute values needed for the Householder step
y = _compute_y(p0, ll)
fval = _tof_equation_y(p0, y, T0, ll, M)
T = fval + T0
fder = _tof_equation_p(p0, y, T, ll)
fder2 = _tof_equation_p2(p0, y, T, fder, ll)
fder3 = _tof_equation_p3(p0, y, T, fder, fder2, ll)
# Householder step (quartic)
numerator = fder**2 - fval * fder2 / 2
denominator = fder * (fder**2 - fval * fder2) + fder3 * fval**2 / 6
# Avoid division by zero with a safer approach
safe_denominator = jnp.where(
jnp.abs(denominator) < 1e-15, jnp.sign(denominator) * 1e-15, denominator
)
# Compute the new value with safeguards against large steps
delta = fval * (numerator / safe_denominator)
# Limit step size to prevent divergence
max_step = jnp.maximum(0.1, jnp.abs(p0))
delta = jnp.clip(delta, -max_step, max_step)
p = p0 - delta
return (p0, p, iter_count + 1)
def cond_fun(state):
p0, p, iter_count = state
# Check both max iterations and convergence criteria
delta_x = jnp.abs(p - p0)
converged = delta_x < (rtol * jnp.abs(p0) + atol)
iter_remaining = iter_count < maxiter
return (~converged) & iter_remaining
# Initialize state with different p value to ensure at least one iteration runs
# Set p to be slightly different from p0 to trigger at least one iteration
init_p = (
p0 * 1.1 + 0.01
) # Ensure it's different enough to avoid immediate convergence
init_state = (p0, init_p, 0)
# Run the Householder iterations
_, p, iteration_count = lax.while_loop(cond_fun, body_fun, init_state)
return p
@jit
def _compute_T_min(ll, M, maxiter, atol, rtol):
"""Compute minimum T."""
# Case 1: ll == 1
x_T_min_case1 = 0.0
T_min_case1 = _tof_equation(x_T_min_case1, 0.0, ll, M)
# Case 2: ll != 1 and M == 0
x_T_min_case2 = jnp.inf
T_min_case2 = 0.0
# Case 3: ll != 1 and M != 0
# Set x_i > 0 to avoid problems at ll = -1
x_i = 0.1
T_i = _tof_equation(x_i, 0.0, ll, M)
x_T_min_case3 = _halley(x_i, T_i, ll, atol, rtol, maxiter)
T_min_case3 = _tof_equation(x_T_min_case3, 0.0, ll, M)
# Select the appropriate case using masks
is_ll_one = jnp.abs(ll - 1.0) < 1e-10
is_m_zero = M == 0
# First choose between case 2 and case 3 based on M
x_T_min_case23 = jnp.where(is_m_zero, x_T_min_case2, x_T_min_case3)
T_min_case23 = jnp.where(is_m_zero, T_min_case2, T_min_case3)
# Then choose between case 1 and the result of case23 based on ll
x_T_min = jnp.where(is_ll_one, x_T_min_case1, x_T_min_case23)
T_min = jnp.where(is_ll_one, T_min_case1, T_min_case23)
return x_T_min, T_min
@jit
def _halley(p0, T0, ll, atol, rtol, maxiter):
"""Find a minimum of time of flight equation using the Halley method."""
def body_fun(state):
p0, iter_count = state
y = _compute_y(p0, ll)
fder = _tof_equation_p(p0, y, T0, ll)
fder2 = _tof_equation_p2(p0, y, T0, fder, ll)
fder3 = _tof_equation_p3(p0, y, T0, fder, fder2, ll)
# Halley step (cubic)
p = p0 - 2 * fder * fder2 / (2 * fder2**2 - fder * fder3)
return (p, iter_count + 1)
def cond_fun(state):
p0, iter_count = state
p, _ = body_fun(state)
return (jnp.abs(p - p0) >= rtol * jnp.abs(p0) + atol) & (iter_count < maxiter)
# Run the Halley iterations
p, _ = lax.while_loop(cond_fun, body_fun, (p0, 0))
return p
@jit
def _find_xy(ll, T, M, maxiter, atol, rtol, low_path):
"""Computes all x, y for given number of revolutions."""
# Calculate M_max
M_max = jnp.floor(T / jnp.pi)
T_00 = jnp.arccos(jnp.abs(ll)) + jnp.abs(ll) * jnp.sqrt(1 - ll**2)
# Possibly refine M_max using JAX-native operations
need_refine = (T < T_00 + M_max * jnp.pi) & (M_max > 0)
# Define the refine function
def refine_fn(need_refine):
_, T_min = _compute_T_min(ll, M_max, maxiter, atol, rtol)
return jnp.where(T < T_min, M_max - 1, M_max)
# Use lax.cond for this specific case as it's crucial
M_max = lax.cond(
need_refine, lambda _: refine_fn(need_refine), lambda _: M_max, None
)
# Check if solution exists - use explicit debug to verify condition
valid_ll_value = jnp.abs(ll) < 1.0
valid_M = M <= M_max
# Get initial guess
x_0 = _initial_guess(T, ll, M, low_path)
# Run Householder iterations
x = _householder(x_0, T, ll, M, atol, rtol, maxiter)
y = _compute_y(x, ll)
# Return NaN for invalid cases using jnp.where which is more reliable
# Check validation condition explicitly again
x_result = jnp.where(valid_ll_value & valid_M, x, jnp.nan)
y_result = jnp.where(valid_ll_value & valid_M, y, jnp.nan)
return x_result, y_result
[docs]
@jit
def izzo_lambert(
r1: jnp.ndarray,
r2: jnp.ndarray,
tof: float,
mu: float = MU,
M: int = 0,
prograde: bool = True,
low_path: bool = True,
maxiter: int = 35,
atol: float = 1e-10,
rtol: float = 1e-10,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Solves Lambert's problem using Izzo's devised algorithm.
Parameters
----------
r1: jnp.ndarray
Initial position vector.
r2: jnp.ndarray
Final position vector.
tof: float
Time of flight.
mu: float
Gravitational parameter, equivalent to GM of attractor body.
M: int
Number of revolutions. Must be equal or greater than 0.
prograde: bool
If True, specifies prograde motion. Otherwise, retrograde motion is imposed.
low_path: bool
If two solutions are available, it selects between high or low path.
maxiter: int
Maximum number of iterations.
atol: float
Absolute tolerance.
rtol: float
Relative tolerance.
Returns
-------
v1: jnp.ndarray
Initial velocity vector.
v2: jnp.ndarray
Final velocity vector.
"""
# Compute basic geometric quantities
c = r2 - r1 # Chord
c_norm = jnp.linalg.norm(c)
r1_norm = jnp.linalg.norm(r1)
r2_norm = jnp.linalg.norm(r2)
# Semiperimeter
s = (r1_norm + r2_norm + c_norm) * 0.5
# Normalized vectors
i_r1 = r1 / r1_norm
i_r2 = r2 / r2_norm
# Compute angular momentum unit vector
i_h = jnp.cross(i_r1, i_r2)
i_h_norm = jnp.linalg.norm(i_h)
i_h = i_h / jnp.where(i_h_norm > 0, i_h_norm, 1.0) # Avoid division by zero
# Geometry of the problem: lambda parameter
ll = jnp.sqrt(1 - jnp.minimum(1.0, c_norm / s))
# Adjust lambda and compute transfer direction based on orbit inclination
ll_sign = jnp.where(i_h[2] < 0, -1.0, 1.0)
ll = ll * ll_sign
# Compute tangential directions
i_t1 = jnp.where(i_h[2] < 0, jnp.cross(i_r1, i_h), jnp.cross(i_h, i_r1))
i_t2 = jnp.where(i_h[2] < 0, jnp.cross(i_r2, i_h), jnp.cross(i_h, i_r2))
# Account for retrograde motion
ll = jnp.where(~prograde, -ll, ll)
i_t1 = jnp.where(~prograde, -i_t1, i_t1)
i_t2 = jnp.where(~prograde, -i_t2, i_t2)
# Non-dimensional time of flight
T = jnp.sqrt(2 * mu / s**3) * tof
# Find x, y using the new approach
x, y = _find_xy(ll, T, M, maxiter, atol, rtol, low_path)
# Perform explicit NaN check with JAX operations
has_nan_x = jnp.isnan(x)
has_nan_y = jnp.isnan(y)
has_nans = has_nan_x | has_nan_y
# Always compute all solution components
# Reconstruct the solution - always compute these values
gamma = jnp.sqrt(mu * s / 2)
rho = (r1_norm - r2_norm) / c_norm
sigma = jnp.sqrt(1 - rho**2)
# Compute velocity components
V_r1 = gamma * ((ll * y - x) - rho * (ll * y + x)) / r1_norm
V_r2 = -gamma * ((ll * y - x) + rho * (ll * y + x)) / r2_norm
V_t1 = gamma * sigma * (y + ll * x) / r1_norm
V_t2 = gamma * sigma * (y + ll * x) / r2_norm
# Construct velocity vectors
v1_valid = V_r1 * i_r1 + V_t1 * i_t1
v2_valid = V_r2 * i_r2 + V_t2 * i_t2
# Create NaN vectors for invalid cases
v1_nan = jnp.full_like(r1, jnp.nan)
v2_nan = jnp.full_like(r2, jnp.nan)
# Select the appropriate results using where
v1 = jnp.where(has_nans, v1_nan, v1_valid)
v2 = jnp.where(has_nans, v2_nan, v2_valid)
return v1, v2
# Vectorize the Lambert solver
_izzo_lambert_vmap = jit(
vmap(
izzo_lambert,
in_axes=(0, 0, 0, None, None, None, None, None, None, None),
out_axes=(0, 0),
)
)
[docs]
def solve_lambert(
r1: Union[np.ndarray, jnp.ndarray],
r2: Union[np.ndarray, jnp.ndarray],
tof: Union[np.ndarray, float],
mu: float = MU,
prograde: bool = True,
max_iter: int = 35,
tol: float = 1e-10,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Solve Lambert's problem for multiple initial and final positions and times of flight.
This implementation uses Izzo's method which is robust and handles all orbit types.
Parameters
----------
r1 : array_like (N, 3)
Initial position vectors in au.
r2 : array_like (N, 3)
Final position vectors in au.
tof : array_like (N) or float
Times of flight in days.
mu : float, optional
Gravitational parameter (GM) of the attracting body in units of au³/day².
prograde : bool, optional
If True, assume prograde motion. If False, assume retrograde motion.
max_iter : int, optional
Maximum number of iterations for convergence.
tol : float, optional
Convergence tolerance.
Returns
-------
v1 : ndarray (N, 3)
Initial velocity vectors in au/day with origin at the attractor
v2 : ndarray (N, 3)
Final velocity vectors in au/day with origin at the attractor
"""
# Convert inputs to jnp arrays
r1 = jnp.asarray(r1)
r2 = jnp.asarray(r2)
# Handle scalar inputs
if r1.ndim == 1:
r1 = r1.reshape(1, -1)
if r2.ndim == 1:
r2 = r2.reshape(1, -1)
# Convert tof to array
if isinstance(tof, (int, float)):
tof = jnp.full(r1.shape[0], tof)
else:
tof = jnp.asarray(tof)
# Call vectorized solver (M=0 for single-revolution case)
v1, v2 = _izzo_lambert_vmap(r1, r2, tof, mu, 0, prograde, True, max_iter, tol, tol)
# Convert to numpy arrays
v1 = np.asarray(v1)
v2 = np.asarray(v2)
return v1, v2
@jit
def _calculate_c3(
v1: Union[np.ndarray, jnp.ndarray], body_v: Union[np.ndarray, jnp.ndarray]
) -> jnp.ndarray:
v_infinity = v1 - body_v
# Use jnp.linalg.norm for JAX arrays to avoid TracerArrayConversionError
c3 = jnp.linalg.norm(v_infinity, axis=1) ** 2
return c3
[docs]
def calculate_c3(
v1: Union[np.ndarray, jnp.ndarray], body_v: Union[np.ndarray, jnp.ndarray]
) -> npt.NDArray[np.float64]:
"""
Calculate the C3 of a spacecraft given its velocity relative to a body.
Parameters
----------
v1 : array_like (N, 3)
Velocity of the spacecraft in au/d.
body_v : array_like (N, 3)
Velocity of the body in au/d.
Returns
-------
c3 : array_like (N)
C3 of the spacecraft in au^2/d^2.
"""
c3 = _calculate_c3(v1, body_v)
# Convert to numpy array before returning
c3 = np.asarray(c3)
return c3