from __future__ import annotations
import logging
import threading
import time
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, List
import numpy as np
import numpy.typing as npt
import pyarrow as pa
import requests
from astroquery.jplsbdb import SBDB
from ...coordinates.cometary import CometaryCoordinates
from ...coordinates.covariances import CoordinateCovariances, sigmas_to_covariances
from ...coordinates.origin import Origin
from ...time import Timestamp
from ..orbits import Orbits
from ..physical_parameters import PhysicalParameters
logger = logging.getLogger(__name__)
_SBDB_API_URL = "https://ssd-api.jpl.nasa.gov/sbdb.api"
_SBDB_API_FAIR_USE_MAX_CONCURRENT_REQUESTS = 1
_thread_local = threading.local()
def _get_requests_session() -> requests.Session:
"""
Return a per-thread `requests.Session`.
Why: using a session enables connection pooling, which reduces overhead when querying many IDs.
A per-thread session is a safe default if callers choose to enable limited concurrency.
"""
sess = getattr(_thread_local, "session", None)
if sess is None:
sess = requests.Session()
_thread_local.session = sess
return sess
def _convert_SBDB_covariances(
sbdb_covariances: npt.ArrayLike,
) -> npt.ArrayLike:
"""
Convert SBDB covariance matrices to Cometary covariance matrices.
Parameters
----------
sbdb_covariances : `~numpy.ndarray` (N, 6, 6)
Covariance matrices pulled from JPL's Small Body Database Browser.
Returns
-------
covariances : `~numpy.ndarray` (N, 6, 6)
Cometary covariance matrices.
"""
covariances = np.zeros_like(sbdb_covariances)
# sigma_q{x}
covariances[:, 0, 0] = sbdb_covariances[:, 1, 1] # sigma_qq
covariances[:, 1, 0] = covariances[:, 0, 1] = sbdb_covariances[:, 0, 1] # sigma_qe
covariances[:, 2, 0] = covariances[:, 0, 2] = sbdb_covariances[:, 5, 1] # sigma_qi
covariances[:, 3, 0] = covariances[:, 0, 3] = sbdb_covariances[
:, 3, 1
] # sigma_qraan
covariances[:, 4, 0] = covariances[:, 0, 4] = sbdb_covariances[:, 4, 1] # sigma_qap
covariances[:, 5, 0] = covariances[:, 0, 5] = sbdb_covariances[:, 2, 1] # sigma_qtp
# sigma_e{x}
covariances[:, 1, 1] = sbdb_covariances[:, 0, 0] # sigma_ee
covariances[:, 2, 1] = covariances[:, 1, 2] = sbdb_covariances[:, 5, 0] # sigma_ei
covariances[:, 3, 1] = covariances[:, 1, 3] = sbdb_covariances[
:, 3, 0
] # sigma_eraan
covariances[:, 4, 1] = covariances[:, 1, 4] = sbdb_covariances[:, 4, 0] # sigma_eap
covariances[:, 5, 1] = covariances[:, 1, 5] = sbdb_covariances[:, 2, 0] # sigma_etp
# sigma_i{x}
covariances[:, 2, 2] = sbdb_covariances[:, 5, 5] # sigma_ii
covariances[:, 3, 2] = covariances[:, 2, 3] = sbdb_covariances[
:, 3, 5
] # sigma_iraan
covariances[:, 4, 2] = covariances[:, 2, 4] = sbdb_covariances[:, 4, 5] # sigma_iap
covariances[:, 5, 2] = covariances[:, 2, 5] = sbdb_covariances[:, 2, 5] # sigma_itp
# sigma_raan{x}
covariances[:, 3, 3] = sbdb_covariances[:, 3, 3] # sigma_raanraan
covariances[:, 4, 3] = covariances[:, 3, 4] = sbdb_covariances[
:, 4, 3
] # sigma_raanap
covariances[:, 5, 3] = covariances[:, 3, 5] = sbdb_covariances[
:, 2, 3
] # sigma_raantp
# sigma_ap{x}
covariances[:, 4, 4] = sbdb_covariances[:, 4, 4] # sigma_apap
covariances[:, 5, 4] = covariances[:, 4, 5] = sbdb_covariances[
:, 2, 4
] # sigma_aptp
# sigma_tp{x}
covariances[:, 5, 5] = sbdb_covariances[:, 2, 2] # sigma_tptp
return covariances
def _get_sbdb_elements(obj_ids: List[str]) -> List[OrderedDict]:
"""
Get orbital elements and other object properties
from JPL's Small Body Database Browser.
Parameters
----------
obj_ids : List
Object IDs to query.
Returns
-------
results : List
List of dictionaries containing orbital elements and other object properties.
"""
results = []
SBDB.clear_cache() # Yikes!
for obj_id in obj_ids:
result = SBDB.query(
obj_id,
covariance="mat",
id_type="search",
full_precision=True,
solution_epoch=False,
)
results.append(result)
return results
def _orbits_from_sbdb_results(ids: npt.ArrayLike, results: List[OrderedDict]) -> Orbits:
"""
Convert SBDB query results into an `Orbits` table.
What: shared implementation for both the legacy astroquery-based query and the new
direct-HTTP query.
Why: keeping a single conversion path ensures both entrypoints return identical results.
"""
orbit_ids = []
object_ids = []
classes = []
coords_cometary = np.zeros((len(results), 6), dtype=np.float64)
covariances_sbdb = np.zeros((len(results), 6, 6), dtype=np.float64)
times = np.zeros((len(results)), dtype=np.float64)
for i, result in enumerate(results):
if "object" not in result:
raise NotFoundError("object {} was not found", ids[i])
orbit_ids.append(f"{i:05d}")
object_ids.append(result["object"]["fullname"])
classes.append(result["object"]["orbit_class"]["code"])
orbit = result["orbit"]
elements = orbit["elements"]
epoch = orbit["epoch"]
if "covariance" in orbit:
labels = orbit["covariance"]["labels"]
if len(labels) > 6:
logger.debug(
"Covariance matrix has more parameters than just orbital elements. "
"Ignoring non-orbital elements in covariance matrix."
)
labels = labels[:6]
expected_labels = ["e", "q", "tp", "node", "peri", "i"]
if labels != expected_labels:
raise ValueError(
"Expected covariance matrix labels to be {expected_labels} instead got {labels}."
)
# Limit covariances to just the orbital elements
# The SBDB API documentation states that physical parameter covariances
# are appended to the rows and columns of the covariance matrix with the
# orbital elements remaining in the first 6 rows and columns.
# See: Orbit Subsection: covariance in https://ssd-api.jpl.nasa.gov/doc/sbdb.html
covariances_sbdb[i, :, :] = orbit["covariance"]["data"][:6, :6]
if "elements" in orbit["covariance"]:
# If elements is provided inside covariance, then it's
# the elements at the epoch which was used to
# calculate covariance, so we should prefer it.
elements = orbit["covariance"]["elements"]
epoch = orbit["covariance"]["epoch"]
else:
sigmas = np.array(
[
[
elements["e_sig"],
elements["q_sig"].value,
elements["tp_sig"].value,
elements["om_sig"].value,
elements["w_sig"].value,
elements["i_sig"].value,
]
]
)
covariances_sbdb[i, :, :] = sigmas_to_covariances(sigmas)[0]
times[i] = epoch.value
coords_cometary[i, 0] = elements["q"].value
coords_cometary[i, 1] = elements["e"]
coords_cometary[i, 2] = elements["i"].value
coords_cometary[i, 3] = elements["om"].value
coords_cometary[i, 4] = elements["w"].value
coords_cometary[i, 5] = (
Timestamp.from_jd([elements["tp"].value], scale="tdb").mjd()[0].as_py()
)
covariances_cometary = _convert_SBDB_covariances(covariances_sbdb)
times = Timestamp.from_jd(times, scale="tdb")
origin = Origin.from_kwargs(code=["SUN" for i in range(len(times))])
frame = "ecliptic"
coordinates = CometaryCoordinates.from_kwargs(
time=times,
q=coords_cometary[:, 0],
e=coords_cometary[:, 1],
i=coords_cometary[:, 2],
raan=coords_cometary[:, 3],
ap=coords_cometary[:, 4],
tp=coords_cometary[:, 5],
covariance=CoordinateCovariances.from_matrix(covariances_cometary),
origin=origin,
frame=frame,
)
orbit_ids = np.array(orbit_ids, dtype="object")
object_ids = np.array(object_ids, dtype="object")
classes = np.array(classes)
# Legacy astroquery path does not request phys-par; fill with nulls.
phys_rows = [(None, None, None, None)] * len(results)
physical_parameters = _physical_parameters_from_sbdb(phys_rows)
return Orbits.from_kwargs(
orbit_id=orbit_ids,
object_id=object_ids,
coordinates=coordinates.to_cartesian(),
physical_parameters=physical_parameters,
)
[docs]
def query_sbdb(ids: npt.ArrayLike) -> Orbits:
"""
Query JPL's Small-Body Database (SBDB) for orbits. The epoch at
which the orbits are returned are near the epoch as published by the
Minor Planet Center.
By default, the orbit's covariance matrices are also queried for. If they
are not available, then the 1-sigma uncertainties are used to construct
the covariance matrices.
Parameters
----------
ids : list
List of object IDs to query.
Returns
-------
orbits : `~adam_core.orbits.orbits.Orbits`
Orbits object containing the queried orbits.
Raises
------
NotFoundError: If any of the queries object IDs are not found.
"""
results = _get_sbdb_elements(ids)
return _orbits_from_sbdb_results(ids, results)
def _sbdb_api_get_json(
object_id: str,
*,
timeout_s: float,
max_attempts: int,
) -> dict[str, Any]:
"""
Query JPL's public SBDB JSON API for a single object, with retries.
Why: we want explicit timeout/retry behavior, and we want to avoid global cache clearing.
Notes:
- Per the JPL SSD/CNEOS API fair use policy, clients should not send concurrent requests.
This helper does not enforce that policy; the public entrypoint defaults to sequential
requests (max_concurrent_requests=1).
"""
obj = str(object_id).strip()
if not obj:
raise ValueError("object_id must be non-empty")
if timeout_s <= 0:
raise ValueError("timeout_s must be > 0")
if max_attempts <= 0:
raise ValueError("max_attempts must be > 0")
params = {
"sstr": obj,
"cov": "mat",
"full-prec": "true",
"phys-par": "true",
}
last_err: Exception | None = None
for attempt in range(max_attempts):
try:
resp = _get_requests_session().get(
_SBDB_API_URL, params=params, timeout=timeout_s
)
resp.raise_for_status()
return resp.json()
except (
requests.exceptions.Timeout,
requests.exceptions.ConnectionError,
) as err:
last_err = err
except requests.exceptions.HTTPError as err:
# Retry on transient server errors and explicit throttling.
status = err.response.status_code if err.response is not None else None
if status is not None and (status >= 500 or status == 429):
last_err = err
else:
raise
except Exception:
# Non-retryable (JSON decode, unexpected failure, etc.)
raise
# Exponential backoff with a small cap.
sleep_s = min(8.0, 0.5 * (2**attempt))
time.sleep(sleep_s)
raise RuntimeError(f"SBDB query failed after {max_attempts} attempts: {last_err}")
def _sbdb_float(value: Any) -> float:
"""
Convert a JSON scalar (string/number) into a float.
SBDB returns many numeric fields as strings (including scientific notation). We normalize
those to floats for orbit construction.
"""
if value is None:
raise ValueError("Expected a numeric value, got None.")
if isinstance(value, (float, int, np.floating, np.integer)):
return float(value)
if isinstance(value, str):
return float(value)
raise TypeError(f"Expected a numeric JSON scalar, got {type(value)!r}.")
def _sbdb_elements_map(elements: Any) -> dict[str, dict[str, Any]]:
"""
Convert SBDB's `orbit.elements` list into a dict keyed by element short-name.
"""
if not isinstance(elements, list):
raise ValueError("Expected SBDB orbit elements to be a list.")
out: dict[str, dict[str, Any]] = {}
for el in elements:
if not isinstance(el, dict):
continue
name = el.get("name")
if name is None:
continue
out[str(name)] = el
return out
def _sbdb_element_value(
elements_by_name: dict[str, dict[str, Any]], name: str
) -> float:
el = elements_by_name.get(name)
if el is None:
raise ValueError(f"SBDB orbit elements missing {name!r}.")
if "value" not in el or el["value"] is None:
raise ValueError(f"SBDB orbit element {name!r} is missing a value.")
return _sbdb_float(el["value"])
def _sbdb_element_sigma(
elements_by_name: dict[str, dict[str, Any]], name: str
) -> float:
el = elements_by_name.get(name)
if el is None:
raise ValueError(f"SBDB orbit elements missing {name!r}.")
sigma = el.get("sigma")
if sigma is None:
# Some SBDB payloads omit per-element uncertainties when no covariance matrix is provided.
# We treat this as "unknown uncertainty" and propagate NaNs through the diagonal fallback.
return float("nan")
try:
return _sbdb_float(sigma)
except Exception:
return float("nan")
def _sbdb_phys_par_value(el: dict[str, Any] | None) -> float | None:
"""Extract numeric value from a phys_par entry (value may be scalar or in el['value'])."""
if el is None:
return None
v = el.get("value")
if v is None:
return None
try:
return _sbdb_float(v)
except (ValueError, TypeError):
return None
def _sbdb_phys_par_sigma(el: dict[str, Any] | None) -> float | None:
"""Extract sigma from a phys_par entry; None if missing."""
if el is None:
return None
s = el.get("sigma")
if s is None:
return None
try:
return _sbdb_float(s)
except (ValueError, TypeError):
return None
def _sbdb_phys_par_from_payload(
payload: dict[str, Any],
) -> tuple[float | None, float | None, float | None, float | None]:
"""
Extract H (V-band), H_sigma, G, G_sigma from SBDB phys_par when requested with phys-par=1.
SBDB documents H as "absolute magnitude (magnitude at 1 au from Sun and observer)" (V-band);
G as "magnitude slope parameter" (H-G system). API may use name "H" or "H_mag" depending on source.
Ref: https://ssd-api.jpl.nasa.gov/doc/sbdb.html#phys_par
"""
phys_list = payload.get("phys_par") or []
by_name: dict[str, dict[str, Any]] = {}
for d in phys_list:
if isinstance(d, dict) and d.get("name") is not None:
by_name[str(d["name"])] = d
H_v = _sbdb_phys_par_value(by_name.get("H")) or _sbdb_phys_par_value(
by_name.get("H_mag")
)
H_v_sigma = _sbdb_phys_par_sigma(by_name.get("H")) or _sbdb_phys_par_sigma(
by_name.get("H_mag")
)
G = _sbdb_phys_par_value(by_name.get("G"))
G_sigma = _sbdb_phys_par_sigma(by_name.get("G"))
return (H_v, H_v_sigma, G, G_sigma)
def _physical_parameters_from_sbdb(
rows: list[tuple[float | None, float | None, float | None, float | None]],
) -> PhysicalParameters:
"""Build PhysicalParameters from SBDB phys_par extractions (one row per payload)."""
if not rows:
return PhysicalParameters.from_kwargs(H_v=[], H_v_sigma=[], G=[], G_sigma=[])
H_v = np.array(
[r[0] if r[0] is not None else np.nan for r in rows],
dtype=np.float64,
)
H_v_sigma = np.array(
[r[1] if r[1] is not None else np.nan for r in rows],
dtype=np.float64,
)
G = np.array(
[r[2] if r[2] is not None else np.nan for r in rows],
dtype=np.float64,
)
G_sigma = np.array(
[r[3] if r[3] is not None else np.nan for r in rows],
dtype=np.float64,
)
return PhysicalParameters.from_kwargs(
H_v=H_v,
H_v_sigma=H_v_sigma,
G=G,
G_sigma=G_sigma,
)
def _orbits_from_sbdb_payloads(
ids: list[str],
payloads: list[dict[str, Any]],
) -> Orbits:
"""
Convert raw SBDB JSON payloads into an `Orbits` table.
This mirrors the behavior of the legacy `query_sbdb` implementation:
- Prefer covariance-provided elements/epoch when present.
- Use covariance matrix when available; otherwise build a diagonal covariance from sigmas.
"""
if len(ids) != len(payloads):
raise ValueError("ids and payloads must have the same length.")
expected_labels = ["e", "q", "tp", "node", "peri", "i"]
orbit_ids: list[str] = []
object_ids: list[str] = []
phys_rows: list[tuple[float | None, float | None, float | None, float | None]] = []
coords_cometary = np.zeros((len(payloads), 6), dtype=np.float64)
covariances_sbdb = np.zeros((len(payloads), 6, 6), dtype=np.float64)
times_jd = np.zeros((len(payloads)), dtype=np.float64)
for i, (obj_id, payload) in enumerate(zip(ids, payloads)):
if "object" not in payload:
raise NotFoundError("object {} was not found", obj_id)
if "orbit" not in payload:
raise ValueError(f"SBDB payload for {obj_id!r} missing 'orbit'.")
obj = payload["object"] or {}
orbit_ids.append(f"{i:05d}")
object_ids.append(str(obj.get("fullname")))
orbit = payload["orbit"] or {}
elements_list = orbit.get("elements")
epoch_jd = _sbdb_float(orbit.get("epoch"))
cov = orbit.get("covariance")
cov_matrix: np.ndarray | None = None
if isinstance(cov, dict) and cov.get("data") is not None:
labels = cov.get("labels")
if isinstance(labels, list):
labels6 = [str(x) for x in labels[:6]]
if labels6 != expected_labels:
raise ValueError(
f"Expected covariance matrix labels to be {expected_labels} "
f"in the first 6 entries, got {labels6}."
)
data = np.asarray(cov["data"], dtype=np.float64)
if data.ndim != 2 or data.shape[0] < 6 or data.shape[1] < 6:
raise ValueError("Expected SBDB covariance matrix to be at least 6x6.")
cov_matrix = data[:6, :6]
# If covariance provides elements, prefer them (and the covariance epoch).
if "elements" in cov and cov["elements"] is not None:
elements_list = cov["elements"]
if cov.get("epoch") is not None:
epoch_jd = _sbdb_float(cov.get("epoch"))
if elements_list is None:
raise ValueError(f"SBDB payload for {obj_id!r} missing orbit elements.")
elements_by_name = _sbdb_elements_map(elements_list)
if cov_matrix is None:
# Fallback: build a diagonal covariance from per-element sigmas.
sigmas = np.array(
[
[
_sbdb_element_sigma(elements_by_name, "e"),
_sbdb_element_sigma(elements_by_name, "q"),
_sbdb_element_sigma(elements_by_name, "tp"),
_sbdb_element_sigma(elements_by_name, "om"),
_sbdb_element_sigma(elements_by_name, "w"),
_sbdb_element_sigma(elements_by_name, "i"),
]
],
dtype=np.float64,
)
cov_matrix = sigmas_to_covariances(sigmas)[0]
covariances_sbdb[i, :, :] = cov_matrix
times_jd[i] = epoch_jd
q = _sbdb_element_value(elements_by_name, "q")
e = _sbdb_element_value(elements_by_name, "e")
inc = _sbdb_element_value(elements_by_name, "i")
om = _sbdb_element_value(elements_by_name, "om")
w = _sbdb_element_value(elements_by_name, "w")
tp_jd = _sbdb_element_value(elements_by_name, "tp")
tp_mjd = Timestamp.from_jd([tp_jd], scale="tdb").mjd()[0].as_py()
coords_cometary[i, 0] = q
coords_cometary[i, 1] = e
coords_cometary[i, 2] = inc
coords_cometary[i, 3] = om
coords_cometary[i, 4] = w
coords_cometary[i, 5] = tp_mjd
phys_rows.append(_sbdb_phys_par_from_payload(payload))
covariances_cometary = _convert_SBDB_covariances(covariances_sbdb)
times = Timestamp.from_jd(times_jd, scale="tdb")
origin = Origin.from_kwargs(code=["SUN" for _ in range(len(times))])
coordinates = CometaryCoordinates.from_kwargs(
time=times,
q=coords_cometary[:, 0],
e=coords_cometary[:, 1],
i=coords_cometary[:, 2],
raan=coords_cometary[:, 3],
ap=coords_cometary[:, 4],
tp=coords_cometary[:, 5],
covariance=CoordinateCovariances.from_matrix(covariances_cometary),
origin=origin,
frame="ecliptic",
)
physical_parameters = _physical_parameters_from_sbdb(phys_rows)
return Orbits.from_kwargs(
orbit_id=np.array(orbit_ids, dtype="object"),
object_id=np.array(object_ids, dtype="object"),
coordinates=coordinates.to_cartesian(),
physical_parameters=physical_parameters,
)
def _get_sbdb_payloads_new(
obj_ids: List[str],
*,
max_concurrent_requests: int,
timeout_s: float,
max_attempts: int,
) -> list[dict[str, Any]]:
"""
Fetch SBDB JSON payloads via direct HTTP, optionally with limited concurrency.
Important: JPL's SSD/CNEOS API fair use policy requests only one in-flight request at a time.
The public entrypoint defaults to `max_concurrent_requests=1` to comply.
"""
if max_concurrent_requests <= 0:
raise ValueError("max_concurrent_requests must be > 0")
n = len(obj_ids)
if n == 0:
return []
max_workers = min(int(max_concurrent_requests), n)
if max_workers > 1:
logger.warning(
"query_sbdb_new is configured with max_concurrent_requests=%s. "
"JPL's SSD/CNEOS API fair use policy requests only one in-flight request at a time; "
"concurrent requests may be rejected.",
max_workers,
)
results: list[dict[str, Any] | None] = [None] * n
def fetch_one(i: int, object_id: str) -> tuple[int, dict[str, Any]]:
payload = _sbdb_api_get_json(
object_id, timeout_s=timeout_s, max_attempts=max_attempts
)
return i, payload
with ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(fetch_one, i, obj_id) for i, obj_id in enumerate(obj_ids)]
for fut in as_completed(futures):
i, res = fut.result()
results[i] = res
out: list[dict[str, Any]] = []
for r in results:
if r is None:
raise RuntimeError("SBDB payload missing from concurrent fetch.")
out.append(r)
return out
[docs]
def query_sbdb_new(
ids: npt.ArrayLike,
*,
max_concurrent_requests: int = 1,
timeout_s: float = 60.0,
max_attempts: int = 5,
allow_missing: bool = False,
orbit_id_from_input: bool = False,
) -> Orbits:
"""
Query JPL SBDB for orbits using direct HTTP requests (new implementation).
This is intended to be a drop-in alternative to `query_sbdb` that:
- avoids `SBDB.clear_cache()`,
- provides explicit timeout/retry controls, and
- can optionally fetch multiple objects concurrently.
Notes
-----
JPL's SSD/CNEOS API fair use policy requests only one in-flight API request at a time.
Therefore, the default `max_concurrent_requests=1` is the recommended setting.
Parameters
----------
allow_missing : bool, optional
If True, do not raise when an ID is not found in SBDB. Instead, return an `Orbits` table
containing only the successfully resolved IDs (potentially empty).
orbit_id_from_input : bool, optional
If True, set the returned `Orbits.orbit_id` values to the input IDs (after any missing
filtering). This is useful when callers need to map rows back to the requested identifiers.
"""
# Normalize ids into a list of strings while preserving the caller's order.
if isinstance(ids, (str, bytes)):
obj_ids = [str(ids)]
else:
obj_ids = [str(x) for x in ids]
payloads = _get_sbdb_payloads_new(
obj_ids,
max_concurrent_requests=max_concurrent_requests,
timeout_s=timeout_s,
max_attempts=max_attempts,
)
if allow_missing:
kept_ids: list[str] = []
kept_payloads: list[dict[str, Any]] = []
for obj_id, payload in zip(obj_ids, payloads):
if "object" not in payload:
continue
kept_ids.append(obj_id)
kept_payloads.append(payload)
if not kept_ids:
return Orbits.empty()
orbits = _orbits_from_sbdb_payloads(kept_ids, kept_payloads)
if orbit_id_from_input:
orbits = orbits.set_column(
"orbit_id", pa.array(kept_ids, type=pa.large_string())
)
return orbits
orbits = _orbits_from_sbdb_payloads(obj_ids, payloads)
if orbit_id_from_input:
orbits = orbits.set_column(
"orbit_id", pa.array(obj_ids, type=pa.large_string())
)
return orbits
[docs]
class NotFoundError(Exception):
def __init__(self, message, object_id):
self.message = message
self.object_id = object_id
def __str__(self):
return self.message