Source code for adam_core.coordinates.jacobian

from typing import Callable, Hashable, Optional, Tuple

import jax.numpy as jnp
import numpy as np
from jax import config, jacfwd, vmap

config.update("jax_enable_x64", True)

__all__ = ["calc_jacobian"]


[docs] def calc_jacobian( coords: np.ndarray, _func: Callable, in_axes: Optional[Hashable] = (0,), out_axes: Optional[int] = 0, **kwargs, ) -> jnp.ndarray: """ Calculate the jacobian for the given callable in D dimensions for every N coordinate. Parameters ---------- coords : `~numpy.ndarray` (N, D) Coordinates that correspond to the input covariance matrices. _func : function A function that takes a single coord (D) as input and return the transformed coordinate (D). See for example: `adam_core.coordinates.transform._cartesian_to_spherical` or `adam_core.coordinates.transform._cartesian_to_keplerian`. in_axes : Optional[Hashable] An integer or ``None`` indicates which array axis to map over for all arguments (with ``None`` indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. Axis integers must be in the range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of dimensions (axes) of the corresponding input array. From: https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap out_axes : Optional[int] An integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a non-None out_axes specification. From: https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap Returns ------- jacobian : `~numpy.ndarray` (N, D, D) Array containing function partial derivatives for each coordinate. """ # Calculate the jacobian function for the input function # Do this only once! jacobian_func = jacfwd(_func, argnums=0) vmapped_jacobian_func = vmap( jacobian_func, in_axes=in_axes, out_axes=out_axes, ) jacobian = vmapped_jacobian_func(coords, *kwargs.values()) # If the vmapped function returns more outputs, then only # return the first one. All relevant functions in adam_core return # primary result first, though we may want to come up with a more general # solution in the future. if isinstance(jacobian, Tuple): # type: ignore jacobian = jacobian[0] return np.array(jacobian)