Source code for jnkepler.jaxttv.symplectic

"""
Symplectic integrator.

JAX-native implementation following the symplectic approach used in TTVFast (https://github.com/kdeck/TTVFast).
"""

__all__ = [
    "integrate_xv", "kepler_step_map", "kick_kepler_map"
]

import jax.numpy as jnp
from functools import partial
from jax import jit, vmap, grad, config, checkpoint, custom_vjp
from jax.lax import scan, while_loop
from .conversion import G
config.update('jax_enable_x64', True)


def dEstep(x, ecosE0, esinE0, dM):
    """single step to solve incremental Kepler's equation to obtain delta(eccentric anomaly)

        Args:
            x: initial guess for dE
            ecosE0, esinE0: eccentricity and eccentric anomaly at the initial state
            dM: delta(mean anomaly)

        Returns:
            dE: updated estimate for delta(eccentric anomaly)

    """
    x2 = x / 2.0  # x = dE
    sx2, cx2 = jnp.sin(x2), jnp.cos(x2)

    sx = 2.0 * sx2 * cx2
    cx = cx2 * cx2 - sx2 * sx2

    f = x + 2.0 * sx2 * (sx2 * esinE0 - cx2 * ecosE0) - dM

    ecosE = cx * ecosE0 - sx * esinE0
    fp = 1.0 - ecosE
    fpp = (sx * ecosE0 + cx * esinE0) / 2.0
    fppp = ecosE / 6.0

    # update (third order)
    dx = -f / fp
    dx = -f / (fp + dx * fpp)
    dx = -f / (fp + dx * (fpp + dx * fppp))

    return x + dx


def _solve_dE_while(ecosE0, esinE0, dM, max_iter, tol):

    def newton_update(dE):
        return dEstep(dE, ecosE0, esinE0, dM)

    i0 = jnp.int32(0)
    dE0 = dM
    err0 = jnp.array(jnp.inf, dtype=dE0.dtype)

    def cond(carry):
        i, dE, err = carry
        return jnp.logical_and(i < max_iter, err > tol)

    def body(carry):
        i, dE, _ = carry
        dE_next = newton_update(dE)
        err_next = jnp.max(jnp.abs(dE_next - dE))
        return (i + 1, dE_next, err_next)

    _, dE, _ = while_loop(cond, body, (i0, dE0, err0))
    return dE


@partial(custom_vjp, nondiff_argnums=(3, 4))
def solve_dE(ecosE0, esinE0, dM, max_iter=10, tol=1e-12):
    """
    Solve for the eccentric-anomaly increment dE in the Kepler step.

    This function solves the scalar equation F(dE)=0 for dE, where
        F(dE) = dE + (1 - cos dE) * (e sin E0) - (sin dE) * (e cos E0) - dM.

    Notes:
        - Uses Newton iterations (up to `max_iter`) with stopping tolerance `tol`.
        - A custom VJP is provided via implicit differentiation of F(dE)=0, so
          gradients do not backpropagate through the Newton iterations.
        - `max_iter` and `tol` are treated as non-differentiable arguments.

    Args:
        ecosE0: e cos(E0), shape (Norbit,).
        esinE0: e sin(E0), shape (Norbit,).
        dM: Mean-anomaly increment n*dt, shape (Norbit,).
        max_iter: Maximum number of Newton iterations.
        tol: Convergence tolerance based on |dE_{k+1} - dE_k|.

    Returns:
        dE: Eccentric-anomaly increment, shape (Norbit,).
    """
    return _solve_dE_while(ecosE0, esinE0, dM, max_iter, tol)


def solve_dE_fwd(ecosE0, esinE0, dM, max_iter=10, tol=1e-12):
    dE = _solve_dE_while(ecosE0, esinE0, dM, max_iter, tol)
    return dE, (dE, ecosE0, esinE0, dM)


def solve_dE_bwd(max_iter, tol, res, dE_bar):
    dE, ecosE0, esinE0, dM = res

    # implicit VJP:F(dE; ecosE0, esinE0, dM)=0
    s, c = jnp.sin(dE), jnp.cos(dE)
    # dF/ddE
    fp = 1.0 + s * esinE0 - c * ecosE0

    # fp = jnp.where(jnp.abs(fp) > 1e-12, fp, jnp.sign(fp) * 1e-12)

    ecosE0_bar = dE_bar * (s / fp)
    esinE0_bar = dE_bar * (-(1.0 - c) / fp)
    dM_bar = dE_bar * (1.0 / fp)
    return (ecosE0_bar, esinE0_bar, dM_bar)


solve_dE.defvjp(solve_dE_fwd, solve_dE_bwd)


def kepler_step(x, v, gm, dt, nitr=10):
    """Kepler step (two-body drift).

        Given Cartesian position/velocity, advance the state by `dt` assuming
        two-body Keplerian motion under the gravitational parameter `gm`.

        Notes:
            The eccentric-anomaly increment ``dE`` is obtained by solving a
            scalar Kepler equation with Newton iterations via `solve_dE`.
            The argument `nitr` is passed as `max_iter` to `solve_dE`, i.e.,
            it sets the maximum number of Newton iterations (not an unrolled
            loop length).

        Args:
            x: positions (Norbit, xyz)
            v: velocities (Norbit, xyz)
            gm: 'GM' in Kepler's 3rd law
            dt: time step
            nitr: maximum number of Newton iterations

        Returns:
            tuple:
                - new positions (Norbit, xyz)
                - new velocities (Norbit, xyz)

    """
    r0 = jnp.sqrt(jnp.sum(x * x, axis=1))
    v0s = jnp.sum(v * v, axis=1)
    u = jnp.sum(x * v, axis=1)

    a = 1.0 / (2.0 / r0 - v0s / gm)
    n = jnp.sqrt(gm / (a * a * a))

    ecosE0 = 1.0 - r0 / a
    esinE0 = u / (n * a * a)

    dM = n * dt

    dE = solve_dE(ecosE0, esinE0, dM, max_iter=nitr, tol=1e-12)

    x2 = dE / 2.0
    sx2, cx2 = jnp.sin(x2), jnp.cos(x2)

    sx = 2.0 * sx2 * cx2
    cx = cx2 * cx2 - sx2 * sx2

    f = 1.0 - (a / r0) * (2.0 * sx2 * sx2)
    g = (2.0 * sx2 * (esinE0 * sx2 + cx2 * r0 / a)) / n

    fp = 1.0 - cx * ecosE0 + sx * esinE0
    fdot = -(a / (r0 * fp)) * n * sx
    gdot = (1.0 + g * fdot) / f

    x_new = f[:, None] * x + g[:, None] * v
    v_new = fdot[:, None] * x + gdot[:, None] * v

    return x_new, v_new


def mmat_from_masses(masses):
    """Construct the mass matrix used in the Jacobi-to-astrocentric transform.

    This returns the same lower-triangular mass matrix as used in
    `jacobi_to_astrocentric`. The matrix depends only on the body masses
    and is typically used to convert between coordinate conventions.

    Args:
        masses: Masses of the bodies, shape (Nbody,).

    Returns:
        mmat: Mass matrix, shape (Nbody-1, Nbody-1).

    """
    nbody = len(masses)
    mp = masses[1:]
    return jnp.eye(nbody - 1) + jnp.tril(
        jnp.tile(mp / jnp.cumsum(masses)[1:], (nbody - 1, 1)), k=-1
    )


def Hintgrad(xjac, vjac, masses):
    """gradient of the interaction Hamiltonian times (star mass / planet mass)

        Args:
            x: positions (Norbit, xyz)
            v: velocities (Norbit, xyz)
            masses: masses of the bodies (Nbody,), solar unit

        Returns:
            gradient of interaction Hamiltonian x (star mass / planet mass)

    """
    m0 = masses[0]
    mp = masses[1:]              # (N,)
    mu = mp / m0                 # (N,)
    M = mmat_from_masses(masses)  # (N,N)

    # ---- term 1: + sum_i mu_i / |xjac_i|  (Jacobi) ----
    r2 = jnp.sum(xjac * xjac, axis=1)
    r = jnp.sqrt(r2)
    inv_r3 = 1.0 / (r2 * r)
    g_jac = -(mu[:, None] * xjac) * inv_r3[:, None]   # d/dxjac of +sum mu/|x|

    # ---- astrocentric ----
    xast = M @ xjac

    # term 2: - sum_i mu_i / |xast_i|
    r2a = jnp.sum(xast * xast, axis=1)
    ra = jnp.sqrt(r2a)
    inv_r3a = 1.0 / (r2a * ra)
    # d/dxast of (-sum mu/|xast|)
    g_ast_sp = +(mu[:, None] * xast) * inv_r3a[:, None]

    # term 3: -0.5 * sum_{i,j} mu_i mu_j / |xast_i - xast_j|
    diff = xast[:, None, :] - xast[None, :, :]           # (N,N,3)
    d2 = jnp.sum(diff * diff, axis=-1)                   # (N,N)
    nz = d2 != 0.0
    d2_safe = jnp.where(nz, d2, 1.0)
    inv_d3 = jnp.where(nz, 1.0 / (d2_safe * jnp.sqrt(d2_safe)), 0.0)

    w = (mu[:, None] * mu[None, :]) * inv_d3            # (N,N)

    # Because Hint uses -0.5 * sum_{i,j}, the gradient becomes + sum_j mu_i mu_j (x_i-x_j)/r^3
    g_ast_pp = +jnp.sum(w[:, :, None] * diff, axis=1)    # (N,3)

    g_ast = g_ast_sp + g_ast_pp                          # total dHint/dxast

    # chain rule back to Jacobi: xast = M @ xjac  => dH/dxjac += M^T @ dH/dxast
    g_from_ast = M.T @ g_ast

    g = g_jac + g_from_ast

    # correct scaling
    return g * (m0 / mp)[:, None]


def nbody_kicks(x, v, ki, masses, dt):
    """apply N-body kicks to velocities

        Args:
            x: positions (Norbit, xyz)
            v: velocities (Norbit, xyz)
            ki: GM values
            masses: masses of the bodies (Nbody,), solar unit
            dt: time step

        Returns:
            tuple:
                - positions
                - kicked velocities

    """
    dv = -ki[:, None] * dt * Hintgrad(x, v, masses)
    return x, v + dv


[docs] def integrate_xv(x, v, masses, times, nitr=10): """symplectic integration of the orbits Args: x: initial Jacobi positions (Norbit, xyz) v: initial Jacobi velocities (Norbit, xyz) masses: masses of the bodies (Nbody,), in units of solar mass times: cumulative sum of time steps Returns: tuple: - times (initial time omitted; dt/2 ahead of the input) - Jacobi position/velocity array (Nstep, x or v, Norbit, xyz) """ ki = G * masses[0] * jnp.cumsum(masses)[1:] / \ jnp.hstack([masses[0], jnp.cumsum(masses)[1:][:-1]]) dtarr = jnp.diff(times) # transformation between the mapping and real Hamiltonian x, v = real_to_mapTO(x, v, ki, masses, dtarr[0]) # dt/2 ahead of the starting time x, v = kepler_step(x, v, ki, dtarr[0] * 0.5, nitr=nitr) # advance the system by dt def step(xvin, dt): x, v = xvin x, v = nbody_kicks(x, v, ki, masses, dt) xout, vout = kepler_step(x, v, ki, dt, nitr=nitr) return [xout, vout], jnp.array([xout, vout]) step = checkpoint(step) _, xv = scan(step, [x, v], dtarr) return times[1:] + 0.5 * dtarr[0], xv
[docs] def kepler_step_map(xjac, vjac, masses, dt, nitr=10): """vmap version of kepler_step; map along the first axis (Ntime) Args: xjac: Jacobi positions (Ntime, Norbit, xyz) vjac: Jacobi velocities (Ntime, Norbit, xyz) masses: masses of the bodies (Nbody,), in units of solar mass dt: common time step Returns: new Jacobi positions and velocities (Ntime, x or v, Norbit, xyz) """ ki = G * masses[0] * jnp.cumsum(masses)[1:] / \ jnp.hstack([masses[0], jnp.cumsum(masses)[1:][:-1]]) def step(x, v): return kepler_step(x, v, ki, dt, nitr=nitr) step_map = vmap(step, (0, 0), 0) return step_map(xjac, vjac)
[docs] def kick_kepler_map(xjac, vjac, masses, dt, nitr=10): """vmap version of nbody_kicks + kepler_step; map along the first axis (Ntime) Args: xjac: jacobi positions (Ntime, Norbit, xyz) vjac: jacobi velocities (Ntime, Norbit, xyz) masses: masses of the bodies (Nbody,), in units of solar mass dt: common time step Returns: new jacobi positions and velocities (Ntime, x or v, Norbit, xyz) """ ki = G * masses[0] * jnp.cumsum(masses)[1:] / \ jnp.hstack([masses[0], jnp.cumsum(masses)[1:][:-1]]) def kick_kepler(x, v): x, v = nbody_kicks(x, v, ki, masses, 2*dt) return kepler_step(x, v, ki, dt, nitr=nitr) func_map = vmap(kick_kepler, (0, 0), 0) return func_map(xjac, vjac)
def compute_corrector_coefficientsTO(): """coefficients for the third-order corrector""" corr_alpha = jnp.sqrt(7./40.) corr_beta = 1. / (48.0 * corr_alpha) TOa1, TOa2 = -corr_alpha, corr_alpha TOb1, TOb2 = -0.5 * corr_beta, 0.5 * corr_beta return TOa1, TOa2, TOb1, TOb2 def corrector_step(x, v, ki, masses, a, b): """corrector step Args: x: positions (Norbit, xyz) v: velocities (Norbit, xyz) ki: GM values masses: masses of the bodies (Nbody,), solar unit a, b: corrector steps Returns: new positions and velocities """ _x, _v = kepler_step(x, v, ki, -a) _x, _v = nbody_kicks(_x, _v, ki, masses, b) _x, _v = kepler_step(_x, _v, ki, a) return _x, _v def real_to_mapTO(x, v, ki, masses, dt): """transformation between real and mapping coordinates Args: x: positions (Norbit, xyz) v: velocities (Norbit, xyz) ki: GM values masses: masses of the bodies (Nbody,), solar unit dt: time step Returns: mapped positions and velocities """ TOa1, TOa2, TOb1, TOb2 = compute_corrector_coefficientsTO() _x, _v = corrector_step(x, v, ki, masses, TOa2*dt, TOb2*dt) _x, _v = corrector_step(_x, _v, ki, masses, TOa1*dt, TOb1*dt) return _x, _v