import logging
import multiprocessing as mp
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Type, Union
import numpy as np
import numpy.typing as npt
import pyarrow as pa
import pyarrow.compute as pc
import quivr as qv
import ray
from ray import ObjectRef
from ..constants import Constants as c
from ..coordinates.cartesian import CartesianCoordinates
from ..coordinates.origin import Origin, OriginCodes
from ..coordinates.spherical import SphericalCoordinates
from ..coordinates.transform import transform_coordinates
from ..dynamics.aberrations import add_light_time
from ..orbits.ephemeris import Ephemeris
from ..orbits.orbits import Orbits
from ..orbits.variants import VariantEphemeris, VariantOrbits
from ..photometry.magnitude import (
calculate_apparent_magnitude_v,
calculate_apparent_magnitude_v_and_phase_angle,
calculate_phase_angle,
)
from ..ray_cluster import initialize_use_ray
from ..time import Timestamp
from ..utils.chunking import process_in_chunks
from ..utils.iter import _iterate_chunks
from .types import EphemerisType, ObserverType, OrbitType, TimestampType
from .utils import ensure_input_origin_and_frame, ensure_input_time_scale
logger = logging.getLogger(__name__)
C = c.C
def _alignment_indices_string_keys(
*,
obs_code: npt.NDArray[np.object_],
obs_days: npt.NDArray[np.int64],
obs_nanos: npt.NDArray[np.int64],
eph_code: npt.NDArray[np.object_],
eph_days: npt.NDArray[np.int64],
eph_nanos: npt.NDArray[np.int64],
) -> npt.NDArray[np.int64]:
keys_obs = np.array(
[f"{c}|{d}|{n}" for c, d, n in zip(obs_code, obs_days, obs_nanos)],
dtype=object,
)
keys_eph = np.array(
[f"{c}|{d}|{n}" for c, d, n in zip(eph_code, eph_days, eph_nanos)],
dtype=object,
)
if len(np.unique(keys_obs)) != len(keys_obs):
raise ValueError(
"Observer keys are not unique; cannot map ephemeris rows unambiguously."
)
sorter = np.argsort(keys_obs)
pos = np.searchsorted(keys_obs[sorter], keys_eph)
idx = sorter[pos]
if np.any(keys_obs[idx] != keys_eph):
raise ValueError("Failed to align ephemeris rows to observers (key mismatch).")
return np.asarray(idx, dtype=np.int64)
def _alignment_indices_struct_index_in(
*,
obs_code: pa.Array,
obs_days: pa.Array,
obs_nanos: pa.Array,
eph_code: pa.Array,
eph_days: pa.Array,
eph_nanos: pa.Array,
) -> npt.NDArray[np.int64]:
# NOTE: Arrow compute `index_in` does not support struct kernels in our supported
# versions, so we build keys via vectorized Arrow string concatenation (no Python loops).
sep = pa.scalar("|", type=pa.large_string())
obs_code_s = pc.cast(obs_code, pa.large_string())
obs_days_s = pc.cast(obs_days, pa.large_string())
obs_nanos_s = pc.cast(obs_nanos, pa.large_string())
eph_code_s = pc.cast(eph_code, pa.large_string())
eph_days_s = pc.cast(eph_days, pa.large_string())
eph_nanos_s = pc.cast(eph_nanos, pa.large_string())
obs_key = pc.binary_join_element_wise(
pc.binary_join_element_wise(obs_code_s, obs_days_s, sep), obs_nanos_s, sep
)
eph_key = pc.binary_join_element_wise(
pc.binary_join_element_wise(eph_code_s, eph_days_s, sep), eph_nanos_s, sep
)
# Uniqueness check: keys must uniquely identify an observer row.
if int(pc.count_distinct(obs_key).as_py()) != len(obs_key):
raise ValueError(
"Observer keys are not unique; cannot map ephemeris rows unambiguously."
)
idx = pc.fill_null(pc.index_in(eph_key, value_set=obs_key), -1)
idx_np = np.asarray(idx.to_numpy(zero_copy_only=False), dtype=np.int64)
if np.any(idx_np < 0):
raise ValueError("Failed to align ephemeris rows to observers (key mismatch).")
return idx_np
def _hg_params_for_ephemeris_rows_arrow(
*,
orbit_id: pa.Array | pa.ChunkedArray,
H_v: pa.Array | pa.ChunkedArray,
G: pa.Array | pa.ChunkedArray,
ephemeris_orbit_id: pa.Array | pa.ChunkedArray,
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
"""
Vectorized mapping from orbit_id -> (H_v, G) aligned to ephemeris rows.
This raises immediately if the input orbit table contains conflicting (H_v, G)
values for the same orbit_id.
"""
oid = pc.cast(orbit_id, pa.large_string())
H = pc.cast(H_v, pa.float64())
g = pc.cast(G, pa.float64())
valid = pc.and_(pc.is_finite(H), pc.is_finite(g))
if not pc.any(valid).as_py():
n = len(ephemeris_orbit_id)
return np.full(n, np.nan, dtype=np.float64), np.full(
n, np.nan, dtype=np.float64
)
t = pa.table({"orbit_id": oid, "H_v": H, "G": g})
t = t.filter(valid)
gb = t.group_by("orbit_id").aggregate(
[("H_v", "min"), ("H_v", "max"), ("G", "min"), ("G", "max")]
)
# Any mismatch indicates conflicting params for an orbit_id.
conflict = pc.or_(
pc.not_equal(gb["H_v_min"], gb["H_v_max"]),
pc.not_equal(gb["G_min"], gb["G_max"]),
)
if pc.any(conflict).as_py():
bad = pc.filter(gb["orbit_id"], conflict)
bad_id = bad[0].as_py() if len(bad) > 0 else "<unknown>"
raise ValueError(f"Conflicting physical parameters for orbit_id={bad_id}")
# Map ephemeris orbit_id -> grouped index
eph_oid = pc.cast(ephemeris_orbit_id, pa.large_string())
idx = pc.index_in(eph_oid, value_set=gb["orbit_id"])
# pc.take propagates nulls for missing indices; those convert to NaN below.
H_out = pc.take(gb["H_v_min"], idx)
G_out = pc.take(gb["G_min"], idx)
return (
np.asarray(H_out.to_numpy(zero_copy_only=False), dtype=np.float64),
np.asarray(G_out.to_numpy(zero_copy_only=False), dtype=np.float64),
)
[docs]
def propagation_worker(
orbits: Union[Orbits, VariantOrbits],
times: Timestamp,
propagator: Type["Propagator"],
**kwargs,
) -> Union[Orbits, VariantOrbits]:
prop = propagator(**kwargs)
propagated = prop._propagate_orbits(orbits, times)
return propagated
[docs]
def attach_magnitude_or_phase(
ephemeris: EphemerisType,
orbits: OrbitType,
observers: ObserverType,
predict_magnitudes: bool,
predict_phase_angle: bool,
) -> EphemerisType:
if (not predict_magnitudes and not predict_phase_angle) or len(ephemeris) == 0:
return ephemeris
if pc.all(pc.is_null(ephemeris.aberrated_coordinates.x)).as_py():
return ephemeris
orbits_value = ray.get(orbits) if isinstance(orbits, ObjectRef) else orbits
observers_value = (
ray.get(observers) if isinstance(observers, ObjectRef) else observers
)
want_mags = bool(predict_magnitudes)
want_alpha = bool(predict_phase_angle)
H_v: npt.NDArray[np.float64] | None = None
G: npt.NDArray[np.float64] | None = None
has_params: npt.NDArray[np.bool_] | None = None
if want_mags:
H_v, G = _hg_params_for_ephemeris_rows_arrow(
orbit_id=orbits_value.orbit_id,
H_v=orbits_value.physical_parameters.H_v,
G=orbits_value.physical_parameters.G,
ephemeris_orbit_id=ephemeris.orbit_id,
)
has_params = np.isfinite(H_v) & np.isfinite(G)
if not np.any(has_params):
want_mags = False
if not want_alpha and not want_mags:
return ephemeris
if ephemeris.coordinates.time.scale != observers_value.coordinates.time.scale:
observers_value = observers_value.set_column(
"coordinates.time",
observers_value.coordinates.time.rescale(ephemeris.coordinates.time.scale),
)
idx_np = _alignment_indices_struct_index_in(
obs_code=observers_value.code,
obs_days=observers_value.coordinates.time.days,
obs_nanos=observers_value.coordinates.time.nanos,
eph_code=ephemeris.coordinates.origin.code,
eph_days=ephemeris.coordinates.time.days,
eph_nanos=ephemeris.coordinates.time.nanos,
)
observers_aligned = observers_value.take(pa.array(idx_np, type=pa.int64()))
obj_helio = transform_coordinates(
ephemeris.aberrated_coordinates,
CartesianCoordinates,
frame_out="ecliptic",
origin_out=OriginCodes.SUN,
)
obs_helio = observers_aligned.set_column(
"coordinates",
transform_coordinates(
observers_aligned.coordinates,
CartesianCoordinates,
frame_out="ecliptic",
origin_out=OriginCodes.SUN,
),
)
alpha_deg = None
mags = None
if want_mags and want_alpha:
assert H_v is not None and G is not None and has_params is not None
mags, alpha_deg = calculate_apparent_magnitude_v_and_phase_angle(
H_v=H_v,
object_coords=obj_helio,
observer=obs_helio,
G=G,
)
elif want_alpha:
alpha_deg = calculate_phase_angle(obj_helio, obs_helio)
elif want_mags:
assert H_v is not None and G is not None and has_params is not None
mags = calculate_apparent_magnitude_v(
H_v=H_v,
object_coords=obj_helio,
observer=obs_helio,
G=G,
)
if alpha_deg is not None:
alpha_deg = np.asarray(alpha_deg, dtype=np.float64)
ephemeris = ephemeris.set_column(
"alpha",
pa.array(alpha_deg, mask=~np.isfinite(alpha_deg), type=pa.float64()),
)
if mags is not None:
assert has_params is not None
mags = np.asarray(mags, dtype=np.float64)
valid = has_params & np.isfinite(mags)
predicted = pa.array(mags, mask=~valid, type=pa.float64())
ephemeris = ephemeris.set_column("predicted_magnitude_v", predicted)
return ephemeris
def _uses_default_ephemeris_mixin(propagator: "Propagator") -> bool:
"""
True iff `propagator._generate_ephemeris` is the EphemerisMixin implementation.
Why: subclasses may override `_generate_ephemeris` without accepting extra keyword
arguments. We only pass photometry flags into the default mixin implementation.
"""
func = getattr(propagator._generate_ephemeris, "__func__", None)
return func is EphemerisMixin._generate_ephemeris
@ray.remote
def propagation_worker_ray(
idx: npt.NDArray[np.int64],
orbits: OrbitType,
times: OrbitType,
propagator: "Propagator",
) -> OrbitType:
orbits_chunk = orbits.take(idx)
propagated = propagator._propagate_orbits(orbits_chunk, times)
return propagated
@ray.remote
def ephemeris_worker_ray(
idx: npt.NDArray[np.int64],
orbits: OrbitType,
observers: ObserverType,
propagator: "Propagator",
predict_magnitudes: bool,
predict_phase_angle: bool,
) -> EphemerisType:
orbits_chunk = orbits.take(idx)
if _uses_default_ephemeris_mixin(propagator):
ephemeris = propagator._generate_ephemeris(
orbits_chunk,
observers,
predict_magnitudes=predict_magnitudes,
predict_phase_angle=predict_phase_angle,
)
else:
ephemeris = propagator._generate_ephemeris(orbits_chunk, observers)
ephemeris = attach_magnitude_or_phase(
ephemeris,
orbits=orbits_chunk,
observers=observers,
predict_magnitudes=predict_magnitudes,
predict_phase_angle=predict_phase_angle,
)
return ephemeris
[docs]
class EphemerisMixin:
"""
Mixin with signature for generating ephemerides.
Subclasses should implement the _generate_ephemeris method.
"""
def _generate_ephemeris(
self,
orbits: OrbitType,
observers: ObserverType,
*,
predict_magnitudes: bool = True,
predict_phase_angle: bool = False,
) -> EphemerisType:
"""
A generic ephemeris implementation, which can be used or overridden by subclasses.
"""
# Sort observers by time and code to ensure consistent ordering
# As further propagation will order by time as well
observers = observers.sort_by(
["coordinates.time.days", "coordinates.time.nanos", "code"]
)
observers_barycentric = observers.set_column(
"coordinates",
transform_coordinates(
observers.coordinates,
CartesianCoordinates,
frame_out="ecliptic",
origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER,
),
)
observers_barycentric_tiled = qv.concatenate(
[observers_barycentric] * len(orbits)
)
# Propagate orbits to sorted observer times
# Returns orbits sorted by orbit_id and time
propagated_orbits = self.propagate_orbits(
orbits, observers.coordinates.time, max_processes=1
)
# Transform both the orbits and observers to the barycenter if they are not already.
propagated_orbits_barycentric = propagated_orbits.set_column(
"coordinates",
transform_coordinates(
propagated_orbits.coordinates,
CartesianCoordinates,
frame_out="ecliptic",
origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER,
),
)
# Process in padded chunks
propagated_orbits_aberrated: np.ndarray
light_time: np.ndarray
propagated_orbits_barycentric_values = (
propagated_orbits_barycentric.coordinates.values
)
propagated_orbits_barycentric_time = (
propagated_orbits_barycentric.coordinates.time.mjd().to_numpy(
zero_copy_only=False
)
)
observers_barycentric_tiled_values = observers_barycentric_tiled.coordinates.r
chunk_size = 200
n = int(propagated_orbits_barycentric_values.shape[0])
# `process_in_chunks` pads each chunk to a fixed size for JAX; preallocate the padded
# output arrays and slice off padding after the loop. This avoids O(n^2) reallocation
# and memcpy from repeated np.concatenate calls.
n_padded = int(((n + int(chunk_size) - 1) // int(chunk_size)) * int(chunk_size))
propagated_orbits_aberrated = np.empty((n_padded, 6), dtype=np.float64)
light_time = np.empty((n_padded,), dtype=np.float64)
k = 0
for (
propagated_orbits_barycentric_chunk,
propagated_orbits_barycentric_time_chunk,
observers_barycentric_tiled_chunk,
) in zip(
process_in_chunks(propagated_orbits_barycentric_values, chunk_size),
process_in_chunks(propagated_orbits_barycentric_time, chunk_size),
process_in_chunks(observers_barycentric_tiled_values, chunk_size),
):
propagated_orbits_aberrated_chunk, light_time_chunk = add_light_time(
propagated_orbits_barycentric_chunk,
propagated_orbits_barycentric_time_chunk,
observers_barycentric_tiled_chunk,
lt_tol=1e-12,
mu=c.MU,
max_iter=100,
tol=1e-15,
)
propagated_orbits_aberrated[k : k + int(chunk_size), :] = np.asarray(
propagated_orbits_aberrated_chunk, dtype=np.float64
)
light_time[k : k + int(chunk_size)] = np.asarray(
light_time_chunk, dtype=np.float64
)
k += int(chunk_size)
# Remove padding
propagated_orbits_aberrated = propagated_orbits_aberrated[:n]
light_time = light_time[:n]
# Guard against pathological light-time values before constructing timestamps.
if not np.all(np.isfinite(light_time)):
raise ValueError(
"Light travel time is NaN or too large and propagation will break."
)
# Compute emission times by subtracting light-time (in days) from the
# propagated (observer) times.
emission_times = (
propagated_orbits_barycentric.coordinates.time.add_fractional_days(
pa.array(-light_time)
)
)
propagated_orbits_aberrated = Orbits.from_kwargs(
orbit_id=propagated_orbits_barycentric.orbit_id,
object_id=propagated_orbits_barycentric.object_id,
coordinates=CartesianCoordinates.from_kwargs(
x=propagated_orbits_aberrated[:, 0],
y=propagated_orbits_aberrated[:, 1],
z=propagated_orbits_aberrated[:, 2],
vx=propagated_orbits_aberrated[:, 3],
vy=propagated_orbits_aberrated[:, 4],
vz=propagated_orbits_aberrated[:, 5],
covariance=propagated_orbits_barycentric.coordinates.covariance,
time=emission_times,
origin=propagated_orbits_barycentric.coordinates.origin,
frame=propagated_orbits_barycentric.coordinates.frame,
),
)
topocentric_state = (
propagated_orbits_aberrated.coordinates.values
- observers_barycentric_tiled.coordinates.values
)
topocentric_coordinates = CartesianCoordinates.from_kwargs(
x=topocentric_state[:, 0],
y=topocentric_state[:, 1],
z=topocentric_state[:, 2],
vx=topocentric_state[:, 3],
vy=topocentric_state[:, 4],
vz=topocentric_state[:, 5],
covariance=None,
# The ephemeris times are at the point of the observer,
# not the aberrated orbit
time=observers_barycentric_tiled.coordinates.time,
origin=Origin.from_kwargs(
code=observers_barycentric_tiled.code.to_numpy(zero_copy_only=False)
),
frame="ecliptic",
)
spherical_coordinates = SphericalCoordinates.from_cartesian(
topocentric_coordinates
)
spherical_coordinates = transform_coordinates(
spherical_coordinates, SphericalCoordinates, frame_out="equatorial"
)
# Ephemeris are generally compared in UTC, so rescale the time
spherical_coordinates = spherical_coordinates.set_column(
"time",
spherical_coordinates.time.rescale("utc"),
)
if isinstance(orbits, Orbits):
ephemeris = Ephemeris.from_kwargs(
orbit_id=propagated_orbits_barycentric.orbit_id,
object_id=propagated_orbits_barycentric.object_id,
coordinates=spherical_coordinates,
light_time=light_time,
aberrated_coordinates=propagated_orbits_aberrated.coordinates,
)
elif isinstance(orbits, VariantOrbits):
# Propagated order is (orbit_id, variant_id, time) to match observer tiling
# [observers]*len(orbits). Sort input the same way and repeat each variant's
# weight for each time so row i gets the weight for the variant at that row.
orbits_sorted = orbits.sort_by(
[
"orbit_id",
"variant_id",
"coordinates.time.days",
"coordinates.time.nanos",
]
)
n_obs = len(observers)
weights_np = orbits_sorted.weights.to_numpy(zero_copy_only=False)
weights_cov_np = orbits_sorted.weights_cov.to_numpy(zero_copy_only=False)
ephemeris = VariantEphemeris.from_kwargs(
orbit_id=propagated_orbits_barycentric.orbit_id,
object_id=propagated_orbits_barycentric.object_id,
variant_id=propagated_orbits_barycentric.variant_id,
coordinates=spherical_coordinates,
light_time=light_time,
weights=np.repeat(weights_np, n_obs),
weights_cov=np.repeat(weights_cov_np, n_obs),
aberrated_coordinates=propagated_orbits_aberrated.coordinates,
)
# Return in same order as propagate_orbits: (orbit_id, variant_id, time) or
# (orbit_id, time). No sort here; callers use .sort_by() or .select()/.apply_mask().
ephemeris = attach_magnitude_or_phase(
ephemeris,
orbits=orbits,
observers=observers,
predict_magnitudes=predict_magnitudes,
predict_phase_angle=predict_phase_angle,
)
return ephemeris
[docs]
def generate_ephemeris(
self,
orbits: OrbitType,
observers: ObserverType,
covariance: bool = False,
covariance_method: Literal[
"auto", "sigma-point", "monte-carlo"
] = "monte-carlo",
num_samples: int = 1000,
chunk_size: int = 100,
max_processes: Optional[int] = 1,
seed: Optional[int] = None,
predict_magnitudes: bool = True,
predict_phase_angle: bool = False,
) -> Ephemeris:
"""
Generate ephemerides for each orbit in orbits as observed by each observer
in observers.
Parameters
----------
orbits : `~adam_core.orbits.orbits.Orbits` (N)
Orbits for which to generate ephemerides.
observers : `~adam_core.observers.observers.Observers` (M)
Observers for which to generate the ephemerides of each
orbit.
covariance: bool, optional
Propagate the covariance matrices of the orbits. This is done by sampling the
orbits from their covariance matrices and propagating each sample and for each
sample also generating ephemerides. The covariance
of the ephemerides is then the covariance of the samples.
covariance_method : {'sigma-point', 'monte-carlo', 'auto'}, optional
The method to use for sampling the covariance matrix. If 'auto' is selected then the method
will be automatically selected based on the covariance matrix. The default is 'monte-carlo'.
num_samples : int, optional
The number of samples to draw when sampling with monte-carlo.
chunk_size : int, optional
Number of orbits to send to each job.
max_processes : int or None, optional
Number of processes to launch. If None then the number of
processes will be equal to the number of cores on the machine. If 1
then no multiprocessing will be used. If "ray" is the parallel_backend and a ray instance
is initialized already then this argument is ignored.
Returns
-------
ephemeris : `~adam_core.orbits.ephemeris.Ephemeris` or `~adam_core.orbits.variants.VariantEphemeris`
Predicted ephemerides. Row order matches :meth:`propagate_orbits`: (orbit_id, time)
or (orbit_id, variant_id, time) for variant ephemeris. Use .sort_by() or
.select()/.apply_mask() for other orderings or grouping.
"""
# If sending in VariantOrbits, we make sure not to run covariance
assert (covariance is False) or (
isinstance(orbits, Orbits)
), "Covariance is not supported for VariantOrbits"
# Check if we need to propagate orbit variants so we can propagate covariance
# matrices
ephemeris: Ephemeris = Ephemeris.empty()
variant_ephemeris: VariantEphemeris = VariantEphemeris.empty()
variants = VariantOrbits.empty()
if covariance is True and not orbits.coordinates.covariance.is_all_nan():
variants = VariantOrbits.create(
orbits,
method=covariance_method,
num_samples=num_samples,
seed=seed,
)
if max_processes is None:
max_processes = mp.cpu_count()
uses_default_ephemeris = _uses_default_ephemeris_mixin(self)
needs_attach = max_processes <= 1 and not uses_default_ephemeris
if max_processes > 1:
initialize_use_ray(num_cpus=max_processes)
# Add orbits and observers to object store if
# they haven't already been added
if not isinstance(observers, ObjectRef):
observers_ref = ray.put(observers)
else:
observers_ref = observers
if not isinstance(orbits, ObjectRef):
orbits_ref = ray.put(orbits)
else:
orbits_ref = orbits
# We need to dereference the orbits ObjectRef so we can
# check its length for chunking and determine
# if we need to propagate variants
orbits = ray.get(orbits_ref)
# Create futures
futures_inputs = []
idx = np.arange(0, len(orbits))
# Use at least max_processes chunks so all workers get work.
effective_chunk_size = chunk_size
if max_processes > 1 and len(orbits) > 0:
effective_chunk_size = min(
chunk_size, max(1, len(orbits) // max_processes)
)
for idx_chunk in _iterate_chunks(idx, effective_chunk_size):
futures_inputs.append(
(
idx_chunk,
orbits_ref,
observers_ref,
self,
predict_magnitudes,
predict_phase_angle,
)
)
# Add variants to propagate to futures inputs
if covariance is True and len(variants) > 0:
variants_ref = ray.put(variants)
idx = np.arange(0, len(variants))
var_chunk_size = (
min(chunk_size, max(1, len(variants) // max_processes))
if max_processes > 1
else chunk_size
)
for variant_chunk_idx in _iterate_chunks(idx, var_chunk_size):
futures_inputs.append(
(
variant_chunk_idx,
variants_ref,
observers_ref,
self,
predict_magnitudes,
predict_phase_angle,
)
)
# Get results as they finish (we sort later)
futures = []
ephemeris_parts: list[Ephemeris] = []
variant_ephemeris_parts: list[VariantEphemeris] = []
for future_input in futures_inputs:
futures.append(ephemeris_worker_ray.remote(*future_input))
if len(futures) >= max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
result = ray.get(finished[0])
if isinstance(result, Ephemeris):
ephemeris_parts.append(result)
elif isinstance(result, VariantEphemeris):
variant_ephemeris_parts.append(result)
else:
raise ValueError(
f"Unexpected result type from ephemeris worker: {type(result)}"
)
while futures:
finished, futures = ray.wait(futures, num_returns=1)
result = ray.get(finished[0])
if isinstance(result, Ephemeris):
ephemeris_parts.append(result)
elif isinstance(result, VariantEphemeris):
variant_ephemeris_parts.append(result)
else:
raise ValueError(
f"Unexpected result type from ephemeris worker: {type(result)}"
)
if ephemeris_parts:
ephemeris = (
ephemeris_parts[0]
if len(ephemeris_parts) == 1
else qv.concatenate(ephemeris_parts)
)
if variant_ephemeris_parts:
variant_ephemeris = (
variant_ephemeris_parts[0]
if len(variant_ephemeris_parts) == 1
else qv.concatenate(variant_ephemeris_parts)
)
# Concatenation was in completion order; sort to canonical order.
if len(ephemeris) > 0:
ephemeris = ephemeris.sort_by(
[
"orbit_id",
"coordinates.time.days",
"coordinates.time.nanos",
"coordinates.origin.code",
]
)
if len(variant_ephemeris) > 0:
variant_ephemeris = variant_ephemeris.sort_by(
[
"orbit_id",
"variant_id",
"coordinates.time.days",
"coordinates.time.nanos",
"coordinates.origin.code",
]
)
else:
if uses_default_ephemeris:
results = self._generate_ephemeris(
orbits,
observers,
predict_magnitudes=predict_magnitudes,
predict_phase_angle=predict_phase_angle,
)
else:
results = self._generate_ephemeris(orbits, observers)
if isinstance(results, Ephemeris):
ephemeris = results
elif isinstance(results, VariantEphemeris):
variant_ephemeris = results
else:
raise ValueError(
f"Unexpected result type from generate_ephemeris: {type(results)}"
)
if covariance is True and len(variants) > 0:
if uses_default_ephemeris:
variant_ephemeris = self._generate_ephemeris(
variants,
observers,
predict_magnitudes=predict_magnitudes,
predict_phase_angle=predict_phase_angle,
)
else:
variant_ephemeris = self._generate_ephemeris(variants, observers)
if covariance is False and len(variant_ephemeris) > 0:
# If we decide that we do not need to guarantee that the time scale is in UTC
# then we may want to call:
# if isinstance(observers, ray.ObjectRef):
# variant_ephemeris = ensure_input_time_scale(
# variant_ephemeris, ray.get(observers).coordinates.time
# )
# else:
# variant_ephemeris = ensure_input_time_scale(
# variant_ephemeris, observers.coordinates.time
# )
variant_ephemeris = variant_ephemeris.set_column(
"coordinates.time",
variant_ephemeris.coordinates.time.rescale("utc"),
)
if needs_attach:
variant_ephemeris = attach_magnitude_or_phase(
variant_ephemeris,
orbits=orbits,
observers=observers,
predict_magnitudes=predict_magnitudes,
predict_phase_angle=predict_phase_angle,
)
return variant_ephemeris
if covariance is True and len(variant_ephemeris) > 0:
ephemeris = variant_ephemeris.collapse(ephemeris)
# Same note as above.
# if isinstance(observers, ray.ObjectRef):
# ephemeris = ensure_input_time_scale(
# ephemeris, ray.get(observers).coordinates.time
# )
# else:
# ephemeris = ensure_input_time_scale(
# ephemeris, observers.coordinates.time
# )
ephemeris = ephemeris.set_column(
"coordinates.time",
ephemeris.coordinates.time.rescale("utc"),
)
if needs_attach:
ephemeris = attach_magnitude_or_phase(
ephemeris,
orbits=orbits,
observers=observers,
predict_magnitudes=predict_magnitudes,
predict_phase_angle=predict_phase_angle,
)
return ephemeris
[docs]
class Propagator(ABC, EphemerisMixin):
"""
Abstract class for propagating orbits and related functions.
Subclasses should implement the _propagate_orbits.
For additional functions, subclasses can add abstract mixins.
Important: subclasses should be pickleable! As this class uses multiprocessing
to parallelize propagation and ephemeris generation. This means that
subclasses should not have any unpickleable attributes.
"""
@abstractmethod
def _propagate_orbits(self, orbits: OrbitType, times: TimestampType) -> OrbitType:
"""
Propagate orbits to times.
THIS FUNCTION SHOULD BE DEFINED BY THE USER.
"""
pass
def __getstate__(self):
"""
Get the state of the propagator.
Subclasses need to define what is picklable for multiprocessing.
e.g.
def __getstate__(self):
state = self.__dict__.copy()
state.pop("_stateful_attribute_that_is_not_pickleable")
return state
"""
raise NotImplementedError(
"Propagator must implement __getstate__ for multiprocessing serialization.\n"
"Example implementation: \n"
"def __getstate__(self):\n"
" state = self.__dict__.copy()\n"
" state.pop('_stateful_attribute_that_is_not_pickleable')\n"
" return state"
)
def __setstate__(self, state):
"""
Set the state of the propagator.
Subclasses need to define what is unpicklable for multiprocessing.
e.g.
def __setstate__(self, state):
self.__dict__.update(state)
self._stateful_attribute_that_is_not_pickleable = None
"""
raise NotImplementedError(
"Propagator must implement __setstate__ for multiprocessing serialization.\n"
"Example implementation: \n"
"def __setstate__(self, state):\n"
" self.__dict__.update(state)\n"
" self._stateful_attribute_that_is_not_pickleable = None"
)
[docs]
def propagate_orbits(
self,
orbits: Union[OrbitType, ObjectRef],
times: Union[TimestampType, ObjectRef],
covariance: bool = False,
covariance_method: Literal[
"auto", "sigma-point", "monte-carlo"
] = "monte-carlo",
num_samples: int = 1000,
chunk_size: int = 100,
max_processes: Optional[int] = 1,
seed: Optional[int] = None,
) -> Union[Orbits, VariantOrbits]:
"""
Propagate each orbit in orbits to each time in times.
This method handles parallelization of the propagation of the orbits.
Subclasses may override this method to modify parallelization behavior.
Parameters
----------
orbits : `~adam_core.orbits.orbits.Orbits` (N)
Orbits to propagate.
times : Timestamp (M)
Times to which to propagate orbits. Sorted chronologically before
calling the backend so integrators (e.g. ASSIST, REBOUND) receive
time-ordered epochs for efficient stepping.
covariance : bool, optional
Propagate the covariance matrices of the orbits. This is done by sampling the
orbits from their covariance matrices and propagating each sample. The covariance
of the propagated orbits is then the covariance of the samples.
covariance_method : {'sigma-point', 'monte-carlo', 'auto'}, optional
The method to use for sampling the covariance matrix. If 'auto' is selected then the method
will be automatically selected based on the covariance matrix. The default is 'monte-carlo'.
num_samples : int, optional
The number of samples to draw when sampling with monte-carlo.
chunk_size : int, optional
Number of orbits to send to each job.
max_processes : int or None, optional
Maximum number of processes to launch. If None then the number of
processes will be equal to the number of cores on the machine. If 1
then no multiprocessing will be used. If "ray" is the parallel_backend and a ray instance
is initialized already then this argument is ignored.
Returns
-------
propagated : `~adam_core.orbits.orbits.Orbits` or `~adam_core.orbits.variants.VariantOrbits`
Propagated orbits. Rows are ordered (orbit_id, time) for Orbits and
(orbit_id, variant_id, time) for VariantOrbits. Use .sort_by() or
.select()/.apply_mask() for other orderings or grouping.
"""
if covariance is True and isinstance(orbits, VariantOrbits):
raise AssertionError("Covariance is not supported for VariantOrbits")
if max_processes is None:
max_processes = mp.cpu_count()
# Resolve times and sort chronologically so backends (ASSIST, REBOUND, etc.)
# receive time-ordered epochs for efficient integration.
if isinstance(times, ObjectRef):
times = ray.get(times)
times = times.sort_by(["days", "nanos"])
if max_processes > 1:
propagated_list: List[Orbits] = []
covariance_variants_list: List[VariantOrbits] = []
# When the input is VariantOrbits, do not treat them as covariance.
propagated_variants_input_list: List[VariantOrbits] = []
input_is_variants: Optional[bool] = None
initialize_use_ray(num_cpus=max_processes)
times_ref = ray.put(times)
if not isinstance(orbits, ObjectRef):
input_is_variants = isinstance(orbits, VariantOrbits)
orbits_ref = ray.put(orbits)
else:
orbits_ref = orbits
# We need to dereference the orbits ObjectRef so we can
# check its length for chunking and determine
# if we need to propagate variants
orbits = ray.get(orbits_ref)
input_is_variants = isinstance(orbits, VariantOrbits)
if covariance is True and input_is_variants:
raise AssertionError("Covariance is not supported for VariantOrbits")
# Create futures inputs
futures_inputs = []
idx = np.arange(0, len(orbits))
for idx_chunk in _iterate_chunks(idx, chunk_size):
futures_inputs.append(
(
idx_chunk,
orbits_ref,
times_ref,
self,
)
)
# Add variants to propagate to futures inputs
if covariance is True and not orbits.coordinates.covariance.is_all_nan():
variants = VariantOrbits.create(
orbits,
method=covariance_method,
num_samples=num_samples,
seed=seed,
)
variants_ref = ray.put(variants)
idx = np.arange(0, len(variants))
for variant_chunk_idx in _iterate_chunks(idx, chunk_size):
futures_inputs.append(
(
variant_chunk_idx,
variants_ref,
times_ref,
self,
)
)
# Submit and process jobs with queuing
futures = []
for future_input in futures_inputs:
futures.append(propagation_worker_ray.remote(*future_input))
if len(futures) >= max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
result = ray.get(finished[0])
if isinstance(result, Orbits):
propagated_list.append(result)
elif isinstance(result, VariantOrbits):
if input_is_variants:
propagated_variants_input_list.append(result)
else:
covariance_variants_list.append(result)
else:
raise ValueError(
f"Unexpected result type from propagation worker: {type(result)}"
)
# Process remaining futures
while futures:
finished, futures = ray.wait(futures, num_returns=1)
result = ray.get(finished[0])
if isinstance(result, Orbits):
propagated_list.append(result)
elif isinstance(result, VariantOrbits):
if input_is_variants:
propagated_variants_input_list.append(result)
else:
covariance_variants_list.append(result)
else:
raise ValueError(
f"Unexpected result type from propagation worker: {type(result)}"
)
# Concatenate propagated orbits
if input_is_variants:
propagated = qv.concatenate(propagated_variants_input_list)
propagated_variants = None
else:
propagated = qv.concatenate(propagated_list)
if len(covariance_variants_list) > 0:
propagated_variants = qv.concatenate(covariance_variants_list)
propagated_variants = propagated_variants.sort_by(
[
"orbit_id",
"variant_id",
"coordinates.time.days",
"coordinates.time.nanos",
]
)
else:
propagated_variants = None
else:
propagated = self._propagate_orbits(orbits, times)
if covariance is True and not orbits.coordinates.covariance.is_all_nan():
variants = VariantOrbits.create(
orbits,
method=covariance_method,
num_samples=num_samples,
seed=seed,
)
propagated_variants = self._propagate_orbits(variants, times)
propagated_variants = propagated_variants.sort_by(
[
"orbit_id",
"variant_id",
"coordinates.time.days",
"coordinates.time.nanos",
]
)
else:
propagated_variants = None
if propagated_variants is not None:
propagated = propagated_variants.collapse(propagated)
# Preserve the time scale of the requested times
propagated = ensure_input_time_scale(propagated, times)
# Return the results with the original origin and frame
# Preserve the original output origin for the input orbits
# by orbit id
propagated = ensure_input_origin_and_frame(orbits, propagated)
# Internal order is (orbit_id, variant_id, time) so that observer tiling
# [observers]*len(orbits) in _generate_ephemeris matches row-for-row.
# Callers that need time-first or grouping should sort or use .select()/.apply_mask().
if isinstance(propagated, VariantOrbits):
return propagated.sort_by(
[
"orbit_id",
"variant_id",
"coordinates.time.days",
"coordinates.time.nanos",
]
)
return propagated.sort_by(
["orbit_id", "coordinates.time.days", "coordinates.time.nanos"]
)