"""Routines for finding transit times."""
__all__ = [
"find_transit_times_all",
"find_transit_times_fast",
"find_transit_params_all",
"find_transit_params_fast",
"find_transit_times_kepler_all",
]
import jax.numpy as jnp
from jax import checkpoint, config, vmap
from jax.lax import scan
from .conversion import G, jacobi_to_astrocentric, xvjac_to_xvacm
from .hermite4 import hermite4_step_map
from .symplectic import kepler_kick_kepler_map, kepler_step_map, kick_kepler_map
from .utils import find_nearest_idx, find_nearest_idx_sorted
config.update("jax_enable_x64", True)
def _get_tcflag(xjac, vjac):
"""Find steps just after transit centers using Jacobi coordinates.
Args:
xjac: Jacobi positions of shape ``(Nstep, Norbit, 3)``.
vjac: Jacobi velocities of shape ``(Nstep, Norbit, 3)``.
Returns:
Boolean array of shape ``(Nstep - 1, Norbit)`` whose entries are
True at steps just after the transit center.
"""
g = jnp.sum(xjac[:, :, :2] * vjac[:, :, :2], axis=2)
return (g[:-1] < 0.0) & (g[1:] > 0.0) & (xjac[1:, :, 2] > 0.0)
def _get_g_map(xjac, vjac, pidxarr):
"""Evaluate ``x ยท v`` in the sky plane for the selected planets."""
def g_orbit(xjac_one, vjac_one, j):
return jnp.sum(xjac_one[j, :2] * vjac_one[j, :2])
g_map = vmap(g_orbit, (0, 0, 0), 0)
return g_map(xjac, vjac, pidxarr).ravel()
def _find_tc_idx(t, tcflag, j, tcobs):
"""Find candidate-step indices nearest to the target transit times.
Args:
t: Time array of shape ``(Nstep,)``.
tcflag: Boolean array of shape ``(Nstep - 1, Norbit)`` indicating
whether each step is just after a transit center.
j: Orbit index.
tcobs: Target transit time or times for the selected orbit.
Returns:
Indices in ``t[1:]`` closest to ``tcobs`` among the entries where
``tcflag[:, j]`` is True.
"""
tc_candidates = jnp.where(tcflag[:, j], t[1:], -jnp.inf)
return find_nearest_idx(tc_candidates, jnp.atleast_1d(tcobs))
_find_tc_idx_map = vmap(_find_tc_idx, (None, None, 0, 0), 0)
def _get_nrstep(x, v, a, j):
"""Compute one Newton correction step for the selected planet.
Args:
x: Positions in the center-of-mass frame of shape ``(Nbody, 3)``.
v: Velocities in the center-of-mass frame of shape ``(Nbody, 3)``.
a: Accelerations in the center-of-mass frame of shape ``(Nbody, 3)``.
j: Orbit index starting from 0.
Returns:
Newton correction step for the selected planet.
"""
xastj = x[j + 1, :] - x[0, :]
vastj = v[j + 1, :] - v[0, :]
aastj = a[j + 1, :] - a[0, :]
gj = jnp.sum(xastj[:2] * vastj[:2])
dotgj = jnp.sum(vastj[:2] * vastj[:2]) + jnp.sum(xastj[:2] * aastj[:2])
return -gj / dotgj
_get_nrstep_map = vmap(_get_nrstep, (0, 0, 0, 0), 0)
def _find_transit_candidates(pidxarr, tcobsarr, t, xvjac):
"""Find integration steps nearest to the target transit times."""
xjac, vjac = xvjac[:, 0, :, :], xvjac[:, 1, :, :]
tcflag = _get_tcflag(xjac, vjac)
return _find_tc_idx_map(t, tcflag, pidxarr, tcobsarr).ravel()
def _prepare_transit_newton_init(tcidx, pidxarr, t, xvjac, masses):
"""Prepare initial states for Newton refinement."""
tc = t[1:][tcidx]
xvjac_init = xvjac[1:][tcidx]
# Bring back the system by dt/2 so that the states are at the
# conclusions of the symplectic step.
dt_correct = -0.5 * jnp.diff(t)[0]
tc += dt_correct
xjac_init, vjac_init = kepler_step_map(
xvjac_init[:, 0, :, :], xvjac_init[:, 1, :, :], masses, dt_correct
)
xcm_init, vcm_init, acm_init = xvjac_to_xvacm(xjac_init, vjac_init, masses)
nrstep_init = _get_nrstep_map(xcm_init, vcm_init, acm_init, pidxarr)
return tc, xcm_init, vcm_init, nrstep_init
def _scan_transit_newton(xcm_init, vcm_init, nrstep_init, pidxarr, masses, nitr):
"""Run Newton refinement for the transit times."""
def tcstep(xvs, _):
xin, vin, step = xvs
xtc, vtc, atc = hermite4_step_map(xin, vin, masses, step)
xtc = jnp.transpose(xtc, axes=[2, 0, 1])
vtc = jnp.transpose(vtc, axes=[2, 0, 1])
atc = jnp.transpose(atc, axes=[2, 0, 1])
step = _get_nrstep_map(xtc, vtc, atc, pidxarr)
return (xtc, vtc, step), step
tcstep = checkpoint(tcstep)
return scan(tcstep, (xcm_init, vcm_init, nrstep_init), jnp.arange(nitr))
def _find_transit_newton_core(pidxarr, tcobsarr, t, xvjac, masses, nitr):
"""Core Newton-based transit finder."""
tcidx = _find_transit_candidates(pidxarr, tcobsarr, t, xvjac)
tc, xcm_init, vcm_init, nrstep_init = _prepare_transit_newton_init(
tcidx, pidxarr, t, xvjac, masses
)
xvs, steps = _scan_transit_newton(
xcm_init, vcm_init, nrstep_init, pidxarr, masses, nitr
)
tc += nrstep_init + jnp.sum(steps, axis=0)
return tc, xvs
def _find_tc_idx_sorted(t, tcobs):
"""Find indices in ``t[1:]`` nearest to the target transit times.
Args:
t: Time array of shape ``(Nstep,)``.
tcobs: Target transit time or times. A scalar or array-like input is
accepted.
Returns:
Indices in ``t[1:]`` nearest to ``tcobs``.
"""
return find_nearest_idx_sorted(t[1:], jnp.atleast_1d(tcobs))
def _advance_to_tcobs_fast(tcidx, pidxarr, tcobsarr, t, xvjac, masses):
"""Advance states to ``tcobsarr`` for the fast transit finder.
Args:
tcidx: Indices in ``t[1:]`` nearest to ``tcobsarr``.
pidxarr: Planet indices corresponding to each transit.
tcobsarr: Observed transit times.
t: Time array of shape ``(Nstep,)``.
xvjac: Jacobi-frame positions and velocities of shape
``(Nstep, 2, Norbit, 3)``.
masses: Mass array of shape ``(Nbody,)``.
Returns:
Tuple containing
- the observed transit times,
- positions in the center-of-mass frame at ``tcobsarr``,
- velocities in the center-of-mass frame at ``tcobsarr``, and
- the transit-time correction evaluated at ``tcobsarr``.
"""
tc = t[1:][tcidx]
xvjac_init = xvjac[1:][tcidx]
# Bring back the system by dt/2 so that the states are at the
# conclusions of the symplectic step.
dt_correct = -0.5 * jnp.diff(t)[0]
tc += dt_correct
xjac_init, vjac_init = kepler_step_map(
xvjac_init[:, 0, :, :],
xvjac_init[:, 1, :, :],
masses,
dt_correct,
)
# Advance from the step boundary to the observed transit times.
dt_to_tcobs = tcobsarr - tc
xjac_tcobs, vjac_tcobs = kepler_kick_kepler_map(
xjac_init,
vjac_init,
masses,
dt_to_tcobs,
)
xcm_tcobs, vcm_tcobs, acm_tcobs = xvjac_to_xvacm(
xjac_tcobs, vjac_tcobs, masses
)
nrstep_tcobs = _get_nrstep_map(xcm_tcobs, vcm_tcobs, acm_tcobs, pidxarr)
return tcobsarr, xcm_tcobs, vcm_tcobs, nrstep_tcobs
def _find_transit_times_fast_core(pidxarr, tcobsarr, t, xvjac, masses):
"""Core fast transit finder returning only transit times."""
tcidx = _find_tc_idx_sorted(t, tcobsarr)
tcobs, _, _, nrstep_tcobs = _advance_to_tcobs_fast(
tcidx, pidxarr, tcobsarr, t, xvjac, masses
)
return tcobs + nrstep_tcobs
def _find_transit_params_fast_core(pidxarr, tcobsarr, t, xvjac, masses):
"""Core fast transit finder returning transit times and phase-space states."""
tcidx = _find_tc_idx_sorted(t, tcobsarr)
tcobs, xcm_tcobs, vcm_tcobs, nrstep_tcobs = _advance_to_tcobs_fast(
tcidx, pidxarr, tcobsarr, t, xvjac, masses
)
xcm_tc, vcm_tc, _ = hermite4_step_map(
xcm_tcobs, vcm_tcobs, masses, nrstep_tcobs)
xcm_tc = jnp.transpose(xcm_tc, axes=[2, 0, 1])
vcm_tc = jnp.transpose(vcm_tc, axes=[2, 0, 1])
return tcobs + nrstep_tcobs, (xcm_tc, vcm_tc, nrstep_tcobs)
[docs]
def find_transit_times_all(pidxarr, tcobsarr, t, xvjac, masses, nitr=5):
"""Find transit times for all requested transits using Newton refinement.
Args:
pidxarr: Orbit indices starting from 0, with shape ``(Ntransit,)``.
tcobsarr: Flattened array of observed transit times of shape
``(Ntransit,)``.
t: Time array of shape ``(Nstep,)``.
xvjac: Jacobi positions and velocities of shape
``(Nstep, 2, Norbit, 3)``.
masses: Mass array of shape ``(Nbody,)``.
nitr: Number of Newton-Raphson iterations.
Returns:
Transit times as a one-dimensional array.
"""
tc, _ = _find_transit_newton_core(
pidxarr, tcobsarr, t, xvjac, masses, nitr)
return tc
[docs]
def find_transit_times_fast(pidxarr, tcobsarr, t, xvjac, masses):
"""Find transit times for all requested transits using the fast algorithm.
Args:
pidxarr: Orbit indices starting from 0, with shape ``(Ntransit,)``.
tcobsarr: Flattened array of observed transit times of shape
``(Ntransit,)``.
t: Time array of shape ``(Nstep,)``.
xvjac: Jacobi positions and velocities of shape
``(Nstep, 2, Norbit, 3)``.
masses: Mass array of shape ``(Nbody,)``.
Returns:
Transit times as a one-dimensional array.
"""
return _find_transit_times_fast_core(pidxarr, tcobsarr, t, xvjac, masses)
[docs]
def find_transit_params_all(pidxarr, tcobsarr, t, xvjac, masses, nitr=5):
"""Find transit times and phase-space states using Newton refinement.
Args:
pidxarr: Orbit indices starting from 0, with shape ``(Ntransit,)``.
tcobsarr: Flattened array of observed transit times of shape
``(Ntransit,)``.
t: Time array of shape ``(Nstep,)``.
xvjac: Jacobi positions and velocities of shape
``(Nstep, 2, Norbit, 3)``.
masses: Mass array of shape ``(Nbody,)``.
nitr: Number of Newton-Raphson iterations.
Returns:
Tuple containing
- transit times as a one-dimensional array, and
- positions, velocities, and Newton steps from the final iteration.
"""
return _find_transit_newton_core(pidxarr, tcobsarr, t, xvjac, masses, nitr)
[docs]
def find_transit_params_fast(pidxarr, tcobsarr, t, xvjac, masses):
"""Find transit times and phase-space states using the fast algorithm.
Args:
pidxarr: Orbit indices starting from 0, with shape ``(Ntransit,)``.
tcobsarr: Flattened array of observed transit times of shape
``(Ntransit,)``.
t: Time array of shape ``(Nstep,)``.
xvjac: Jacobi positions and velocities of shape
``(Nstep, 2, Norbit, 3)``.
masses: Mass array of shape ``(Nbody,)``.
Returns:
Tuple containing
- transit times as a one-dimensional array, and
- positions, velocities, and the time corrections evaluated at
``tcobsarr``.
"""
return _find_transit_params_fast_core(pidxarr, tcobsarr, t, xvjac, masses)
"""TTVFast algorithm."""
def _get_elements(x, v, gm):
"""Compute orbital quantities used by the TTVFast interpolation step.
Args:
x: Positions of shape ``(Norbit, 3)``.
v: Velocities of shape ``(Norbit, 3)``.
gm: ``GM`` for each orbit.
Returns:
Tuple containing the mean motion, ``e cos E0``, ``e sin E0``, and
``a / r0``.
"""
r0 = jnp.sqrt(jnp.sum(x * x, axis=1))
v0s = jnp.sum(v * v, axis=1)
u = jnp.sum(x * v, axis=1)
a = 1.0 / (2.0 / r0 - v0s / gm)
n = jnp.sqrt(gm / (a * a * a))
ecosE0, esinE0 = 1.0 - r0 / a, u / (n * a * a)
return n, ecosE0, esinE0, a / r0
def _find_transit_times_kepler(xast, vast, kast, dt, nitr):
"""Find transit times via the TTVFast interpolation scheme.
Note:
This function is adapted from TTVFast
https://github.com/kdeck/TTVFast, based on the scheme developed by
Nesvorny et al. (2013, ApJ, 777, 3).
Args:
xast: Astrocentric positions of shape ``(Norbit, 3)``.
vast: Astrocentric velocities of shape ``(Norbit, 3)``.
kast: Astrocentric ``GM``.
dt: Integration time step.
nitr: Number of iterations used in the interpolation solve.
Returns:
Time to the transit center.
"""
n, ecosE0, esinE0, a_r0 = _get_elements(xast, vast, kast)
rsquared = jnp.sum(xast[:, :2] * xast[:, :2], axis=1)
vsquared = jnp.sum(vast[:, :2] * vast[:, :2], axis=1)
xdotv = jnp.sum(xast[:, :2] * vast[:, :2], axis=1)
def dEstep_transit(dE, _):
x2 = dE / 2.0
sx2, cx2 = jnp.sin(x2), jnp.cos(x2)
f = 1.0 - a_r0 * 2.0 * sx2 * sx2
sx, cx = 2.0 * sx2 * cx2, cx2 * cx2 - sx2 * sx2
g = (2.0 * sx2 * (esinE0 * sx2 + cx2 / a_r0)) / n
fp = 1.0 - cx * ecosE0 + sx * esinE0
fdot = -(a_r0 / fp) * n * sx
fp2 = sx * ecosE0 + cx * esinE0
gdot = 1.0 - 2.0 * sx2 * sx2 / fp
dgdotdz = -sx / fp + 2.0 * sx2 * sx2 / fp / fp * fp2
dfdz = -a_r0 * sx
dgdz = (sx * esinE0 - (ecosE0 - 1.0) * cx) / n
dfdotdz = -n * a_r0 / fp * (cx + sx / fp * fp2)
dotproduct = (
f * fdot * rsquared
+ g * gdot * vsquared
+ (f * gdot + g * fdot) * xdotv
)
dotproductderiv = (
dfdz * (gdot * xdotv + fdot * rsquared)
+ dfdotdz * (f * rsquared + g * xdotv)
+ dgdz * (fdot * xdotv + gdot * vsquared)
+ dgdotdz * (g * vsquared + f * xdotv)
)
return dE - dotproduct / dotproductderiv, None
dE0 = n * dt / 2.0
dE, _ = scan(dEstep_transit, dE0, jnp.arange(nitr))
x2 = dE / 2.0
sx2, cx2 = jnp.sin(x2), jnp.cos(x2)
sx = 2.0 * sx2 * cx2
transitM = dE + esinE0 * 2.0 * sx2 * sx2 - sx * ecosE0
return transitM / n
_find_transit_times_kepler_map = vmap(
_find_transit_times_kepler, (0, 0, 0, None, None), 0
)
[docs]
def find_transit_times_kepler_all(pidxarr, tcobsarr, t, xvjac, masses, nitr=3):
"""Find transit times via the legacy TTVFast-style interpolation scheme.
Note:
This function is kept for backward compatibility and will be deprecated.
It may fail for large ``dt``.
Args:
pidxarr: Orbit indices starting from 0, with shape ``(Ntransit,)``.
tcobsarr: Flattened array of observed transit times of shape
``(Ntransit,)``.
t: Time array of shape ``(Nstep,)``.
xvjac: Jacobi positions and velocities of shape
``(Nstep, 2, Norbit, 3)``.
masses: Mass array of shape ``(Nbody,)``.
nitr: Number of iterations used in the Kepler interpolation step.
Returns:
Transit times as a one-dimensional array.
"""
xjac, vjac = xvjac[:, 0, :, :], xvjac[:, 1, :, :]
tcflag = _get_tcflag(xjac, vjac)
tcidx = _find_tc_idx_map(t, tcflag, pidxarr, tcobsarr).ravel()
tc_ahead, tc_behind = t[1:][tcidx], t[1:][tcidx - 1]
xvjac_ahead, xvjac_behind = xvjac[1:][tcidx], xvjac[1:][tcidx - 1]
# Bring back the system by dt/2 so that the states are at the
# conclusions of the symplectic step. If the transit is not bracketed
# after this shift, advance the system by dt again.
dt = jnp.diff(t)[0]
dt2 = 0.5 * dt
xjac_ahead_mindt2, vjac_ahead_mindt2 = kepler_step_map(
xvjac_ahead[:, 0, :, :], xvjac_ahead[:, 1, :, :], masses, -dt2
)
xjac_behind_mindt2, vjac_behind_mindt2 = kepler_step_map(
xvjac_behind[:, 0, :, :], xvjac_behind[:, 1, :, :], masses, -dt2
)
xjac_ahead_plusdt2, vjac_ahead_plusdt2 = kick_kepler_map(
xvjac_ahead[:, 0, :, :], xvjac_ahead[:, 1, :, :], masses, dt2
)
tcflag_mindt2 = _get_g_map(
xjac_ahead_mindt2, vjac_ahead_mindt2, pidxarr) > 0.0
def _select(mask, left, right):
return jnp.where(mask, left, right)
select_map = vmap(_select, (0, 0, 0), 0)
xjac_ahead = select_map(
tcflag_mindt2, xjac_ahead_mindt2, xjac_ahead_plusdt2)
xjac_behind = select_map(
tcflag_mindt2, xjac_behind_mindt2, xjac_ahead_mindt2)
vjac_ahead = select_map(
tcflag_mindt2, vjac_ahead_mindt2, vjac_ahead_plusdt2)
vjac_behind = select_map(
tcflag_mindt2, vjac_behind_mindt2, vjac_ahead_mindt2)
tc_ahead = jnp.where(tcflag_mindt2, tc_ahead - dt2, tc_ahead + dt2)
tc_behind = jnp.where(tcflag_mindt2, tc_behind - dt2, tc_behind + dt2)
xast_ahead, vast_ahead = jacobi_to_astrocentric(
xjac_ahead, vjac_ahead, masses)
xast_behind, vast_behind = jacobi_to_astrocentric(
xjac_behind, vjac_behind, masses)
kast = G * (masses[1:] + masses[0])
kastarr = kast[pidxarr]
tau_ahead = tc_ahead + jnp.diag(
_find_transit_times_kepler_map(
xast_ahead, vast_ahead, kastarr, -dt, nitr)[:, pidxarr]
)
tau_behind = tc_behind + jnp.diag(
_find_transit_times_kepler_map(
xast_behind, vast_behind, kastarr, dt, nitr)[:, pidxarr]
)
tc = (
(tau_behind - tc_behind) * tau_ahead
+ (tc_ahead - tau_ahead) * tau_behind
) / (dt + tau_behind - tau_ahead)
return tc