Source code for adam_core.missions.porkchop

import logging
import multiprocessing as mp
import warnings
from typing import List, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
import quivr as qv
import ray
from astropy.time import Time

from adam_core.coordinates import CartesianCoordinates, transform_coordinates
from adam_core.coordinates.origin import Origin, OriginCodes
from adam_core.coordinates.spherical import SphericalCoordinates
from adam_core.coordinates.units import au_per_day_to_km_per_s
from adam_core.dynamics.lambert import calculate_c3, solve_lambert
from adam_core.orbits import Orbits
from adam_core.propagator import Propagator
from adam_core.ray_cluster import initialize_use_ray
from adam_core.time import Timestamp
from adam_core.utils import get_perturber_state
from adam_core.utils.iter import _iterate_chunk_indices
from adam_core.utils.plots.logos import AsteroidInstituteLogoLight, get_logo_base64

logger = logging.getLogger(__name__)


[docs] def generate_saturated_colorscale( base_color: str, n_levels: int = 8, max_alpha: float = 0.8, min_alpha: float = 0.1 ) -> List[List]: """ Generate a colorscale from light to dark based on a base color with full saturation and variable transparency that increases with color intensity. Parameters ---------- base_color : str Base color name (e.g., 'red', 'blue') or hex code (e.g., '#FF0000') n_levels : int, optional Number of levels in the colorscale (default: 8) max_alpha : float, optional Maximum alpha (opacity) for darkest colors (default: 0.8) min_alpha : float, optional Minimum alpha (opacity) for lightest colors (default: 0.1) Returns ------- List[List] Plotly colorscale format with RGBA: [[position, color], ...] """ # Color mapping for common base colors to RGB with full saturation color_map = { "red": (255, 0, 0), # Pure red, full saturation "blue": (0, 0, 255), # Pure blue, full saturation "green": (0, 255, 0), # Pure green, full saturation "orange": (255, 165, 0), # Pure orange, full saturation "purple": (128, 0, 128), # Pure purple, full saturation "yellow": (255, 255, 0), # Pure yellow, full saturation "cyan": (0, 255, 255), # Pure cyan, full saturation "magenta": (255, 0, 255), # Pure magenta, full saturation } # Parse base color if base_color.startswith("#"): # Hex color hex_color = base_color.lstrip("#") base_rgb = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) elif base_color.lower() in color_map: base_rgb = color_map[base_color.lower()] else: # Default to pure red if unknown base_rgb = (255, 0, 0) colorscale = [] for i in range(n_levels): position = i / (n_levels - 1) # Create lightness variation while maintaining full saturation # Lighter colors: mix with white but preserve hue # Darker colors: reduce brightness but keep saturation if position == 0: # Lightest: mix with white for pastel effect lightness_factor = 0.9 # Very light r = int(base_rgb[0] * lightness_factor + 255 * (1 - lightness_factor)) g = int(base_rgb[1] * lightness_factor + 255 * (1 - lightness_factor)) b = int(base_rgb[2] * lightness_factor + 255 * (1 - lightness_factor)) else: # Use power curve for smooth transition intensity = np.power(position, 0.8) # Maintain saturation by scaling from full saturation down r = int(base_rgb[0] * (0.3 + 0.7 * intensity)) g = int(base_rgb[1] * (0.3 + 0.7 * intensity)) b = int(base_rgb[2] * (0.3 + 0.7 * intensity)) # Ensure values are within valid RGB range r = max(0, min(255, r)) g = max(0, min(255, g)) b = max(0, min(255, b)) # Calculate alpha based on position (lighter = more transparent) alpha = min_alpha + (max_alpha - min_alpha) * position colorscale.append([position, f"rgba({r}, {g}, {b}, {alpha:.2f})"]) return colorscale
[docs] def generate_perceptual_colorscale( base_color: str, n_levels: int = 8, min_lightness: float = 0.3, max_lightness: float = 0.9, max_alpha: float = 0.8, min_alpha: float = 0.1, ) -> List[List]: """ Generate a perceptually uniform colorscale with full saturation and variable transparency that works better for overlaying contours. Parameters ---------- base_color : str Base color name (e.g., 'red', 'blue') or hex code (e.g., '#FF0000') n_levels : int, optional Number of levels in the colorscale (default: 8) min_lightness : float, optional Minimum lightness value (0-1, default: 0.3 for good contrast) max_lightness : float, optional Maximum lightness value (0-1, default: 0.9 for visibility with transparency) max_alpha : float, optional Maximum alpha (opacity) for darkest colors (default: 0.8) min_alpha : float, optional Minimum alpha (opacity) for lightest colors (default: 0.1) Returns ------- List[List] Plotly colorscale format with RGBA: [[position, color], ...] """ # Full saturation color mapping for maximum color purity color_map = { "red": (255, 0, 0), # Pure red "blue": (0, 0, 255), # Pure blue "green": (0, 255, 0), # Pure green "orange": (255, 165, 0), # Pure orange "purple": (128, 0, 128), # Pure purple "yellow": (255, 255, 0), # Pure yellow "cyan": (0, 255, 255), # Pure cyan "magenta": (255, 0, 255), # Pure magenta } # Parse base color if base_color.startswith("#"): hex_color = base_color.lstrip("#") base_rgb = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) elif base_color.lower() in color_map: base_rgb = color_map[base_color.lower()] else: base_rgb = (255, 0, 0) # Default pure red # Convert base color to normalized RGB for calculations base_r, base_g, base_b = [x / 255.0 for x in base_rgb] colorscale = [] for i in range(n_levels): position = i / (n_levels - 1) # Create perceptually uniform lightness steps lightness = max_lightness - (position * (max_lightness - min_lightness)) # Use full saturation throughout, varying only lightness if position == 0: # Lightest: mix with white while maintaining hue white_mix = 1 - lightness r = base_r * lightness + white_mix g = base_g * lightness + white_mix b = base_b * lightness + white_mix else: # Scale the base color by lightness while maintaining saturation # Use HSV-like scaling to preserve hue and saturation max_component = max(base_r, base_g, base_b) if max_component > 0: # Scale all components proportionally to achieve desired lightness scale_factor = lightness / max_component r = base_r * scale_factor g = base_g * scale_factor b = base_b * scale_factor else: r = g = b = lightness # Convert back to 0-255 range and ensure validity r_int = max(0, min(255, int(r * 255))) g_int = max(0, min(255, int(g * 255))) b_int = max(0, min(255, int(b * 255))) # Calculate alpha based on position (lighter = more transparent) alpha = min_alpha + (max_alpha - min_alpha) * position colorscale.append([position, f"rgba({r_int}, {g_int}, {b_int}, {alpha:.2f})"]) return colorscale
# class LambertOutput(qv.Table):
[docs] class LambertSolutions(qv.Table): # departure_state = CartesianCoordinates.as_column() # arrival_state = CartesianCoordinates.as_column() departure_body_id = qv.LargeStringColumn() departure_time = Timestamp.as_column() departure_body_x = qv.Float64Column() departure_body_y = qv.Float64Column() departure_body_z = qv.Float64Column() departure_body_vx = qv.Float64Column() departure_body_vy = qv.Float64Column() departure_body_vz = qv.Float64Column() arrival_body_id = qv.LargeStringColumn() arrival_time = Timestamp.as_column() arrival_body_x = qv.Float64Column() arrival_body_y = qv.Float64Column() arrival_body_z = qv.Float64Column() arrival_body_vx = qv.Float64Column() arrival_body_vy = qv.Float64Column() arrival_body_vz = qv.Float64Column() solution_departure_vx = qv.Float64Column() solution_departure_vy = qv.Float64Column() solution_departure_vz = qv.Float64Column() solution_arrival_vx = qv.Float64Column() solution_arrival_vy = qv.Float64Column() solution_arrival_vz = qv.Float64Column() frame = qv.StringAttribute(default="unspecified") origin = Origin.as_column()
[docs] def departure_body_orbit(self) -> Orbits: """ Return the departure body orbit. """ return Orbits.from_kwargs( orbit_id=self.departure_body_id, coordinates=CartesianCoordinates.from_kwargs( time=self.departure_time, x=self.departure_body_x, y=self.departure_body_y, z=self.departure_body_z, vx=self.departure_body_vx, vy=self.departure_body_vy, vz=self.departure_body_vz, origin=self.origin, frame=self.frame, ), )
[docs] def arrival_body_orbit(self) -> Orbits: """ Return the arrival body orbit. """ return Orbits.from_kwargs( orbit_id=self.arrival_body_id, coordinates=CartesianCoordinates.from_kwargs( time=self.arrival_time, x=self.arrival_body_x, y=self.arrival_body_y, z=self.arrival_body_z, vx=self.arrival_body_vx, vy=self.arrival_body_vy, vz=self.arrival_body_vz, origin=self.origin, frame=self.frame, ), )
[docs] def solution_departure_orbit(self) -> Orbits: """ Return the solution departure orbit. """ solution_departure_orbit_id = [ f"solution_departure_orbit_{i}" for i in range(len(self.solution_departure_vx)) ] return Orbits.from_kwargs( orbit_id=solution_departure_orbit_id, coordinates=CartesianCoordinates.from_kwargs( time=self.departure_time, x=self.departure_body_x, y=self.departure_body_y, z=self.departure_body_z, vx=self.solution_departure_vx, vy=self.solution_departure_vy, vz=self.solution_departure_vz, origin=self.origin, frame=self.frame, ), )
[docs] def solution_arrival_orbit(self) -> Orbits: """ Return the solution arrival orbit. """ solution_arrival_orbit_id = [ f"solution_arrival_orbit_{i}" for i in range(len(self.solution_arrival_vx)) ] return Orbits.from_kwargs( orbit_id=solution_arrival_orbit_id, coordinates=CartesianCoordinates.from_kwargs( time=self.arrival_time, x=self.arrival_body_x, y=self.arrival_body_y, z=self.arrival_body_z, vx=self.solution_arrival_vx, vy=self.solution_arrival_vy, vz=self.solution_arrival_vz, origin=self.origin, frame=self.frame, ), )
[docs] def c3_departure(self) -> npt.NDArray[np.float64]: """ Return the C3 in au^2/d^2. """ return calculate_c3( np.array( self.table.select( [ "solution_departure_vx", "solution_departure_vy", "solution_departure_vz", ] ) ), np.array( self.table.select( ["departure_body_vx", "departure_body_vy", "departure_body_vz"] ) ), )
[docs] def c3_arrival(self) -> npt.NDArray[np.float64]: """ Return the C3 in au^2/d^2. """ return calculate_c3( np.array( self.table.select( [ "solution_arrival_vx", "solution_arrival_vy", "solution_arrival_vz", ] ) ), np.array( self.table.select( ["arrival_body_vx", "arrival_body_vy", "arrival_body_vz"] ) ), )
[docs] def vinf_departure(self) -> npt.NDArray[np.float64]: """ Return the v infinity in au/d. """ return np.linalg.norm( np.array( self.table.select( [ "solution_departure_vx", "solution_departure_vy", "solution_departure_vz", ] ) ) - np.array( self.table.select( ["departure_body_vx", "departure_body_vy", "departure_body_vz"] ) ), axis=1, )
[docs] def vinf_arrival(self) -> npt.NDArray[np.float64]: """ Return the v infinity in au/d. """ return np.linalg.norm( np.array( self.table.select( [ "solution_arrival_vx", "solution_arrival_vy", "solution_arrival_vz", ] ) ) - np.array( self.table.select( ["arrival_body_vx", "arrival_body_vy", "arrival_body_vz"] ) ), axis=1, )
[docs] def time_of_flight(self) -> npt.NDArray[np.float64]: """ Return the time of flight in days. """ return self.arrival_time.mjd().to_numpy( zero_copy_only=False ) - self.departure_time.mjd().to_numpy(zero_copy_only=False)
[docs] def departure_spherical_coordinates( departure_origin: OriginCodes, times: Timestamp, frame: str, vx: npt.NDArray[np.float64], vy: npt.NDArray[np.float64], vz: npt.NDArray[np.float64], ) -> SphericalCoordinates: """ Return the spherical coordinates of the departure vector. Parameters ---------- departure_origin : OriginCodes The origin of the departure and also the frame of the departure vectors. times : Timestamp The times of the departure vectors. frame : str The frame of the departure vectors. vx : npt.NDArray[np.float64] The x-component of the departure vectors. vy : npt.NDArray[np.float64] The y-component of the departure vectors. vz : npt.NDArray[np.float64] The z-component of the departure vectors. Returns ------- SphericalCoordinates The spherical coordinates of the departure unit vectors. Can be used to express ra / dec of the departure direction. """ assert ( len(vx) == len(vy) == len(vz) == len(times) ), "All arrays must have the same length" assert len(vx) > 0, "At least one departure vector is required" # Create unit direction vectors from the velocity vectors # Normalize the velocity vectors to get direction only velocity_magnitude = np.sqrt(vx**2 + vy**2 + vz**2) direction_x = vx / velocity_magnitude direction_y = vy / velocity_magnitude direction_z = vz / velocity_magnitude # Create CartesianCoordinates with the direction as position (on unit sphere) # and zero velocity since we only care about the direction direction_coords = CartesianCoordinates.from_kwargs( time=times, x=direction_x, # Unit vector pointing in velocity direction y=direction_y, z=direction_z, vx=np.zeros_like(vx), # No velocity needed for direction vy=np.zeros_like(vy), vz=np.zeros_like(vz), # From our departing origin. origin=Origin.from_OriginCodes(departure_origin, size=len(vx)), frame=frame, ) # Transform direction to equatorial frame for proper RA/Dec coordinates # These are inertial celestial coordinates, suitable for any departure origin spherical = transform_coordinates( direction_coords, SphericalCoordinates, frame_out="equatorial", origin_out=departure_origin, ) return spherical
[docs] def lambert_worker( departure_orbits: Orbits, arrival_orbits: Orbits, propagation_origin: OriginCodes, prograde: bool = True, max_iter: int = 35, tol: float = 1e-10, ) -> LambertSolutions: # Extract coordinates from orbits departure_coordinates = departure_orbits.coordinates arrival_coordinates = arrival_orbits.coordinates r1 = departure_coordinates.r r2 = arrival_coordinates.r tof = arrival_coordinates.time.mjd().to_numpy( zero_copy_only=False ) - departure_coordinates.time.mjd().to_numpy(zero_copy_only=False) origins = Origin.from_OriginCodes(propagation_origin, size=len(r1)) mu = origins.mu()[0] v1, v2 = solve_lambert(r1, r2, tof, mu, prograde, max_iter, tol) # Use actual orbit IDs from the Orbits objects departure_body_ids = departure_orbits.orbit_id.to_pylist() arrival_body_ids = arrival_orbits.orbit_id.to_pylist() return LambertSolutions.from_kwargs( departure_body_id=departure_body_ids, departure_time=departure_coordinates.time, departure_body_x=departure_coordinates.x, departure_body_y=departure_coordinates.y, departure_body_z=departure_coordinates.z, departure_body_vx=departure_coordinates.vx, departure_body_vy=departure_coordinates.vy, departure_body_vz=departure_coordinates.vz, arrival_body_id=arrival_body_ids, arrival_time=arrival_coordinates.time, arrival_body_x=arrival_coordinates.x, arrival_body_y=arrival_coordinates.y, arrival_body_z=arrival_coordinates.z, arrival_body_vx=arrival_coordinates.vx, arrival_body_vy=arrival_coordinates.vy, arrival_body_vz=arrival_coordinates.vz, solution_departure_vx=v1[:, 0], solution_departure_vy=v1[:, 1], solution_departure_vz=v1[:, 2], solution_arrival_vx=v2[:, 0], solution_arrival_vy=v2[:, 1], solution_arrival_vz=v2[:, 2], frame=departure_coordinates.frame, origin=origins, )
lambert_worker_remote = ray.remote(lambert_worker)
[docs] def prepare_and_propagate_orbits( body: Union[Orbits, OriginCodes], start_time: Timestamp, end_time: Timestamp, propagation_origin: OriginCodes = OriginCodes.SUN, step_size: float = 1.0, propagator_class: Optional[type[Propagator]] = None, max_processes: Optional[int] = 1, ) -> Orbits: """ Prepare and propagate orbits for a single body over a specified time range. Parameters ---------- body : Union[Orbits, OriginCodes] The body to propagate (either an Orbits object or an OriginCode for a major body). start_time : Timestamp The start time for propagation. end_time : Timestamp The end time for propagation. propagation_origin : OriginCodes, optional The origin of the propagation (default: SUN). step_size : float, optional The step size in days (default: 1.0). propagator_class : Optional[type[Propagator]], optional The propagator class to use for orbit propagation. max_processes : Optional[int], optional The maximum number of processes to use. Returns ------- Orbits The propagated orbits over the specified time range. """ # if body is an Orbit, ensure its origin is the propagation_origin if isinstance(body, Orbits): body = body.set_column( "coordinates", transform_coordinates( body.coordinates, representation_out=CartesianCoordinates, frame_out="ecliptic", origin_out=propagation_origin, ), ) times = Timestamp.from_mjd( np.arange( start_time.rescale("tdb").mjd()[0].as_py(), end_time.rescale("tdb").mjd()[0].as_py(), step_size, ), scale="tdb", ) # get orbits for the body at specified times if isinstance(body, Orbits): propagator = propagator_class() orbits = propagator.propagate_orbits(body, times, max_processes=max_processes) else: # For major bodies, create an Orbits object with the body's origin code as the orbit_id coordinates = get_perturber_state( body, times, frame="ecliptic", origin=propagation_origin ) # Create orbit IDs based on the body name and time index orbit_ids = np.repeat(body.name, len(coordinates)) orbits = Orbits.from_kwargs( orbit_id=orbit_ids, coordinates=coordinates, ) return orbits
[docs] def generate_porkchop_data( departure_orbits: Orbits, arrival_orbits: Orbits, propagation_origin: OriginCodes = OriginCodes.SUN, prograde: bool = True, max_iter: int = 35, tol: float = 1e-10, max_processes: Optional[int] = 1, ) -> LambertSolutions: """ Generate data for a porkchop plot by solving Lambert's problem for a grid of departure and arrival times. Parameters ---------- departure_orbits : Orbits The departure orbits. arrival_orbits : Orbits The arrival orbits. propagation_origin : OriginCodes The origin of the propagation. prograde : bool, optional If True, assume prograde motion. If False, assume retrograde motion. max_iter : int, optional The maximum number of iterations for Lambert's solver. tol : float, optional The numerical tolerance for Lambert's solver. max_processes : Optional[int], optional The maximum number of processes to use. max_processes : Optional[int], optional The maximum number of processes to use. Returns ------- porkchop_data : LambertOutput The porkchop data. """ assert ( departure_orbits.coordinates.frame == arrival_orbits.coordinates.frame ), "Departure and arrival frames must be the same" assert len(departure_orbits.coordinates.origin.code.unique()) == 1 assert len(arrival_orbits.coordinates.origin.code.unique()) == 1 assert ( departure_orbits.coordinates.origin.code[0] == arrival_orbits.coordinates.origin.code[0] ), "Departure and arrival origins must be the same" # First let's make sure departure and arrival orbits are time-ordered departure_orbits = departure_orbits.sort_by( ["coordinates.time.days", "coordinates.time.nanos"] ) arrival_orbits = arrival_orbits.sort_by( ["coordinates.time.days", "coordinates.time.nanos"] ) # Get the actual times for comparison dep_times_mjd = departure_orbits.coordinates.time.mjd().to_numpy( zero_copy_only=False ) arr_times_mjd = arrival_orbits.coordinates.time.mjd().to_numpy(zero_copy_only=False) # Create meshgrids of indices and times dep_indices, arr_indices = np.meshgrid( np.arange(len(departure_orbits)), np.arange(len(arrival_orbits)) ) dep_time_grid, arr_time_grid = np.meshgrid(dep_times_mjd, arr_times_mjd) # Filter to ensure departure time is before arrival time # Use actual time comparison instead of index comparison valid_indices = arr_time_grid > dep_time_grid # Apply the mask to flatten only valid combinations dep_indices_flat = dep_indices[valid_indices].flatten() arr_indices_flat = arr_indices[valid_indices].flatten() stacked_departure_orbits = departure_orbits.take(dep_indices_flat) stacked_arrival_orbits = arrival_orbits.take(arr_indices_flat) # If no valid combinations exist, return empty results if len(stacked_departure_orbits) == 0: return LambertSolutions.empty() if max_processes is None: max_processes = mp.cpu_count() use_ray = initialize_use_ray(max_processes) lambert_results = LambertSolutions.empty() if use_ray: futures = [] for start, end in _iterate_chunk_indices( stacked_departure_orbits, chunk_size=100 ): futures.append( lambert_worker_remote.remote( stacked_departure_orbits[start:end], stacked_arrival_orbits[start:end], propagation_origin, prograde, max_iter, tol, ) ) if len(futures) > 1.5 * max_processes: finished, futures = ray.wait(futures, num_returns=1) result = ray.get(finished[0]) lambert_results = qv.concatenate([lambert_results, result]) while futures: finished, futures = ray.wait(futures, num_returns=1) result = ray.get(finished[0]) lambert_results = qv.concatenate([lambert_results, result]) else: lambert_results = lambert_worker( stacked_departure_orbits, stacked_arrival_orbits, propagation_origin, prograde, max_iter, tol, ) return lambert_results
[docs] def plot_porkchop_plotly( porkchop_data: LambertSolutions, width: int = 900, height: int = 700, c3_departure_min: Optional[float] = None, c3_departure_max: Optional[float] = None, vinf_arrival_min: Optional[float] = None, vinf_arrival_max: Optional[float] = None, tof_min: Optional[float] = None, tof_max: Optional[float] = None, c3_base_colorscale: str = "Reds", vinf_base_colorscale: str = "Blues", tof_line_color: str = "black", xlim_mjd: Optional[Tuple[float, float]] = None, ylim_mjd: Optional[Tuple[float, float]] = None, title: str = "Porkchop Plot", show_optimal: bool = True, show_hover: bool = False, logo: bool = True, ): """ Plot the porkchop plot from Lambert trajectory data using Plotly. Parameters ---------- porkchop_data : LambertOutput The porkchop data. width : int, optional The width of the plot. height : int, optional The height of the plot. c3_departure_min : float, optional The minimum C3 departure value. c3_departure_max : float, optional The maximum C3 departure value. vinf_arrival_min : float, optional The minimum V∞ arrival value. vinf_arrival_max : float, optional The maximum V∞ arrival value. tof_min : float, optional The minimum time of flight value. tof_max : float, optional The maximum time of flight value. c3_base_colorscale : str, optional The base colorscale for C3. vinf_base_colorscale : str, optional The base colorscale for V∞. tof_line_color : str, optional The color of the time of flight line. xlim_mjd : tuple, optional The x-axis limits in MJD. ylim_mjd : tuple, optional The y-axis limits in MJD. title : str, optional The title of the plot. show_optimal : bool, optional Whether to show the optimal V∞ point. show_hover : bool, optional Whether to show the hover information. logo : bool, optional Whether to show the logo. Returns ------- fig : plotly.graph_objects.Figure The Plotly figure object. """ # --- Extract basic raw data --- c3_departure_au_d2 = porkchop_data.c3_departure() # C3 departure in (AU/day)^2 vinf_arrival_au_day = porkchop_data.vinf_arrival() # V∞ arrival in AU/day time_of_flight_days = porkchop_data.time_of_flight() departure_times = porkchop_data.departure_time arrival_times = porkchop_data.arrival_time # Convert to metric units using unit conversion functions c3_departure_km2_s2 = c3_departure_au_d2 * (au_per_day_to_km_per_s(1.0) ** 2) vinf_arrival_km_s = au_per_day_to_km_per_s(vinf_arrival_au_day) # Define default C3 range if not provided if c3_departure_min is None: c3_departure_min = 0 if c3_departure_max is None: c3_departure_max = np.max(c3_departure_km2_s2) assert c3_departure_max > c3_departure_min, "C3 max must be greater than C3 min" c3_step = (c3_departure_max - c3_departure_min) / 10 # 10 levels by default assert c3_step < ( c3_departure_max - c3_departure_min ), "C3 step must be less than the C3 range" # Define default V∞ range if not provided if vinf_arrival_min is None: vinf_arrival_min = 0 if vinf_arrival_max is None: vinf_arrival_max = np.max(vinf_arrival_km_s) vinf_step = (vinf_arrival_max - vinf_arrival_min) / 10 # 10 levels by default if tof_min is None: tof_min = 0 if tof_max is None: tof_max = np.max(time_of_flight_days) tof_step = max(5, (tof_max - tof_min) / 10) # 10 levels, minimum step of 5 days tof_step = round(tof_step / 5) * 5 # Round to multiple of 5 # Validate all step sizes are positive assert c3_step > 0, f"c3_step must be positive, got {c3_step}" assert vinf_step > 0, f"vinf_step must be positive, got {vinf_step}" assert tof_step > 0, f"tof_step must be positive, got {tof_step}" # Extract raw MJD values for all points departure_times_mjd = departure_times.mjd().to_numpy(zero_copy_only=False) arrival_times_mjd = arrival_times.mjd().to_numpy(zero_copy_only=False) # --- Apply all filtering to the actual data in one place --- # We want to keep all solutions that are not NaN and are within the specified ranges of c3, vinf and tof data_mask = ( ~np.isnan(c3_departure_km2_s2) & ~np.isnan(vinf_arrival_km_s) # Also filter out V∞ NaN values & (c3_departure_km2_s2 <= c3_departure_max) & (c3_departure_km2_s2 >= c3_departure_min) & (vinf_arrival_km_s >= vinf_arrival_min) & (vinf_arrival_km_s <= vinf_arrival_max) & (time_of_flight_days >= tof_min) & (time_of_flight_days <= tof_max) ) # Filter all our data arrays using the combined mask filtered_departure_mjd = departure_times_mjd[data_mask] filtered_arrival_mjd = arrival_times_mjd[data_mask] filtered_c3_km2_s2 = c3_departure_km2_s2[data_mask] filtered_vinf_km_s = vinf_arrival_km_s[data_mask] filtered_tof_days = time_of_flight_days[data_mask] # Recalculate min/max and step sizes based on filtered data if len(filtered_c3_km2_s2) > 0: c3_departure_min_filtered = np.min(filtered_c3_km2_s2) c3_departure_max_filtered = np.max(filtered_c3_km2_s2) c3_step_filtered = (c3_departure_max_filtered - c3_departure_min_filtered) / 10 if c3_step_filtered <= 0: c3_step_filtered = 1.0 # Fallback for constant data else: c3_departure_min_filtered = c3_departure_min c3_departure_max_filtered = c3_departure_max c3_step_filtered = c3_step if len(filtered_vinf_km_s) > 0: vinf_arrival_min_filtered = np.min(filtered_vinf_km_s) vinf_arrival_max_filtered = np.max(filtered_vinf_km_s) vinf_step_filtered = ( vinf_arrival_max_filtered - vinf_arrival_min_filtered ) / 10 if vinf_step_filtered <= 0: vinf_step_filtered = 1.0 # Fallback for constant data else: vinf_arrival_min_filtered = vinf_arrival_min vinf_arrival_max_filtered = vinf_arrival_max vinf_step_filtered = vinf_step if len(filtered_tof_days) > 0: tof_min_filtered = np.min(filtered_tof_days) tof_max_filtered = np.max(filtered_tof_days) tof_step_filtered = max(5, (tof_max_filtered - tof_min_filtered) / 10) tof_step_filtered = round(tof_step_filtered / 5) * 5 # Round to multiple of 5 if tof_step_filtered <= 0: tof_step_filtered = 5 # Fallback minimum step else: tof_min_filtered = tof_min tof_max_filtered = tof_max tof_step_filtered = tof_step # Get unique times from the filtered data - this guarantees all data points have corresponding unique times unique_departure_mjd, dep_indices = np.unique( filtered_departure_mjd, return_inverse=True ) unique_arrival_mjd, arr_indices = np.unique( filtered_arrival_mjd, return_inverse=True ) # Check if we have enough unique times to create a grid if len(unique_departure_mjd) < 2 or len(unique_arrival_mjd) < 2: warnings.warn( "Porkchop plotting: Not enough unique times for grid. Returning empty figure." ) fig_empty = go.Figure() fig_empty.update_layout( title=title + " (Insufficient data for grid)", xaxis_title="Departure Date", yaxis_title="Arrival Date", width=width, height=height, autosize=False, xaxis=dict(type="date"), yaxis=dict(type="date"), ) return fig_empty # Convert to datetime objects for plotting axes unique_departure_dates_dt = [ Time(mjd, format="mjd").datetime for mjd in unique_departure_mjd ] unique_arrival_dates_dt = [ Time(mjd, format="mjd").datetime for mjd in unique_arrival_mjd ] # --- Unit Conversions and Grid Setup --- # Create the grid including date combinations that do not have valid Lambert solutions grid_departure_mjd, grid_arrival_mjd = np.meshgrid( unique_departure_mjd, unique_arrival_mjd ) # Initialize grid arrays with NaN and fill using the filtered data # Since we used return_inverse=True, dep_indices and arr_indices are guaranteed to be valid grid_c3_departure_km2_s2 = np.full( (len(unique_arrival_mjd), len(unique_departure_mjd)), np.nan, dtype=np.float64 ) grid_vinf_arrival_km_s = np.full( (len(unique_arrival_mjd), len(unique_departure_mjd)), np.nan, dtype=np.float64 ) # Fill the grid directly - no validity masking needed since we pre-filtered the data grid_c3_departure_km2_s2[arr_indices, dep_indices] = filtered_c3_km2_s2 grid_vinf_arrival_km_s[arr_indices, dep_indices] = filtered_vinf_km_s grid_tof_days = grid_arrival_mjd - grid_departure_mjd # --- Use original grids with NaN values for native Plotly handling --- grid_c3_for_plot = grid_c3_departure_km2_s2 grid_vinf_for_plot = grid_vinf_arrival_km_s # Set up the date limits for the plot # Convert the min/max MJD values to datetime objects for Plotly xlim_dt = [ Time(np.min(grid_departure_mjd), format="mjd").datetime, Time(np.max(grid_departure_mjd), format="mjd").datetime, ] ylim_dt = [ Time(np.min(grid_arrival_mjd), format="mjd").datetime, Time(np.max(grid_arrival_mjd), format="mjd").datetime, ] # If explicit limits were provided, use those instead if xlim_mjd: xlim_dt = [ Time(xlim_mjd[0], format="mjd").datetime, Time(xlim_mjd[1], format="mjd").datetime, ] if ylim_mjd: ylim_dt = [ Time(ylim_mjd[0], format="mjd").datetime, Time(ylim_mjd[1], format="mjd").datetime, ] # --- Generate custom colorscales with better saturation at minimum values --- # Map common Plotly colorscale names to base colors colorscale_to_color = { "Reds": "red", "Blues": "blue", "Greens": "green", "Oranges": "orange", "Purples": "purple", } # Generate C3 colorscale with full saturation and built-in transparency if c3_base_colorscale in colorscale_to_color: # Using saturated colorscale with transparency built into the colorscale c3_colorscale = generate_saturated_colorscale( colorscale_to_color[c3_base_colorscale], n_levels=8, max_alpha=0.7, # Maximum opacity for darkest colors min_alpha=0.15, # Minimum opacity for lightest colors ) else: c3_colorscale = c3_base_colorscale # Generate V∞ colorscale with full saturation and built-in transparency if vinf_base_colorscale in colorscale_to_color: # Using saturated colorscale with transparency built into the colorscale vinf_colorscale = generate_saturated_colorscale( colorscale_to_color[vinf_base_colorscale], n_levels=8, max_alpha=0.7, # Maximum opacity for darkest colors min_alpha=0.15, # Minimum opacity for lightest colors ) else: vinf_colorscale = vinf_base_colorscale # --- Create hover information grids if requested --- hover_info = "none" custom_data = None hover_template = None if show_hover: # Create date strings for hover display grid_departure_date_strings = np.array( [ [ Time(mjd, format="mjd").strftime("%Y-%m-%d") for mjd in unique_departure_mjd ] for _ in unique_arrival_mjd ] ) grid_arrival_date_strings = np.array( [ [ Time(mjd, format="mjd").strftime("%Y-%m-%d") for _ in unique_departure_mjd ] for mjd in unique_arrival_mjd ] ) # Stack all the data we want in hover info # Shape: (n_arrival, n_departure, 5) for [c3, vinf, tof, dep_date, arr_date] custom_data = np.stack( [ grid_c3_departure_km2_s2, # C3 in km²/s² grid_vinf_arrival_km_s, # V∞ in km/s grid_tof_days, # ToF in days grid_departure_date_strings, # Departure date strings grid_arrival_date_strings, # Arrival date strings ], axis=-1, ) hover_info = "text" hover_template = ( "Departure: %{customdata[3]}<br>" "Arrival: %{customdata[4]}<br>" "Time of Flight: %{customdata[2]:.1f} days<br>" "C3 Departure: %{customdata[0]:.2f} km²/s²<br>" "V∞ Arrival: %{customdata[1]:.2f} km/s<br>" "<extra></extra>" ) # --- Create Dual Contour Traces --- plotly_traces = [] # C3 Departure Contour Trace (warm colorscale with built-in transparency) plotly_traces.append( go.Contour( x=unique_departure_dates_dt, y=unique_arrival_dates_dt, z=grid_c3_for_plot, zauto=False, zmin=c3_departure_min_filtered, zmax=c3_departure_max_filtered, colorscale=c3_colorscale, opacity=1.0, # Use full opacity since transparency is built into colorscale hoverinfo=hover_info, hovertemplate=hover_template, customdata=custom_data, contours=dict( coloring="fill", showlabels=True, labelfont=dict(size=10, color="darkred"), start=c3_departure_min_filtered, end=c3_departure_max_filtered, size=c3_step_filtered, labelformat=".1f", ), ncontours=10, # Ensure exactly 10 contour levels line=dict(width=1.0, smoothing=1.3), name="C3 Departure", showscale=False, # Remove colorbar from main trace connectgaps=False, # Don't connect across gaps to match V∞ behavior visible=True, showlegend=True, ) ) # V∞ Arrival Contour Trace (cool colorscale with built-in transparency) plotly_traces.append( go.Contour( x=unique_departure_dates_dt, y=unique_arrival_dates_dt, z=grid_vinf_for_plot, zauto=False, zmin=vinf_arrival_min_filtered, zmax=vinf_arrival_max_filtered, colorscale=vinf_colorscale, opacity=1.0, # Use full opacity since transparency is built into colorscale hoverinfo=hover_info, hovertemplate=hover_template, customdata=custom_data, contours=dict( coloring="fill", showlabels=True, labelfont=dict(size=10, color="darkblue"), start=vinf_arrival_min_filtered, end=vinf_arrival_max_filtered, size=vinf_step_filtered, labelformat=".1f", ), ncontours=10, # Ensure exactly 10 contour levels line=dict(width=1.0, smoothing=1.3), name="V∞ Arrival", showscale=False, # Remove colorbar from main trace connectgaps=False, # Faster rendering by not connecting across gaps visible="legendonly", showlegend=True, ) ) # --- ToF Contours --- plotly_traces.append( go.Contour( x=unique_departure_dates_dt, y=unique_arrival_dates_dt, z=grid_tof_days, # Original ToF grid with NaNs colorscale=[[0, tof_line_color], [1, tof_line_color]], contours=dict( coloring="lines", showlabels=True, labelfont=dict(size=10, color=tof_line_color), start=tof_min_filtered, end=tof_max_filtered, size=tof_step_filtered, ), line=dict(color=tof_line_color, width=1, dash="longdash"), name="ToF (days)", showscale=False, hoverinfo="skip", # Skip hover for ToF contours connectgaps=False, # Don't connect across NaN gaps visible=True, ) ) # --- Optimal Points (separate for C3 and V∞) --- if show_optimal: # Optimal C3 Departure Point from filtered data if len(filtered_c3_km2_s2) > 0: min_c3_filtered_idx = np.nanargmin(filtered_c3_km2_s2) # Get the corresponding departure and arrival times from filtered data best_c3_dep_mjd = filtered_departure_mjd[min_c3_filtered_idx] best_c3_arr_mjd = filtered_arrival_mjd[min_c3_filtered_idx] # Convert to datetime objects for plotting best_c3_dep_dt = Time(best_c3_dep_mjd, format="mjd").datetime best_c3_arr_dt = Time(best_c3_arr_mjd, format="mjd").datetime # Check if the optimal C3 point falls within our current plot range c3_optimal_in_range = ( xlim_dt[0] <= best_c3_dep_dt <= xlim_dt[1] and ylim_dt[0] <= best_c3_arr_dt <= ylim_dt[1] ) if c3_optimal_in_range: plotly_traces.append( go.Scatter( x=[best_c3_dep_dt], y=[best_c3_arr_dt], mode="markers", marker=dict( symbol="circle", color="darkred", size=10, line=dict(color="white", width=2), ), showlegend=True, name="Optimal C3", visible=True, hoverinfo="skip", # Skip hover for optimal points ) ) # Optimal V∞ Arrival Point from filtered data if len(filtered_vinf_km_s) > 0: min_vinf_filtered_idx = np.nanargmin(filtered_vinf_km_s) # Get the corresponding departure and arrival times from filtered data best_vinf_dep_mjd = filtered_departure_mjd[min_vinf_filtered_idx] best_vinf_arr_mjd = filtered_arrival_mjd[min_vinf_filtered_idx] # Convert to datetime objects for plotting best_vinf_dep_dt = Time(best_vinf_dep_mjd, format="mjd").datetime best_vinf_arr_dt = Time(best_vinf_arr_mjd, format="mjd").datetime # Check if the optimal V∞ point falls within our current plot range vinf_optimal_in_range = ( xlim_dt[0] <= best_vinf_dep_dt <= xlim_dt[1] and ylim_dt[0] <= best_vinf_arr_dt <= ylim_dt[1] ) if vinf_optimal_in_range: plotly_traces.append( go.Scatter( x=[best_vinf_dep_dt], y=[best_vinf_arr_dt], mode="markers", marker=dict( symbol="circle", color="darkblue", size=10, line=dict(color="white", width=2), ), showlegend=True, name="Optimal V∞", visible=True, hoverinfo="skip", # Skip hover for optimal points ) ) # --- Figure Creation and Layout Update --- fig = go.Figure(data=plotly_traces) if logo: images = [ dict( source=get_logo_base64(AsteroidInstituteLogoLight), xref="paper", yref="paper", x=0.98, y=0.02, sizex=0.12, sizey=0.12, xanchor="right", yanchor="bottom", ) ] else: images = [] fig.update_layout( title_text=title, xaxis_title="Departure Date", yaxis_title="Arrival Date", xaxis=dict( tickformat="%Y-%m-%d", tickangle=-45, range=xlim_dt, showgrid=True, gridcolor="lightgray", gridwidth=1, ), yaxis=dict( tickformat="%Y-%m-%d", range=ylim_dt, showgrid=True, gridcolor="lightgray", gridwidth=1, ), plot_bgcolor="white", width=width, height=height, autosize=False, hovermode="closest", images=images, legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0, itemsizing="constant", # Keep legend items same size when hidden font=dict(size=12), # Larger legend text bgcolor="rgba(255,255,255,0.8)", # Semi-transparent background bordercolor="Black", borderwidth=1, ), ) return fig