import logging
import multiprocessing as mp
import time
from typing import Literal, Optional, Tuple, Type, Union
import numpy as np
import numpy.typing as npt
import pyarrow.compute as pc
import quivr as qv
import ray
from scipy.linalg import solve
from ..coordinates import CartesianCoordinates, CoordinateCovariances
from ..coordinates.residuals import Residuals
from ..orbit_determination import OrbitDeterminationObservations
from ..orbits import Orbits
from ..propagator import Propagator
from ..ray_cluster import initialize_use_ray
from ..utils.iter import _iterate_chunk_indices, _iterate_chunks
from .fitted_orbits import FittedOrbitMembers, FittedOrbits
from .outliers import calculate_max_outliers
logger = logging.getLogger(__name__)
__all__ = ["differential_correction"]
def od_worker(
orbit_ids: npt.NDArray[np.str_],
orbits: FittedOrbits,
orbit_members: FittedOrbitMembers,
observations: OrbitDeterminationObservations,
propagator: Type[Propagator],
rchi2_threshold: float = 100,
min_obs: int = 5,
min_arc_length: float = 1.0,
contamination_percentage: float = 0.0,
delta: float = 1e-6,
max_iter: int = 20,
method: Literal["central", "finite"] = "central",
propagator_kwargs: dict = {},
) -> Tuple[FittedOrbits, FittedOrbitMembers]:
od_orbits = FittedOrbits.empty()
od_orbit_members = FittedOrbitMembers.empty()
for orbit_id in orbit_ids:
time_start = time.time()
logger.debug(f"Differentially correcting orbit {orbit_id}...")
orbit = orbits.select("orbit_id", orbit_id)
obs_ids = orbit_members.apply_mask(
pc.equal(orbit_members.orbit_id, orbit_id)
).obs_id
orbit_observations = observations.apply_mask(pc.is_in(observations.id, obs_ids))
# Sort observations by time
orbit_observations = orbit_observations.sort_by(
[
"coordinates.time.days",
"coordinates.time.nanos",
"coordinates.origin.code",
]
)
od_orbit, od_orbit_orbit_members = od(
orbit,
orbit_observations,
rchi2_threshold=rchi2_threshold,
min_obs=min_obs,
min_arc_length=min_arc_length,
contamination_percentage=contamination_percentage,
delta=delta,
max_iter=max_iter,
method=method,
propagator=propagator,
propagator_kwargs=propagator_kwargs,
)
time_end = time.time()
duration = time_end - time_start
logger.debug(f"OD for orbit {orbit_id} completed in {duration:.3f}s.")
od_orbits = qv.concatenate([od_orbits, od_orbit])
if od_orbits.fragmented():
od_orbits = qv.defragment(od_orbits)
od_orbit_members = qv.concatenate([od_orbit_members, od_orbit_orbit_members])
if od_orbit_members.fragmented():
od_orbit_members = qv.defragment(od_orbit_members)
return od_orbits, od_orbit_members
@ray.remote
def od_worker_remote(
orbit_ids: npt.NDArray[np.str_],
orbit_ids_indices: Tuple[int, int],
orbits: FittedOrbits,
orbit_members: FittedOrbitMembers,
observations: OrbitDeterminationObservations,
propagator: Type[Propagator],
rchi2_threshold: float = 100,
min_obs: int = 5,
min_arc_length: float = 1.0,
contamination_percentage: float = 0.0,
delta: float = 1e-6,
max_iter: int = 20,
method: Literal["central", "finite"] = "central",
propagator_kwargs: dict = {},
) -> Tuple[FittedOrbits, FittedOrbitMembers]:
orbit_ids_chunk = orbit_ids[orbit_ids_indices[0] : orbit_ids_indices[1]]
return od_worker(
orbit_ids_chunk,
orbits,
orbit_members,
observations,
rchi2_threshold=rchi2_threshold,
min_obs=min_obs,
min_arc_length=min_arc_length,
contamination_percentage=contamination_percentage,
delta=delta,
max_iter=max_iter,
method=method,
propagator=propagator,
propagator_kwargs=propagator_kwargs,
)
od_worker_remote.options(num_returns=1, num_cpus=1)
def od(
orbit: FittedOrbits,
observations: OrbitDeterminationObservations,
propagator: Type[Propagator],
rchi2_threshold: float = 100,
min_obs: int = 5,
min_arc_length: float = 1.0,
contamination_percentage: float = 0.0,
delta: float = 1e-6,
max_iter: int = 20,
method: Literal["central", "finite"] = "central",
propagator_kwargs: dict = {},
) -> Tuple[FittedOrbits, FittedOrbitMembers]:
# Intialize the propagator
prop = propagator(**propagator_kwargs)
if method not in ["central", "finite"]:
err = "method should be one of 'central' or 'finite'."
raise ValueError(err)
obs_ids_all = observations.id.to_numpy(zero_copy_only=False)
coords = observations.coordinates
coords_sigma = coords.covariance.sigmas[:, 1:3]
observers = observations.observers
times_all = coords.time.mjd().to_numpy(zero_copy_only=False)
# FLAG: can we stop iterating to find a solution?
converged = False
# FLAG: has an orbit with reduced chi2 less than the reduced chi2 of the input orbit been found?
improved = False
# FLAG: has an orbit with reduced chi2 less than the rchi2_threshold been found?
solution_found = False
# FLAG: is this orbit processable? Does it have at least min_obs observations?
processable = True
# FLAG: is this the first iteration with a successful differential correction (this solution is always stored as the solution
# which needs to be improved.. input orbits may not have been previously corrected with current set of observations so this
# forces at least one succesful iteration to have been taken.)
first_solution = True
num_obs = len(observations)
if num_obs < min_obs:
logger.debug("This orbit has fewer than {} observations.".format(min_obs))
processable = False
else:
max_outliers = calculate_max_outliers(
num_obs, min_obs, contamination_percentage
)
logger.debug(f"Maximum number of outliers allowed: {max_outliers}")
outliers_tried = 0
# Calculate chi2 for residuals on the given observations
# for the current orbit, the goal is for the orbit to improve
# such that the chi2 improves
orbit_prev_ = orbit
ephemeris_prev_ = prop.generate_ephemeris(
orbit_prev_, observers, chunk_size=1, max_processes=1
)
# Calculate residuals and chi2
residuals_prev_ = Residuals.calculate(
coords,
ephemeris_prev_.coordinates,
)
residuals_prev_array = np.stack(
residuals_prev_.values.to_numpy(zero_copy_only=False)
)[:, 1:3]
num_obs_ = len(observations)
chi2_prev_ = residuals_prev_.chi2.to_numpy()
chi2_total_prev_ = np.sum(chi2_prev_)
rchi2_prev_ = np.sum(chi2_prev_) / (2 * num_obs - 6)
# Save the initial orbit in case we need to reset
# to it later
orbit_prev = orbit_prev_
ephemeris_prev = ephemeris_prev_
residuals_prev = residuals_prev_
num_obs = num_obs_
chi2_prev = chi2_prev_
chi2_total_prev = chi2_total_prev_
rchi2_prev = rchi2_prev_
ids_mask = np.array([True for i in range(num_obs)])
times_all = ephemeris_prev.coordinates.time.mjd().to_numpy()
obs_id_outlier = []
delta_prev = delta
iterations = 0
DELTA_INCREASE_FACTOR = 5
DELTA_DECREASE_FACTOR = 100
max_iter_i = max_iter
max_iter_outliers = max_iter * (max_outliers + 1)
while not converged and processable:
iterations += 1
# We add 1 here because the iterations are counted as they start, not
# as they finish. There are a lot of 'continue' statements down below that
# will exit the current iteration if something fails which makes accounting for
# iterations at the start of an iteration easier.
if iterations == max_iter_outliers + 1:
logger.debug("Maximum number of iterations completed.")
break
if iterations == max_iter_i + 1 and (
solution_found or (max_outliers == outliers_tried)
):
logger.debug("Maximum number of iterations completed.")
break
logger.debug(f"Starting iteration number: {iterations}/{max_iter_outliers}")
# Make sure delta is well bounded
if delta_prev < 1e-14:
delta_prev *= DELTA_INCREASE_FACTOR
logger.debug("Delta is too small, increasing.")
elif delta_prev > 1e-2:
delta_prev /= DELTA_DECREASE_FACTOR
logger.debug("Delta is too large, decreasing.")
else:
pass
delta_iter = delta_prev
logger.debug(f"Starting iteration {iterations} with delta {delta_iter}.")
# Initialize the partials derivatives matrix
num_params = 6
A = np.zeros((2, num_params, num_obs))
ATWA = np.zeros((num_params, num_params, num_obs))
ATWb = np.zeros((num_params, 1, num_obs))
# Generate ephemeris with current nominal orbit
ephemeris_nom = prop.generate_ephemeris(
orbit_prev, observers, chunk_size=1, max_processes=1
)
# Modify each component of the state by a small delta
d = np.zeros((1, 6))
for i in range(num_params):
# zero the delta vector
d *= 0.0
# x, y, z [au]: 0, 1, 2
# vx, vy, vz [au per day]: 3, 4, 5
if i < 3:
delta_iter = delta_prev
d[0, i] = orbit_prev.coordinates.values[0, i] * delta_iter
elif i < 6:
delta_iter = delta_prev
d[0, i] = orbit_prev.coordinates.values[0, i] * delta_iter
# Modify component i of the orbit by a small delta
cartesian_elements_p = orbit_prev.coordinates.values + d[0, :6]
orbit_iter_p = Orbits.from_kwargs(
coordinates=CartesianCoordinates.from_kwargs(
x=cartesian_elements_p[:, 0],
y=cartesian_elements_p[:, 1],
z=cartesian_elements_p[:, 2],
vx=cartesian_elements_p[:, 3],
vy=cartesian_elements_p[:, 4],
vz=cartesian_elements_p[:, 5],
time=orbit_prev.coordinates.time,
origin=orbit_prev.coordinates.origin,
frame=orbit_prev.coordinates.frame,
)
)
# Calculate the modified ephemerides
ephemeris_mod_p = prop.generate_ephemeris(
orbit_iter_p, observers, chunk_size=1, max_processes=1
)
delta_denom = d[0, i]
if method == "central":
# Modify component i of the orbit by a small delta
cartesian_elements_n = orbit_prev.coordinates.values - d[0, :6]
orbit_iter_n = Orbits.from_kwargs(
coordinates=CartesianCoordinates.from_kwargs(
x=cartesian_elements_n[:, 0],
y=cartesian_elements_n[:, 1],
z=cartesian_elements_n[:, 2],
vx=cartesian_elements_n[:, 3],
vy=cartesian_elements_n[:, 4],
vz=cartesian_elements_n[:, 5],
time=orbit_prev.coordinates.time,
origin=orbit_prev.coordinates.origin,
frame=orbit_prev.coordinates.frame,
)
)
# Calculate the modified ephemerides
ephemeris_mod_n = prop.generate_ephemeris(
orbit_iter_n, observers, chunk_size=1, max_processes=1
)
delta_denom *= 2
else:
ephemeris_mod_n = ephemeris_nom
residuals_mod = Residuals.calculate(
ephemeris_mod_p.coordinates,
ephemeris_mod_n.coordinates,
)
residuals_mod = np.stack(
residuals_mod.values.to_numpy(zero_copy_only=False)
)
residuals_mod_array = residuals_mod[:, 1:3]
for n in range(num_obs):
try:
A[:, i : i + 1, n] = (
residuals_mod_array[ids_mask][n : n + 1].T / delta_denom
)
except RuntimeError:
print(orbit_prev.orbit_id)
for n in range(num_obs):
W = np.diag(1 / coords_sigma[n] ** 2)
ATWA[:, :, n] = A[:, :, n].T @ W @ A[:, :, n]
ATWb[:, :, n] = A[:, :, n].T @ W @ residuals_prev_array[n : n + 1].T
ATWA = np.sum(ATWA, axis=2)
ATWb = np.sum(ATWb, axis=2)
try:
# Calculate the condition number of both matrices
ATWA_condition = np.linalg.cond(ATWA)
ATWb_condition = np.linalg.cond(ATWb)
except np.linalg.LinAlgError:
ATWA_condition = np.nan
ATWb_condition = np.nan
logger.debug(
f"Matrix condition calculation failed for {orbit.orbit_id[0].as_py()}"
)
if (ATWA_condition > 1e15) or (ATWb_condition > 1e15):
delta_prev /= DELTA_DECREASE_FACTOR
continue
if np.any(np.isnan(ATWA)) or np.any(np.isnan(ATWb)):
delta_prev *= DELTA_INCREASE_FACTOR
continue
else:
try:
delta_state = solve(
ATWA,
ATWb,
).T
covariance_matrix = np.linalg.inv(ATWA)
variances = np.diag(covariance_matrix)
if np.any(variances <= 0) or np.any(np.isnan(variances)):
delta_prev /= DELTA_DECREASE_FACTOR
logger.debug(
"Variances are negative, 0.0, or NaN. Discarding solution."
)
continue
r_variances = variances[0:3]
r_sigma = np.sqrt(np.sum(r_variances))
r = orbit_prev.coordinates.r_mag
if (r_sigma / r) > 1:
delta_prev /= DELTA_DECREASE_FACTOR
logger.debug(
"Covariance matrix is largely unconstrained. Discarding solution."
)
continue
if np.any(np.isnan(covariance_matrix)):
delta_prev *= DELTA_INCREASE_FACTOR
logger.debug(
"Covariance matrix contains NaNs. Discarding solution."
)
continue
except np.linalg.LinAlgError:
delta_prev *= DELTA_INCREASE_FACTOR
continue
if np.linalg.norm(delta_state[:3]) < 1e-16:
logger.debug("Change in state is less than 1e-16 au, discarding solution.")
delta_prev *= DELTA_DECREASE_FACTOR
continue
if np.linalg.norm(delta_state[:3]) > 100:
delta_prev /= DELTA_DECREASE_FACTOR
logger.debug("Change in state is more than 100 au, discarding solution.")
continue
cartesian_elements = orbit_prev.coordinates.values + delta_state
orbit_iter = Orbits.from_kwargs(
orbit_id=orbit_prev.orbit_id,
coordinates=CartesianCoordinates.from_kwargs(
x=cartesian_elements[:, 0],
y=cartesian_elements[:, 1],
z=cartesian_elements[:, 2],
vx=cartesian_elements[:, 3],
vy=cartesian_elements[:, 4],
vz=cartesian_elements[:, 5],
covariance=CoordinateCovariances.from_matrix(
covariance_matrix.reshape(1, 6, 6)
),
time=orbit_prev.coordinates.time,
origin=orbit_prev.coordinates.origin,
frame=orbit_prev.coordinates.frame,
),
)
if np.linalg.norm(orbit_iter.coordinates.v_mag) > 1:
delta_prev *= DELTA_INCREASE_FACTOR
logger.debug("Orbit is moving extraordinarily fast, discarding solution.")
continue
# Generate ephemeris with current nominal orbit
ephemeris_iter = prop.generate_ephemeris(
orbit_iter, observers, chunk_size=1, max_processes=1
)
residuals = Residuals.calculate(coords, ephemeris_iter.coordinates)
chi2_iter = residuals.chi2.to_numpy()
chi2_total_iter = np.sum(chi2_iter[ids_mask])
rchi2_iter = chi2_total_iter / (2 * num_obs - num_params)
arc_length = times_all[ids_mask].max() - times_all[ids_mask].min()
# If the new orbit has lower residuals than the previous,
# accept the orbit and continue iterating until max iterations has been
# reached. Once max iterations have been reached and the orbit still has not converged
# to an acceptable solution, try removing an observation as an outlier and iterate again.
if (
(rchi2_iter < rchi2_prev) or first_solution
) and arc_length >= min_arc_length:
if first_solution:
logger.debug(
"Storing first successful differential correction iteration for these observations."
)
first_solution = False
else:
logger.debug("Potential improvement orbit has been found.")
orbit_prev = orbit_iter
residuals_prev = residuals
chi2_prev = chi2_iter
chi2_total_prev = chi2_total_iter
rchi2_prev = rchi2_iter
if rchi2_prev <= rchi2_prev_:
improved = True
if rchi2_prev <= rchi2_threshold:
logger.debug("Potential solution orbit has been found.")
solution_found = True
converged = True
elif (
max_outliers > 0
and outliers_tried <= max_outliers
and iterations > max_iter_i
and not solution_found
):
logger.debug("Attempting to identify possible outliers.")
# Previous fits have failed, lets reset the current best fit orbit back to its original
# state and re-run fitting, this time removing outliers
orbit_prev = orbit_prev_
ephemeris_prev = ephemeris_prev_
residuals_prev = residuals_prev_
num_obs = num_obs_
chi2_prev = chi2_prev_
chi2_total_prev = chi2_total_prev_
rchi2_prev = rchi2_prev_
delta_prev = delta
# Select i highest observations that contribute to
# chi2 (and thereby the residuals)
remove = chi2_prev.argsort()[-(outliers_tried + 1) :]
# Grab the obs_ids for these outliers
obs_id_outlier = obs_ids_all[remove]
num_obs = len(observations) - len(obs_id_outlier)
ids_mask = np.isin(obs_ids_all, obs_id_outlier, invert=True)
arc_length = times_all[ids_mask].max() - times_all[ids_mask].min()
logger.debug("Possible outlier(s): {}".format(obs_id_outlier))
outliers_tried += 1
if arc_length >= min_arc_length:
max_iter_i = max_iter * (outliers_tried + 1)
else:
logger.debug(
"Removing the outlier will cause the arc length to go below the minimum."
)
# If the new orbit does not have lower residuals, try changing
# delta to see if we get an improvement
else:
# logger.debug("Orbit did not improve since previous iteration, decrease delta and continue.")
# delta_prev /= DELTA_DECREASE_FACTOR
pass
logger.debug(
"Current r-chi2: {}, Previous r-chi2: {}, Max Iterations: {}, Outliers Tried: {}".format(
rchi2_iter, rchi2_prev, max_iter_i, outliers_tried
)
)
logger.debug("Solution found: {}".format(solution_found))
logger.debug("First solution: {}".format(first_solution))
if not solution_found or not processable or first_solution:
od_orbit = FittedOrbits.empty()
od_orbit_members = FittedOrbitMembers.empty()
else:
obs_times = observations.coordinates.time.mjd().to_numpy(zero_copy_only=False)[
ids_mask
]
arc_length_ = obs_times.max() - obs_times.min()
assert arc_length == arc_length_
od_orbit = FittedOrbits.from_kwargs(
orbit_id=orbit_prev.orbit_id,
object_id=orbit_prev.object_id,
coordinates=orbit_prev.coordinates,
arc_length=[arc_length_],
num_obs=[num_obs],
chi2=[chi2_total_prev],
reduced_chi2=[rchi2_prev],
iterations=[iterations],
success=[improved],
status_code=[0],
)
od_orbit_members = FittedOrbitMembers.from_kwargs(
orbit_id=np.full(
len(obs_ids_all), orbit_prev.orbit_id[0].as_py(), dtype="object"
),
obs_id=obs_ids_all,
residuals=residuals_prev,
solution=np.isin(obs_ids_all, obs_id_outlier, invert=True),
outlier=np.isin(obs_ids_all, obs_id_outlier),
)
return od_orbit, od_orbit_members
[docs]
def differential_correction(
orbits: Union[FittedOrbits, ray.ObjectRef],
orbit_members: Union[FittedOrbitMembers, ray.ObjectRef],
observations: Union[OrbitDeterminationObservations, ray.ObjectRef],
propagator: Type[Propagator],
min_obs: int = 5,
min_arc_length: float = 1.0,
contamination_percentage: float = 20,
rchi2_threshold: float = 100,
delta: float = 1e-8,
max_iter: int = 20,
method: Literal["central", "finite"] = "central",
propagator_kwargs: dict = {},
chunk_size: int = 10,
max_processes: Optional[int] = 1,
orbit_ids: Optional[npt.NDArray[np.str_]] = None,
obs_ids: Optional[npt.NDArray[np.str_]] = None,
) -> Tuple[FittedOrbits, FittedOrbitMembers]:
"""
Differentially correct (via finite/central differencing).
Parameters
----------
chunk_size : int, optional
Number of orbits to send to each job.
num_jobs : int, optional
Number of jobs to launch.
parallel_backend : str, optional
Which parallelization backend to use {'ray', 'mp', 'cf'}. Defaults to using Python's concurrent.futures
module ('cf').
"""
time_start = time.perf_counter()
logger.info("Running differential correction...")
if isinstance(orbits, ray.ObjectRef):
orbits_ref = orbits
orbits = ray.get(orbits)
logger.info("Retrieved orbits from the object store.")
if orbit_ids is not None:
orbits = orbits.apply_mask(pc.is_in(orbits.orbit_id, orbit_ids))
logger.info("Applied mask to orbit members.")
else:
orbits_ref = None
if isinstance(orbit_members, ray.ObjectRef):
orbit_members_ref = orbit_members
orbit_members = ray.get(orbit_members)
logger.info("Retrieved orbit members from the object store.")
if obs_ids is not None:
orbit_members = orbit_members.apply_mask(
pc.is_in(orbit_members.obs_id, obs_ids)
)
logger.info("Applied mask to orbit members.")
if orbit_ids is not None:
orbit_members = orbit_members.apply_mask(
pc.is_in(orbit_members.orbit_id, orbit_ids)
)
logger.info("Applied mask to orbit members.")
else:
orbit_members_ref = None
if isinstance(observations, ray.ObjectRef):
observations_ref = observations
observations = ray.get(observations)
logger.info("Retrieved observations from the object store.")
if obs_ids is not None:
observations = observations.apply_mask(pc.is_in(observations.id, obs_ids))
logger.info("Applied mask to observations.")
else:
observations_ref = None
if len(orbits) == 0 or len(orbit_members) == 0:
logger.info("Received no orbits or orbit members.")
od_orbits = FittedOrbits.empty()
od_orbit_members = FittedOrbitMembers.empty()
time_end = time.perf_counter()
logger.info(f"Differentially corrected {len(od_orbits)} orbits.")
logger.info(
f"Differential correction completed in {time_end - time_start:.3f} seconds."
)
return od_orbits, od_orbit_members
orbit_ids = orbits.orbit_id.to_numpy(zero_copy_only=False)
od_orbits = FittedOrbits.empty()
od_orbit_members = FittedOrbitMembers.empty()
if max_processes is None:
max_processes = mp.cpu_count()
use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:
refs_to_free = []
orbit_ids_ref = ray.put(orbit_ids)
orbit_ids = ray.get(orbit_ids_ref)
refs_to_free.append(orbit_ids_ref)
logger.info("Placed orbit IDs in the object store.")
if orbits_ref is None:
orbits_ref = ray.put(orbits)
orbits = ray.get(orbits_ref)
refs_to_free.append(orbits_ref)
logger.info("Placed orbits in the object store.")
if orbit_members_ref is None:
orbit_members_ref = ray.put(orbit_members)
orbit_members = ray.get(orbit_members_ref)
refs_to_free.append(orbit_members_ref)
logger.info("Placed orbit members in the object store.")
if observations_ref is None:
observations_ref = ray.put(observations)
refs_to_free.append(observations_ref)
observations = ray.get(observations_ref)
logger.info("Placed observations in the object store.")
futures = []
for orbit_ids_indices in _iterate_chunk_indices(orbit_ids, chunk_size):
futures.append(
od_worker_remote.remote(
orbit_ids_ref,
orbit_ids_indices,
orbits_ref,
orbit_members_ref,
observations_ref,
rchi2_threshold=rchi2_threshold,
min_obs=min_obs,
min_arc_length=min_arc_length,
contamination_percentage=contamination_percentage,
delta=delta,
max_iter=max_iter,
method=method,
propagator=propagator,
propagator_kwargs=propagator_kwargs,
)
)
if len(futures) >= max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)
od_orbits_chunk, od_orbit_members_chunk = ray.get(finished[0])
od_orbits = qv.concatenate([od_orbits, od_orbits_chunk])
if od_orbits.fragmented():
od_orbits = qv.defragment(od_orbits)
od_orbit_members = qv.concatenate(
[od_orbit_members, od_orbit_members_chunk]
)
if od_orbit_members.fragmented():
od_orbit_members = qv.defragment(od_orbit_members)
while futures:
finished, futures = ray.wait(futures, num_returns=1)
od_orbits_chunk, od_orbit_members_chunk = ray.get(finished[0])
od_orbits = qv.concatenate([od_orbits, od_orbits_chunk])
if od_orbits.fragmented():
od_orbits = qv.defragment(od_orbits)
od_orbit_members = qv.concatenate(
[od_orbit_members, od_orbit_members_chunk]
)
if od_orbit_members.fragmented():
od_orbit_members = qv.defragment(od_orbit_members)
if len(refs_to_free) > 0:
ray.internal.free(refs_to_free)
logger.info(
f"Removed {len(refs_to_free)} references from the object store."
)
else:
for orbit_ids_chunk in _iterate_chunks(orbit_ids, chunk_size):
od_orbits_chunk, od_orbit_members_chunk = od_worker(
orbit_ids_chunk,
orbits,
orbit_members,
observations,
rchi2_threshold=rchi2_threshold,
min_obs=min_obs,
min_arc_length=min_arc_length,
contamination_percentage=contamination_percentage,
delta=delta,
max_iter=max_iter,
method=method,
propagator=propagator,
propagator_kwargs=propagator_kwargs,
)
od_orbits = qv.concatenate([od_orbits, od_orbits_chunk])
if od_orbits.fragmented():
od_orbits = qv.defragment(od_orbits)
od_orbit_members = qv.concatenate(
[od_orbit_members, od_orbit_members_chunk]
)
if od_orbit_members.fragmented():
od_orbit_members = qv.defragment(od_orbit_members)
time_end = time.perf_counter()
logger.info(f"Differentially corrected {len(od_orbits)} orbits.")
logger.info(
f"Differential correction completed in {time_end - time_start:.3f} seconds."
)
return od_orbits, od_orbit_members