Source code for jnkepler.jaxttv.symplectic

""" Symplectic integrator. JAX-native implementation following the symplectic
approach used in TTVFast (https://github.com/kdeck/TTVFast).
"""

__all__ = ["integrate_xv", "kepler_step_map",
           "kick_kepler_map", "kepler_kick_kepler_map"]

import jax.numpy as jnp
from functools import partial
from jax import jit, vmap, grad, config, checkpoint, custom_jvp
from jax.lax import scan, while_loop

from .conversion import G

config.update("jax_enable_x64", True)


def _compute_ki(masses):
    """Return the effective two-body GM values for Jacobi coordinates.

    Args:
        masses: masses of the bodies (Nbody,), in units of solar mass

    Returns:
        GM values for each Jacobi orbit, shape (Norbit,)
    """
    csum = jnp.cumsum(masses)
    return G * masses[0] * csum[1:] / jnp.hstack([masses[0], csum[1:][:-1]])


def dEstep(x, ecosE0, esinE0, dM):
    """single step to solve incremental Kepler's equation to obtain
    delta(eccentric anomaly)

    Args:
        x: initial guess for dE
        ecosE0, esinE0: eccentricity and eccentric anomaly at the initial state
        dM: delta(mean anomaly)

    Returns:
        dE: updated estimate for delta(eccentric anomaly)
    """
    x2 = x / 2.0  # x = dE
    sx2, cx2 = jnp.sin(x2), jnp.cos(x2)
    sx = 2.0 * sx2 * cx2
    cx = cx2 * cx2 - sx2 * sx2

    f = x + 2.0 * sx2 * (sx2 * esinE0 - cx2 * ecosE0) - dM
    ecosE = cx * ecosE0 - sx * esinE0
    fp = 1.0 - ecosE
    fpp = (sx * ecosE0 + cx * esinE0) / 2.0
    fppp = ecosE / 6.0

    # update (third order)
    dx = -f / fp
    dx = -f / (fp + dx * fpp)
    dx = -f / (fp + dx * (fpp + dx * fppp))
    return x + dx


def _solve_dE_while(ecosE0, esinE0, dM, max_iter, tol):
    ecosE0 = jnp.asarray(ecosE0)
    esinE0 = jnp.asarray(esinE0)
    dM = jnp.asarray(dM)

    dtype = jnp.result_type(ecosE0, esinE0, dM)
    ecosE0 = ecosE0.astype(dtype)
    esinE0 = esinE0.astype(dtype)
    dM = dM.astype(dtype)

    def newton_update(dE):
        return dEstep(dE, ecosE0, esinE0, dM)

    i0 = jnp.int32(0)
    dE0 = dM
    err0 = jnp.array(jnp.inf, dtype=dtype)

    def cond(carry):
        i, dE, err = carry
        return jnp.logical_and(i < max_iter, err > tol)

    def body(carry):
        i, dE, _ = carry
        dE_next = newton_update(dE)
        err_next = jnp.max(jnp.abs(dE_next - dE))
        return (i + 1, dE_next, err_next)

    _, dE, _ = while_loop(cond, body, (i0, dE0, err0))
    return dE


@partial(custom_jvp, nondiff_argnums=(3, 4))
def solve_dE(ecosE0, esinE0, dM, max_iter=10, tol=1e-12):
    """
    Solve for the eccentric-anomaly increment dE in the Kepler step.

    This function solves the scalar equation F(dE)=0 for dE, where
        F(dE) = dE + (1 - cos dE) * (e sin E0) - (sin dE) * (e cos E0) - dM.

    Notes:
    - Uses Newton iterations (up to `max_iter`) with stopping tolerance `tol`.
    - A custom JVP is provided via implicit differentiation of F(dE)=0,
      so autodiff does not backpropagate through the Newton iterations.
    - `max_iter` and `tol` are treated as non-differentiable arguments.

    Args:
        ecosE0: e cos(E0), shape (Norbit,).
        esinE0: e sin(E0), shape (Norbit,).
        dM: Mean-anomaly increment n*dt, shape (Norbit,).
        max_iter: Maximum number of Newton iterations.
        tol: Convergence tolerance based on |dE_{k+1} - dE_k|.

    Returns:
        dE: Eccentric-anomaly increment, shape (Norbit,).
    """
    return _solve_dE_while(ecosE0, esinE0, dM, max_iter, tol)


@solve_dE.defjvp
def solve_dE_jvp(max_iter, tol, primals, tangents):
    ecosE0, esinE0, dM = primals
    ecosE0_dot, esinE0_dot, dM_dot = tangents

    dE = solve_dE(ecosE0, esinE0, dM, max_iter=max_iter, tol=tol)

    # implicit differentiation:
    # F(dE; ecosE0, esinE0, dM)
    # = dE + (1 - cos dE) * esinE0 - sin(dE) * ecosE0 - dM = 0
    s, c = jnp.sin(dE), jnp.cos(dE)

    # ∂F/∂dE
    fp = 1.0 + s * esinE0 - c * ecosE0
    # fp = jnp.where(jnp.abs(fp) > 1e-12, fp, jnp.sign(fp) * 1e-12)

    # with
    #   ∂F/∂ecosE0 = -sin(dE) = -s
    #   ∂F/∂esinE0 = 1 - cos(dE) = 1 - c
    #   ∂F/∂dM     = -1
    dE_dot = (s * ecosE0_dot - (1.0 - c) * esinE0_dot + dM_dot) / fp
    return dE, dE_dot


def kepler_step(x, v, gm, dt, nitr=10):
    """Kepler step (two-body drift).

    Given Cartesian position/velocity, advance the state by `dt`
    assuming two-body Keplerian motion under the gravitational
    parameter `gm`.

    Notes:
        The eccentric-anomaly increment ``dE`` is obtained by solving a scalar
        Kepler equation with Newton iterations via `solve_dE`.
        The argument `nitr` is passed as `max_iter` to `solve_dE`, i.e.,
        it sets the maximum number of Newton iterations (not an unrolled
        loop length).

    Args:
        x: positions (Norbit, xyz)
        v: velocities (Norbit, xyz)
        gm: 'GM' in Kepler's 3rd law
        dt: time step
        nitr: maximum number of Newton iterations

    Returns:
        tuple:
            - new positions (Norbit, xyz)
            - new velocities (Norbit, xyz)
    """
    r0 = jnp.sqrt(jnp.sum(x * x, axis=1))
    v0s = jnp.sum(v * v, axis=1)
    u = jnp.sum(x * v, axis=1)
    a = 1.0 / (2.0 / r0 - v0s / gm)
    n = jnp.sqrt(gm / (a * a * a))
    ecosE0 = 1.0 - r0 / a
    esinE0 = u / (n * a * a)
    dM = n * dt
    dE = solve_dE(ecosE0, esinE0, dM, max_iter=nitr, tol=1e-12)

    x2 = dE / 2.0
    sx2, cx2 = jnp.sin(x2), jnp.cos(x2)
    sx = 2.0 * sx2 * cx2
    cx = cx2 * cx2 - sx2 * sx2

    f = 1.0 - (a / r0) * (2.0 * sx2 * sx2)
    g = (2.0 * sx2 * (esinE0 * sx2 + cx2 * r0 / a)) / n
    fp = 1.0 - cx * ecosE0 + sx * esinE0
    fdot = -(a / (r0 * fp)) * n * sx
    gdot = (1.0 + g * fdot) / f

    x_new = f[:, None] * x + g[:, None] * v
    v_new = fdot[:, None] * x + gdot[:, None] * v
    return x_new, v_new


def mmat_from_masses(masses):
    """Construct the mass matrix used in the Jacobi-to-astrocentric transform.

    This returns the same lower-triangular mass matrix as used in
    `jacobi_to_astrocentric`. The matrix depends only on the body masses
    and is typically used to convert between coordinate conventions.

    Args:
        masses: Masses of the bodies, shape (Nbody,).

    Returns:
        mmat: Mass matrix, shape (Nbody-1, Nbody-1).
    """
    nbody = len(masses)
    mp = masses[1:]
    return jnp.eye(nbody - 1) + jnp.tril(
        jnp.tile(mp / jnp.cumsum(masses)[1:], (nbody - 1, 1)),
        k=-1,
    )


def Hintgrad(xjac, vjac, masses):
    """gradient of the interaction Hamiltonian times (star mass / planet mass)

    Args:
        x: positions (Norbit, xyz)
        v: velocities (Norbit, xyz)
        masses: masses of the bodies (Nbody,), solar unit

    Returns:
        gradient of interaction Hamiltonian x (star mass / planet mass)
    """
    m0 = masses[0]
    mp = masses[1:]  # (N,)
    mu = mp / m0  # (N,)
    M = mmat_from_masses(masses)  # (N,N)

    # ---- term 1: + sum_i mu_i / |xjac_i| (Jacobi) ----
    r2 = jnp.sum(xjac * xjac, axis=1)
    r = jnp.sqrt(r2)
    inv_r3 = 1.0 / (r2 * r)
    g_jac = -(mu[:, None] * xjac) * inv_r3[:, None]  # d/dxjac of +sum mu/|x|

    # ---- astrocentric ----
    xast = M @ xjac

    # term 2: - sum_i mu_i / |xast_i|
    r2a = jnp.sum(xast * xast, axis=1)
    ra = jnp.sqrt(r2a)
    inv_r3a = 1.0 / (r2a * ra)
    # d/dxast of (-sum mu/|xast|)
    g_ast_sp = +(mu[:, None] * xast) * inv_r3a[:, None]

    # term 3: -0.5 * sum_{i,j} mu_i mu_j / |xast_i - xast_j|
    diff = xast[:, None, :] - xast[None, :, :]  # (N,N,3)
    d2 = jnp.sum(diff * diff, axis=-1)  # (N,N)
    nz = d2 != 0.0
    d2_safe = jnp.where(nz, d2, 1.0)
    inv_d3 = jnp.where(nz, 1.0 / (d2_safe * jnp.sqrt(d2_safe)), 0.0)
    w = (mu[:, None] * mu[None, :]) * inv_d3  # (N,N)

    # Because Hint uses -0.5 * sum_{i,j}, the gradient becomes
    # + sum_j mu_i mu_j (x_i-x_j)/r^3
    g_ast_pp = +jnp.sum(w[:, :, None] * diff, axis=1)  # (N,3)

    g_ast = g_ast_sp + g_ast_pp  # total dHint/dxast

    # chain rule back to Jacobi: xast = M @ xjac => dH/dxjac += M^T @ dH/dxast
    g_from_ast = M.T @ g_ast

    g = g_jac + g_from_ast
    return g * (m0 / mp)[:, None]


def nbody_kicks(x, v, ki, masses, dt):
    """apply N-body kicks to velocities

    Args:
        x: positions (Norbit, xyz)
        v: velocities (Norbit, xyz)
        ki: GM values
        masses: masses of the bodies (Nbody,), solar unit
        dt: time step

    Returns:
        tuple:
            - positions
            - kicked velocities
    """
    dv = -ki[:, None] * dt * Hintgrad(x, v, masses)
    return x, v + dv


[docs] def integrate_xv(x, v, masses, times, nitr=10): """Integrate Jacobi positions and velocities using the Wisdom-Holman map. This function currently assumes that `times` are uniformly spaced. Internally, the map is initialized using the first time interval and then advanced with a fixed step size, so non-uniform `times` are not supported. Args: x: initial Jacobi positions in Cartesian coordinates (Norbit, xyz) v: initial Jacobi velocities in Cartesian coordinates (Norbit, xyz) masses: masses of the bodies (Nbody,), in units of solar mass times: 1D array of uniformly spaced output times nitr: number of iterations in Kepler's equation solver Returns: tuple: - times at the middle of each map step - integrated Jacobi positions and velocities """ ki = _compute_ki(masses) dtarr = jnp.diff(times) # transformation between the mapping and real Hamiltonian # dt/2 ahead of the starting time x, v = real_to_mapTO(x, v, ki, masses, dtarr[0]) # advance the system by dt x, v = kepler_step(x, v, ki, dtarr[0] * 0.5, nitr=nitr) def step(xvin, dt): x, v = xvin x, v = nbody_kicks(x, v, ki, masses, dt) xout, vout = kepler_step(x, v, ki, dt, nitr=nitr) return [xout, vout], jnp.array([xout, vout]) step = checkpoint(step) _, xv = scan(step, [x, v], dtarr) return times[1:] + 0.5 * dtarr[0], xv
[docs] def kepler_step_map(xjac, vjac, masses, dt, nitr=10): """vmap version of kepler_step; map along the first axis (Ntime) Args: xjac: Jacobi positions (Ntime, Norbit, xyz) vjac: Jacobi velocities (Ntime, Norbit, xyz) masses: masses of the bodies (Nbody,), in units of solar mass dt: common time step Returns: new Jacobi positions and velocities (Ntime, x or v, Norbit, xyz) """ ki = _compute_ki(masses) def step(x, v): return kepler_step(x, v, ki, dt, nitr=nitr) step_map = vmap(step, (0, 0), 0) return step_map(xjac, vjac)
[docs] def kick_kepler_map(xjac, vjac, masses, dt, nitr=10): """vmap version of nbody_kicks + kepler_step; map along the first axis (Ntime) Args: xjac: jacobi positions (Ntime, Norbit, xyz) vjac: jacobi velocities (Ntime, Norbit, xyz) masses: masses of the bodies (Nbody,), in units of solar mass dt: common time step Returns: new jacobi positions and velocities (Ntime, x or v, Norbit, xyz) """ ki = _compute_ki(masses) def kick_kepler(x, v): x, v = nbody_kicks(x, v, ki, masses, 2 * dt) return kepler_step(x, v, ki, dt, nitr=nitr) func_map = vmap(kick_kepler, (0, 0), 0) return func_map(xjac, vjac)
[docs] def kepler_kick_kepler_map(xjac, vjac, masses, dt, nitr=10): """vmap version of kepler_step + nbody_kicks + kepler_step. This applies a full KDK-like step written in the order kepler(dt/2) -> kick(dt) -> kepler(dt/2) to each element along the first axis. Args: xjac: Jacobi positions (Ntime, Norbit, xyz) vjac: Jacobi velocities (Ntime, Norbit, xyz) masses: masses of the bodies (Nbody,), in units of solar mass dt: common full time step Returns: new Jacobi positions and velocities (Ntime, x or v, Norbit, xyz) """ ki = _compute_ki(masses) def kepler_kick_kepler(x, v, dt): x, v = kepler_step(x, v, ki, 0.5 * dt, nitr=nitr) x, v = nbody_kicks(x, v, ki, masses, dt) return kepler_step(x, v, ki, 0.5 * dt, nitr=nitr) func_map = vmap(kepler_kick_kepler, (0, 0, 0), 0) return func_map(xjac, vjac, dt)
def compute_corrector_coefficientsTO(): """coefficients for the third-order corrector""" corr_alpha = jnp.sqrt(7.0 / 40.0) corr_beta = 1.0 / (48.0 * corr_alpha) TOa1, TOa2 = -corr_alpha, corr_alpha TOb1, TOb2 = -0.5 * corr_beta, 0.5 * corr_beta return TOa1, TOa2, TOb1, TOb2 def corrector_step(x, v, ki, masses, a, b): """corrector step Args: x: positions (Norbit, xyz) v: velocities (Norbit, xyz) ki: GM values masses: masses of the bodies (Nbody,), solar unit a, b: corrector steps Returns: new positions and velocities """ _x, _v = kepler_step(x, v, ki, -a) _x, _v = nbody_kicks(_x, _v, ki, masses, b) _x, _v = kepler_step(_x, _v, ki, a) return _x, _v def real_to_mapTO(x, v, ki, masses, dt): """transformation between real and mapping coordinates Args: x: positions (Norbit, xyz) v: velocities (Norbit, xyz) ki: GM values masses: masses of the bodies (Nbody,), solar unit dt: time step Returns: mapped positions and velocities """ TOa1, TOa2, TOb1, TOb2 = compute_corrector_coefficientsTO() _x, _v = corrector_step(x, v, ki, masses, TOa2 * dt, TOb2 * dt) _x, _v = corrector_step(_x, _v, ki, masses, TOa1 * dt, TOb1 * dt) return _x, _v