Source code for adam_core.propagator.propagator

import logging
import multiprocessing as mp
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Type, Union

import numpy as np
import numpy.typing as npt
import pyarrow as pa
import pyarrow.compute as pc
import quivr as qv
import ray
from ray import ObjectRef

from ..constants import Constants as c
from ..coordinates.cartesian import CartesianCoordinates
from ..coordinates.origin import Origin, OriginCodes
from ..coordinates.spherical import SphericalCoordinates
from ..coordinates.transform import transform_coordinates
from ..dynamics.aberrations import add_light_time
from ..orbits.ephemeris import Ephemeris
from ..orbits.orbits import Orbits
from ..orbits.variants import VariantEphemeris, VariantOrbits
from ..photometry.magnitude import (
    calculate_apparent_magnitude_v,
    calculate_apparent_magnitude_v_and_phase_angle,
    calculate_phase_angle,
)
from ..ray_cluster import initialize_use_ray
from ..time import Timestamp
from ..utils.chunking import process_in_chunks
from ..utils.iter import _iterate_chunks
from .types import EphemerisType, ObserverType, OrbitType, TimestampType
from .utils import ensure_input_origin_and_frame, ensure_input_time_scale

logger = logging.getLogger(__name__)

C = c.C


def _alignment_indices_string_keys(
    *,
    obs_code: npt.NDArray[np.object_],
    obs_days: npt.NDArray[np.int64],
    obs_nanos: npt.NDArray[np.int64],
    eph_code: npt.NDArray[np.object_],
    eph_days: npt.NDArray[np.int64],
    eph_nanos: npt.NDArray[np.int64],
) -> npt.NDArray[np.int64]:
    keys_obs = np.array(
        [f"{c}|{d}|{n}" for c, d, n in zip(obs_code, obs_days, obs_nanos)],
        dtype=object,
    )
    keys_eph = np.array(
        [f"{c}|{d}|{n}" for c, d, n in zip(eph_code, eph_days, eph_nanos)],
        dtype=object,
    )
    if len(np.unique(keys_obs)) != len(keys_obs):
        raise ValueError(
            "Observer keys are not unique; cannot map ephemeris rows unambiguously."
        )

    sorter = np.argsort(keys_obs)
    pos = np.searchsorted(keys_obs[sorter], keys_eph)
    idx = sorter[pos]
    if np.any(keys_obs[idx] != keys_eph):
        raise ValueError("Failed to align ephemeris rows to observers (key mismatch).")
    return np.asarray(idx, dtype=np.int64)


def _alignment_indices_struct_index_in(
    *,
    obs_code: pa.Array,
    obs_days: pa.Array,
    obs_nanos: pa.Array,
    eph_code: pa.Array,
    eph_days: pa.Array,
    eph_nanos: pa.Array,
) -> npt.NDArray[np.int64]:
    # NOTE: Arrow compute `index_in` does not support struct kernels in our supported
    # versions, so we build keys via vectorized Arrow string concatenation (no Python loops).
    sep = pa.scalar("|", type=pa.large_string())
    obs_code_s = pc.cast(obs_code, pa.large_string())
    obs_days_s = pc.cast(obs_days, pa.large_string())
    obs_nanos_s = pc.cast(obs_nanos, pa.large_string())
    eph_code_s = pc.cast(eph_code, pa.large_string())
    eph_days_s = pc.cast(eph_days, pa.large_string())
    eph_nanos_s = pc.cast(eph_nanos, pa.large_string())

    obs_key = pc.binary_join_element_wise(
        pc.binary_join_element_wise(obs_code_s, obs_days_s, sep), obs_nanos_s, sep
    )
    eph_key = pc.binary_join_element_wise(
        pc.binary_join_element_wise(eph_code_s, eph_days_s, sep), eph_nanos_s, sep
    )

    # Uniqueness check: keys must uniquely identify an observer row.
    if int(pc.count_distinct(obs_key).as_py()) != len(obs_key):
        raise ValueError(
            "Observer keys are not unique; cannot map ephemeris rows unambiguously."
        )

    idx = pc.fill_null(pc.index_in(eph_key, value_set=obs_key), -1)
    idx_np = np.asarray(idx.to_numpy(zero_copy_only=False), dtype=np.int64)
    if np.any(idx_np < 0):
        raise ValueError("Failed to align ephemeris rows to observers (key mismatch).")
    return idx_np


def _hg_params_for_ephemeris_rows_arrow(
    *,
    orbit_id: pa.Array | pa.ChunkedArray,
    H_v: pa.Array | pa.ChunkedArray,
    G: pa.Array | pa.ChunkedArray,
    ephemeris_orbit_id: pa.Array | pa.ChunkedArray,
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
    """
    Vectorized mapping from orbit_id -> (H_v, G) aligned to ephemeris rows.

    This raises immediately if the input orbit table contains conflicting (H_v, G)
    values for the same orbit_id.
    """
    oid = pc.cast(orbit_id, pa.large_string())
    H = pc.cast(H_v, pa.float64())
    g = pc.cast(G, pa.float64())

    valid = pc.and_(pc.is_finite(H), pc.is_finite(g))
    if not pc.any(valid).as_py():
        n = len(ephemeris_orbit_id)
        return np.full(n, np.nan, dtype=np.float64), np.full(
            n, np.nan, dtype=np.float64
        )

    t = pa.table({"orbit_id": oid, "H_v": H, "G": g})
    t = t.filter(valid)

    gb = t.group_by("orbit_id").aggregate(
        [("H_v", "min"), ("H_v", "max"), ("G", "min"), ("G", "max")]
    )
    # Any mismatch indicates conflicting params for an orbit_id.
    conflict = pc.or_(
        pc.not_equal(gb["H_v_min"], gb["H_v_max"]),
        pc.not_equal(gb["G_min"], gb["G_max"]),
    )
    if pc.any(conflict).as_py():
        bad = pc.filter(gb["orbit_id"], conflict)
        bad_id = bad[0].as_py() if len(bad) > 0 else "<unknown>"
        raise ValueError(f"Conflicting physical parameters for orbit_id={bad_id}")

    # Map ephemeris orbit_id -> grouped index
    eph_oid = pc.cast(ephemeris_orbit_id, pa.large_string())
    idx = pc.index_in(eph_oid, value_set=gb["orbit_id"])
    # pc.take propagates nulls for missing indices; those convert to NaN below.
    H_out = pc.take(gb["H_v_min"], idx)
    G_out = pc.take(gb["G_min"], idx)
    return (
        np.asarray(H_out.to_numpy(zero_copy_only=False), dtype=np.float64),
        np.asarray(G_out.to_numpy(zero_copy_only=False), dtype=np.float64),
    )


[docs] def propagation_worker( orbits: Union[Orbits, VariantOrbits], times: Timestamp, propagator: Type["Propagator"], **kwargs, ) -> Union[Orbits, VariantOrbits]: prop = propagator(**kwargs) propagated = prop._propagate_orbits(orbits, times) return propagated
[docs] def attach_magnitude_or_phase( ephemeris: EphemerisType, orbits: OrbitType, observers: ObserverType, predict_magnitudes: bool, predict_phase_angle: bool, ) -> EphemerisType: if (not predict_magnitudes and not predict_phase_angle) or len(ephemeris) == 0: return ephemeris if pc.all(pc.is_null(ephemeris.aberrated_coordinates.x)).as_py(): return ephemeris orbits_value = ray.get(orbits) if isinstance(orbits, ObjectRef) else orbits observers_value = ( ray.get(observers) if isinstance(observers, ObjectRef) else observers ) want_mags = bool(predict_magnitudes) want_alpha = bool(predict_phase_angle) H_v: npt.NDArray[np.float64] | None = None G: npt.NDArray[np.float64] | None = None has_params: npt.NDArray[np.bool_] | None = None if want_mags: H_v, G = _hg_params_for_ephemeris_rows_arrow( orbit_id=orbits_value.orbit_id, H_v=orbits_value.physical_parameters.H_v, G=orbits_value.physical_parameters.G, ephemeris_orbit_id=ephemeris.orbit_id, ) has_params = np.isfinite(H_v) & np.isfinite(G) if not np.any(has_params): want_mags = False if not want_alpha and not want_mags: return ephemeris if ephemeris.coordinates.time.scale != observers_value.coordinates.time.scale: observers_value = observers_value.set_column( "coordinates.time", observers_value.coordinates.time.rescale(ephemeris.coordinates.time.scale), ) idx_np = _alignment_indices_struct_index_in( obs_code=observers_value.code, obs_days=observers_value.coordinates.time.days, obs_nanos=observers_value.coordinates.time.nanos, eph_code=ephemeris.coordinates.origin.code, eph_days=ephemeris.coordinates.time.days, eph_nanos=ephemeris.coordinates.time.nanos, ) observers_aligned = observers_value.take(pa.array(idx_np, type=pa.int64())) obj_helio = transform_coordinates( ephemeris.aberrated_coordinates, CartesianCoordinates, frame_out="ecliptic", origin_out=OriginCodes.SUN, ) obs_helio = observers_aligned.set_column( "coordinates", transform_coordinates( observers_aligned.coordinates, CartesianCoordinates, frame_out="ecliptic", origin_out=OriginCodes.SUN, ), ) alpha_deg = None mags = None if want_mags and want_alpha: assert H_v is not None and G is not None and has_params is not None mags, alpha_deg = calculate_apparent_magnitude_v_and_phase_angle( H_v=H_v, object_coords=obj_helio, observer=obs_helio, G=G, ) elif want_alpha: alpha_deg = calculate_phase_angle(obj_helio, obs_helio) elif want_mags: assert H_v is not None and G is not None and has_params is not None mags = calculate_apparent_magnitude_v( H_v=H_v, object_coords=obj_helio, observer=obs_helio, G=G, ) if alpha_deg is not None: alpha_deg = np.asarray(alpha_deg, dtype=np.float64) ephemeris = ephemeris.set_column( "alpha", pa.array(alpha_deg, mask=~np.isfinite(alpha_deg), type=pa.float64()), ) if mags is not None: assert has_params is not None mags = np.asarray(mags, dtype=np.float64) valid = has_params & np.isfinite(mags) predicted = pa.array(mags, mask=~valid, type=pa.float64()) ephemeris = ephemeris.set_column("predicted_magnitude_v", predicted) return ephemeris
def _uses_default_ephemeris_mixin(propagator: "Propagator") -> bool: """ True iff `propagator._generate_ephemeris` is the EphemerisMixin implementation. Why: subclasses may override `_generate_ephemeris` without accepting extra keyword arguments. We only pass photometry flags into the default mixin implementation. """ func = getattr(propagator._generate_ephemeris, "__func__", None) return func is EphemerisMixin._generate_ephemeris @ray.remote def propagation_worker_ray( idx: npt.NDArray[np.int64], orbits: OrbitType, times: OrbitType, propagator: "Propagator", ) -> OrbitType: orbits_chunk = orbits.take(idx) propagated = propagator._propagate_orbits(orbits_chunk, times) return propagated @ray.remote def ephemeris_worker_ray( idx: npt.NDArray[np.int64], orbits: OrbitType, observers: ObserverType, propagator: "Propagator", predict_magnitudes: bool, predict_phase_angle: bool, ) -> EphemerisType: orbits_chunk = orbits.take(idx) if _uses_default_ephemeris_mixin(propagator): ephemeris = propagator._generate_ephemeris( orbits_chunk, observers, predict_magnitudes=predict_magnitudes, predict_phase_angle=predict_phase_angle, ) else: ephemeris = propagator._generate_ephemeris(orbits_chunk, observers) ephemeris = attach_magnitude_or_phase( ephemeris, orbits=orbits_chunk, observers=observers, predict_magnitudes=predict_magnitudes, predict_phase_angle=predict_phase_angle, ) return ephemeris
[docs] class EphemerisMixin: """ Mixin with signature for generating ephemerides. Subclasses should implement the _generate_ephemeris method. """ def _generate_ephemeris( self, orbits: OrbitType, observers: ObserverType, *, predict_magnitudes: bool = True, predict_phase_angle: bool = False, ) -> EphemerisType: """ A generic ephemeris implementation, which can be used or overridden by subclasses. """ # Sort observers by time and code to ensure consistent ordering # As further propagation will order by time as well observers = observers.sort_by( ["coordinates.time.days", "coordinates.time.nanos", "code"] ) observers_barycentric = observers.set_column( "coordinates", transform_coordinates( observers.coordinates, CartesianCoordinates, frame_out="ecliptic", origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER, ), ) observers_barycentric_tiled = qv.concatenate( [observers_barycentric] * len(orbits) ) # Propagate orbits to sorted observer times # Returns orbits sorted by orbit_id and time propagated_orbits = self.propagate_orbits( orbits, observers.coordinates.time, max_processes=1 ) # Transform both the orbits and observers to the barycenter if they are not already. propagated_orbits_barycentric = propagated_orbits.set_column( "coordinates", transform_coordinates( propagated_orbits.coordinates, CartesianCoordinates, frame_out="ecliptic", origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER, ), ) # Process in padded chunks propagated_orbits_aberrated: np.ndarray light_time: np.ndarray propagated_orbits_barycentric_values = ( propagated_orbits_barycentric.coordinates.values ) propagated_orbits_barycentric_time = ( propagated_orbits_barycentric.coordinates.time.mjd().to_numpy( zero_copy_only=False ) ) observers_barycentric_tiled_values = observers_barycentric_tiled.coordinates.r chunk_size = 200 n = int(propagated_orbits_barycentric_values.shape[0]) # `process_in_chunks` pads each chunk to a fixed size for JAX; preallocate the padded # output arrays and slice off padding after the loop. This avoids O(n^2) reallocation # and memcpy from repeated np.concatenate calls. n_padded = int(((n + int(chunk_size) - 1) // int(chunk_size)) * int(chunk_size)) propagated_orbits_aberrated = np.empty((n_padded, 6), dtype=np.float64) light_time = np.empty((n_padded,), dtype=np.float64) k = 0 for ( propagated_orbits_barycentric_chunk, propagated_orbits_barycentric_time_chunk, observers_barycentric_tiled_chunk, ) in zip( process_in_chunks(propagated_orbits_barycentric_values, chunk_size), process_in_chunks(propagated_orbits_barycentric_time, chunk_size), process_in_chunks(observers_barycentric_tiled_values, chunk_size), ): propagated_orbits_aberrated_chunk, light_time_chunk = add_light_time( propagated_orbits_barycentric_chunk, propagated_orbits_barycentric_time_chunk, observers_barycentric_tiled_chunk, lt_tol=1e-12, mu=c.MU, max_iter=100, tol=1e-15, ) propagated_orbits_aberrated[k : k + int(chunk_size), :] = np.asarray( propagated_orbits_aberrated_chunk, dtype=np.float64 ) light_time[k : k + int(chunk_size)] = np.asarray( light_time_chunk, dtype=np.float64 ) k += int(chunk_size) # Remove padding propagated_orbits_aberrated = propagated_orbits_aberrated[:n] light_time = light_time[:n] # Guard against pathological light-time values before constructing timestamps. if not np.all(np.isfinite(light_time)): raise ValueError( "Light travel time is NaN or too large and propagation will break." ) # Compute emission times by subtracting light-time (in days) from the # propagated (observer) times. emission_times = ( propagated_orbits_barycentric.coordinates.time.add_fractional_days( pa.array(-light_time) ) ) propagated_orbits_aberrated = Orbits.from_kwargs( orbit_id=propagated_orbits_barycentric.orbit_id, object_id=propagated_orbits_barycentric.object_id, coordinates=CartesianCoordinates.from_kwargs( x=propagated_orbits_aberrated[:, 0], y=propagated_orbits_aberrated[:, 1], z=propagated_orbits_aberrated[:, 2], vx=propagated_orbits_aberrated[:, 3], vy=propagated_orbits_aberrated[:, 4], vz=propagated_orbits_aberrated[:, 5], covariance=propagated_orbits_barycentric.coordinates.covariance, time=emission_times, origin=propagated_orbits_barycentric.coordinates.origin, frame=propagated_orbits_barycentric.coordinates.frame, ), ) topocentric_state = ( propagated_orbits_aberrated.coordinates.values - observers_barycentric_tiled.coordinates.values ) topocentric_coordinates = CartesianCoordinates.from_kwargs( x=topocentric_state[:, 0], y=topocentric_state[:, 1], z=topocentric_state[:, 2], vx=topocentric_state[:, 3], vy=topocentric_state[:, 4], vz=topocentric_state[:, 5], covariance=None, # The ephemeris times are at the point of the observer, # not the aberrated orbit time=observers_barycentric_tiled.coordinates.time, origin=Origin.from_kwargs( code=observers_barycentric_tiled.code.to_numpy(zero_copy_only=False) ), frame="ecliptic", ) spherical_coordinates = SphericalCoordinates.from_cartesian( topocentric_coordinates ) spherical_coordinates = transform_coordinates( spherical_coordinates, SphericalCoordinates, frame_out="equatorial" ) # Ephemeris are generally compared in UTC, so rescale the time spherical_coordinates = spherical_coordinates.set_column( "time", spherical_coordinates.time.rescale("utc"), ) if isinstance(orbits, Orbits): ephemeris = Ephemeris.from_kwargs( orbit_id=propagated_orbits_barycentric.orbit_id, object_id=propagated_orbits_barycentric.object_id, coordinates=spherical_coordinates, light_time=light_time, aberrated_coordinates=propagated_orbits_aberrated.coordinates, ) elif isinstance(orbits, VariantOrbits): # Propagated order is (orbit_id, variant_id, time) to match observer tiling # [observers]*len(orbits). Sort input the same way and repeat each variant's # weight for each time so row i gets the weight for the variant at that row. orbits_sorted = orbits.sort_by( [ "orbit_id", "variant_id", "coordinates.time.days", "coordinates.time.nanos", ] ) n_obs = len(observers) weights_np = orbits_sorted.weights.to_numpy(zero_copy_only=False) weights_cov_np = orbits_sorted.weights_cov.to_numpy(zero_copy_only=False) ephemeris = VariantEphemeris.from_kwargs( orbit_id=propagated_orbits_barycentric.orbit_id, object_id=propagated_orbits_barycentric.object_id, variant_id=propagated_orbits_barycentric.variant_id, coordinates=spherical_coordinates, light_time=light_time, weights=np.repeat(weights_np, n_obs), weights_cov=np.repeat(weights_cov_np, n_obs), aberrated_coordinates=propagated_orbits_aberrated.coordinates, ) # Return in same order as propagate_orbits: (orbit_id, variant_id, time) or # (orbit_id, time). No sort here; callers use .sort_by() or .select()/.apply_mask(). ephemeris = attach_magnitude_or_phase( ephemeris, orbits=orbits, observers=observers, predict_magnitudes=predict_magnitudes, predict_phase_angle=predict_phase_angle, ) return ephemeris
[docs] def generate_ephemeris( self, orbits: OrbitType, observers: ObserverType, covariance: bool = False, covariance_method: Literal[ "auto", "sigma-point", "monte-carlo" ] = "monte-carlo", num_samples: int = 1000, chunk_size: int = 100, max_processes: Optional[int] = 1, seed: Optional[int] = None, predict_magnitudes: bool = True, predict_phase_angle: bool = False, ) -> Ephemeris: """ Generate ephemerides for each orbit in orbits as observed by each observer in observers. Parameters ---------- orbits : `~adam_core.orbits.orbits.Orbits` (N) Orbits for which to generate ephemerides. observers : `~adam_core.observers.observers.Observers` (M) Observers for which to generate the ephemerides of each orbit. covariance: bool, optional Propagate the covariance matrices of the orbits. This is done by sampling the orbits from their covariance matrices and propagating each sample and for each sample also generating ephemerides. The covariance of the ephemerides is then the covariance of the samples. covariance_method : {'sigma-point', 'monte-carlo', 'auto'}, optional The method to use for sampling the covariance matrix. If 'auto' is selected then the method will be automatically selected based on the covariance matrix. The default is 'monte-carlo'. num_samples : int, optional The number of samples to draw when sampling with monte-carlo. chunk_size : int, optional Number of orbits to send to each job. max_processes : int or None, optional Number of processes to launch. If None then the number of processes will be equal to the number of cores on the machine. If 1 then no multiprocessing will be used. If "ray" is the parallel_backend and a ray instance is initialized already then this argument is ignored. Returns ------- ephemeris : `~adam_core.orbits.ephemeris.Ephemeris` or `~adam_core.orbits.variants.VariantEphemeris` Predicted ephemerides. Row order matches :meth:`propagate_orbits`: (orbit_id, time) or (orbit_id, variant_id, time) for variant ephemeris. Use .sort_by() or .select()/.apply_mask() for other orderings or grouping. """ # If sending in VariantOrbits, we make sure not to run covariance assert (covariance is False) or ( isinstance(orbits, Orbits) ), "Covariance is not supported for VariantOrbits" # Check if we need to propagate orbit variants so we can propagate covariance # matrices ephemeris: Ephemeris = Ephemeris.empty() variant_ephemeris: VariantEphemeris = VariantEphemeris.empty() variants = VariantOrbits.empty() if covariance is True and not orbits.coordinates.covariance.is_all_nan(): variants = VariantOrbits.create( orbits, method=covariance_method, num_samples=num_samples, seed=seed, ) if max_processes is None: max_processes = mp.cpu_count() uses_default_ephemeris = _uses_default_ephemeris_mixin(self) needs_attach = max_processes <= 1 and not uses_default_ephemeris if max_processes > 1: initialize_use_ray(num_cpus=max_processes) # Add orbits and observers to object store if # they haven't already been added if not isinstance(observers, ObjectRef): observers_ref = ray.put(observers) else: observers_ref = observers if not isinstance(orbits, ObjectRef): orbits_ref = ray.put(orbits) else: orbits_ref = orbits # We need to dereference the orbits ObjectRef so we can # check its length for chunking and determine # if we need to propagate variants orbits = ray.get(orbits_ref) # Create futures futures_inputs = [] idx = np.arange(0, len(orbits)) # Use at least max_processes chunks so all workers get work. effective_chunk_size = chunk_size if max_processes > 1 and len(orbits) > 0: effective_chunk_size = min( chunk_size, max(1, len(orbits) // max_processes) ) for idx_chunk in _iterate_chunks(idx, effective_chunk_size): futures_inputs.append( ( idx_chunk, orbits_ref, observers_ref, self, predict_magnitudes, predict_phase_angle, ) ) # Add variants to propagate to futures inputs if covariance is True and len(variants) > 0: variants_ref = ray.put(variants) idx = np.arange(0, len(variants)) var_chunk_size = ( min(chunk_size, max(1, len(variants) // max_processes)) if max_processes > 1 else chunk_size ) for variant_chunk_idx in _iterate_chunks(idx, var_chunk_size): futures_inputs.append( ( variant_chunk_idx, variants_ref, observers_ref, self, predict_magnitudes, predict_phase_angle, ) ) # Get results as they finish (we sort later) futures = [] ephemeris_parts: list[Ephemeris] = [] variant_ephemeris_parts: list[VariantEphemeris] = [] for future_input in futures_inputs: futures.append(ephemeris_worker_ray.remote(*future_input)) if len(futures) >= max_processes * 1.5: finished, futures = ray.wait(futures, num_returns=1) result = ray.get(finished[0]) if isinstance(result, Ephemeris): ephemeris_parts.append(result) elif isinstance(result, VariantEphemeris): variant_ephemeris_parts.append(result) else: raise ValueError( f"Unexpected result type from ephemeris worker: {type(result)}" ) while futures: finished, futures = ray.wait(futures, num_returns=1) result = ray.get(finished[0]) if isinstance(result, Ephemeris): ephemeris_parts.append(result) elif isinstance(result, VariantEphemeris): variant_ephemeris_parts.append(result) else: raise ValueError( f"Unexpected result type from ephemeris worker: {type(result)}" ) if ephemeris_parts: ephemeris = ( ephemeris_parts[0] if len(ephemeris_parts) == 1 else qv.concatenate(ephemeris_parts) ) if variant_ephemeris_parts: variant_ephemeris = ( variant_ephemeris_parts[0] if len(variant_ephemeris_parts) == 1 else qv.concatenate(variant_ephemeris_parts) ) # Concatenation was in completion order; sort to canonical order. if len(ephemeris) > 0: ephemeris = ephemeris.sort_by( [ "orbit_id", "coordinates.time.days", "coordinates.time.nanos", "coordinates.origin.code", ] ) if len(variant_ephemeris) > 0: variant_ephemeris = variant_ephemeris.sort_by( [ "orbit_id", "variant_id", "coordinates.time.days", "coordinates.time.nanos", "coordinates.origin.code", ] ) else: if uses_default_ephemeris: results = self._generate_ephemeris( orbits, observers, predict_magnitudes=predict_magnitudes, predict_phase_angle=predict_phase_angle, ) else: results = self._generate_ephemeris(orbits, observers) if isinstance(results, Ephemeris): ephemeris = results elif isinstance(results, VariantEphemeris): variant_ephemeris = results else: raise ValueError( f"Unexpected result type from generate_ephemeris: {type(results)}" ) if covariance is True and len(variants) > 0: if uses_default_ephemeris: variant_ephemeris = self._generate_ephemeris( variants, observers, predict_magnitudes=predict_magnitudes, predict_phase_angle=predict_phase_angle, ) else: variant_ephemeris = self._generate_ephemeris(variants, observers) if covariance is False and len(variant_ephemeris) > 0: # If we decide that we do not need to guarantee that the time scale is in UTC # then we may want to call: # if isinstance(observers, ray.ObjectRef): # variant_ephemeris = ensure_input_time_scale( # variant_ephemeris, ray.get(observers).coordinates.time # ) # else: # variant_ephemeris = ensure_input_time_scale( # variant_ephemeris, observers.coordinates.time # ) variant_ephemeris = variant_ephemeris.set_column( "coordinates.time", variant_ephemeris.coordinates.time.rescale("utc"), ) if needs_attach: variant_ephemeris = attach_magnitude_or_phase( variant_ephemeris, orbits=orbits, observers=observers, predict_magnitudes=predict_magnitudes, predict_phase_angle=predict_phase_angle, ) return variant_ephemeris if covariance is True and len(variant_ephemeris) > 0: ephemeris = variant_ephemeris.collapse(ephemeris) # Same note as above. # if isinstance(observers, ray.ObjectRef): # ephemeris = ensure_input_time_scale( # ephemeris, ray.get(observers).coordinates.time # ) # else: # ephemeris = ensure_input_time_scale( # ephemeris, observers.coordinates.time # ) ephemeris = ephemeris.set_column( "coordinates.time", ephemeris.coordinates.time.rescale("utc"), ) if needs_attach: ephemeris = attach_magnitude_or_phase( ephemeris, orbits=orbits, observers=observers, predict_magnitudes=predict_magnitudes, predict_phase_angle=predict_phase_angle, ) return ephemeris
[docs] class Propagator(ABC, EphemerisMixin): """ Abstract class for propagating orbits and related functions. Subclasses should implement the _propagate_orbits. For additional functions, subclasses can add abstract mixins. Important: subclasses should be pickleable! As this class uses multiprocessing to parallelize propagation and ephemeris generation. This means that subclasses should not have any unpickleable attributes. """ @abstractmethod def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitType: """ Propagate orbits to times. THIS FUNCTION SHOULD BE DEFINED BY THE USER. """ pass def __getstate__(self): """ Get the state of the propagator. Subclasses need to define what is picklable for multiprocessing. e.g. def __getstate__(self): state = self.__dict__.copy() state.pop("_stateful_attribute_that_is_not_pickleable") return state """ raise NotImplementedError( "Propagator must implement __getstate__ for multiprocessing serialization.\n" "Example implementation: \n" "def __getstate__(self):\n" " state = self.__dict__.copy()\n" " state.pop('_stateful_attribute_that_is_not_pickleable')\n" " return state" ) def __setstate__(self, state): """ Set the state of the propagator. Subclasses need to define what is unpicklable for multiprocessing. e.g. def __setstate__(self, state): self.__dict__.update(state) self._stateful_attribute_that_is_not_pickleable = None """ raise NotImplementedError( "Propagator must implement __setstate__ for multiprocessing serialization.\n" "Example implementation: \n" "def __setstate__(self, state):\n" " self.__dict__.update(state)\n" " self._stateful_attribute_that_is_not_pickleable = None" )
[docs] def propagate_orbits( self, orbits: Union[OrbitType, ObjectRef], times: Union[TimestampType, ObjectRef], covariance: bool = False, covariance_method: Literal[ "auto", "sigma-point", "monte-carlo" ] = "monte-carlo", num_samples: int = 1000, chunk_size: int = 100, max_processes: Optional[int] = 1, seed: Optional[int] = None, ) -> Union[Orbits, VariantOrbits]: """ Propagate each orbit in orbits to each time in times. This method handles parallelization of the propagation of the orbits. Subclasses may override this method to modify parallelization behavior. Parameters ---------- orbits : `~adam_core.orbits.orbits.Orbits` (N) Orbits to propagate. times : Timestamp (M) Times to which to propagate orbits. Sorted chronologically before calling the backend so integrators (e.g. ASSIST, REBOUND) receive time-ordered epochs for efficient stepping. covariance : bool, optional Propagate the covariance matrices of the orbits. This is done by sampling the orbits from their covariance matrices and propagating each sample. The covariance of the propagated orbits is then the covariance of the samples. covariance_method : {'sigma-point', 'monte-carlo', 'auto'}, optional The method to use for sampling the covariance matrix. If 'auto' is selected then the method will be automatically selected based on the covariance matrix. The default is 'monte-carlo'. num_samples : int, optional The number of samples to draw when sampling with monte-carlo. chunk_size : int, optional Number of orbits to send to each job. max_processes : int or None, optional Maximum number of processes to launch. If None then the number of processes will be equal to the number of cores on the machine. If 1 then no multiprocessing will be used. If "ray" is the parallel_backend and a ray instance is initialized already then this argument is ignored. Returns ------- propagated : `~adam_core.orbits.orbits.Orbits` or `~adam_core.orbits.variants.VariantOrbits` Propagated orbits. Rows are ordered (orbit_id, time) for Orbits and (orbit_id, variant_id, time) for VariantOrbits. Use .sort_by() or .select()/.apply_mask() for other orderings or grouping. """ if covariance is True and isinstance(orbits, VariantOrbits): raise AssertionError("Covariance is not supported for VariantOrbits") if max_processes is None: max_processes = mp.cpu_count() # Resolve times and sort chronologically so backends (ASSIST, REBOUND, etc.) # receive time-ordered epochs for efficient integration. if isinstance(times, ObjectRef): times = ray.get(times) times = times.sort_by(["days", "nanos"]) if max_processes > 1: propagated_list: List[Orbits] = [] covariance_variants_list: List[VariantOrbits] = [] # When the input is VariantOrbits, do not treat them as covariance. propagated_variants_input_list: List[VariantOrbits] = [] input_is_variants: Optional[bool] = None initialize_use_ray(num_cpus=max_processes) times_ref = ray.put(times) if not isinstance(orbits, ObjectRef): input_is_variants = isinstance(orbits, VariantOrbits) orbits_ref = ray.put(orbits) else: orbits_ref = orbits # We need to dereference the orbits ObjectRef so we can # check its length for chunking and determine # if we need to propagate variants orbits = ray.get(orbits_ref) input_is_variants = isinstance(orbits, VariantOrbits) if covariance is True and input_is_variants: raise AssertionError("Covariance is not supported for VariantOrbits") # Create futures inputs futures_inputs = [] idx = np.arange(0, len(orbits)) for idx_chunk in _iterate_chunks(idx, chunk_size): futures_inputs.append( ( idx_chunk, orbits_ref, times_ref, self, ) ) # Add variants to propagate to futures inputs if covariance is True and not orbits.coordinates.covariance.is_all_nan(): variants = VariantOrbits.create( orbits, method=covariance_method, num_samples=num_samples, seed=seed, ) variants_ref = ray.put(variants) idx = np.arange(0, len(variants)) for variant_chunk_idx in _iterate_chunks(idx, chunk_size): futures_inputs.append( ( variant_chunk_idx, variants_ref, times_ref, self, ) ) # Submit and process jobs with queuing futures = [] for future_input in futures_inputs: futures.append(propagation_worker_ray.remote(*future_input)) if len(futures) >= max_processes * 1.5: finished, futures = ray.wait(futures, num_returns=1) result = ray.get(finished[0]) if isinstance(result, Orbits): propagated_list.append(result) elif isinstance(result, VariantOrbits): if input_is_variants: propagated_variants_input_list.append(result) else: covariance_variants_list.append(result) else: raise ValueError( f"Unexpected result type from propagation worker: {type(result)}" ) # Process remaining futures while futures: finished, futures = ray.wait(futures, num_returns=1) result = ray.get(finished[0]) if isinstance(result, Orbits): propagated_list.append(result) elif isinstance(result, VariantOrbits): if input_is_variants: propagated_variants_input_list.append(result) else: covariance_variants_list.append(result) else: raise ValueError( f"Unexpected result type from propagation worker: {type(result)}" ) # Concatenate propagated orbits if input_is_variants: propagated = qv.concatenate(propagated_variants_input_list) propagated_variants = None else: propagated = qv.concatenate(propagated_list) if len(covariance_variants_list) > 0: propagated_variants = qv.concatenate(covariance_variants_list) propagated_variants = propagated_variants.sort_by( [ "orbit_id", "variant_id", "coordinates.time.days", "coordinates.time.nanos", ] ) else: propagated_variants = None else: propagated = self._propagate_orbits(orbits, times) if covariance is True and not orbits.coordinates.covariance.is_all_nan(): variants = VariantOrbits.create( orbits, method=covariance_method, num_samples=num_samples, seed=seed, ) propagated_variants = self._propagate_orbits(variants, times) propagated_variants = propagated_variants.sort_by( [ "orbit_id", "variant_id", "coordinates.time.days", "coordinates.time.nanos", ] ) else: propagated_variants = None if propagated_variants is not None: propagated = propagated_variants.collapse(propagated) # Preserve the time scale of the requested times propagated = ensure_input_time_scale(propagated, times) # Return the results with the original origin and frame # Preserve the original output origin for the input orbits # by orbit id propagated = ensure_input_origin_and_frame(orbits, propagated) # Internal order is (orbit_id, variant_id, time) so that observer tiling # [observers]*len(orbits) in _generate_ephemeris matches row-for-row. # Callers that need time-first or grouping should sort or use .select()/.apply_mask(). if isinstance(propagated, VariantOrbits): return propagated.sort_by( [ "orbit_id", "variant_id", "coordinates.time.days", "coordinates.time.nanos", ] ) return propagated.sort_by( ["orbit_id", "coordinates.time.days", "coordinates.time.nanos"] )