__all__ = ["slogdet_jkep_jax", "det_jkep_am",
"det_jkep_atau", "det_jkep_pm", "det_jkep_ptau",
"slogdet_jkep2D_jax", "det_jkep2D_am", "det_jkep2D_pm"]
from jax import jit, jacrev
from functools import partial
from jnkepler.keplerian import elements_to_xv
from jnkepler.jaxttv.conversion import G
import jax.numpy as jnp
[docs]
@partial(jit, static_argnums=(1,))
def slogdet_jkep_jax(params, keys):
"""
Compute the Jacobian determinant for the mapping from orbital elements to
Cartesian state vectors in the 3D Kepler problem using JAX autodiff
Args:
params (dict): Dictionary of orbital-element parameters.
keys (list[str]): Parameter names with respect to which the Jacobian
is computed. Only these parameters are differentiated.
Typical choices correspond to combinations such as
(a, ecc, inc, lnode, omega, M), (a, ecc, inc, lnode, omega, tau),
(period, ecc, inc, lnode, omega, M).
Returns:
tuple: (sign, log_abs_det) from `jnp.linalg.slogdet`, where the Jacobian
is taken with respect to the flattened state vector (x, v).
"""
def func(params):
if 'a' in keys:
params['period'] = 2 * jnp.pi * \
jnp.sqrt(params['a']**3 / G / params['mass'])
elif 'n' in keys:
params['period'] = 2 * jnp.pi / params['n']
if 'M' in keys:
params['tau'] = params['t_ref'] - \
params['M'] * params['period'] / 2 / jnp.pi
out = elements_to_xv(0., params)
return jnp.hstack([out['x'], out['v']])
Jdict = jacrev(func)(params)
Jarr = jnp.stack([Jdict[k].reshape(-1) for k in keys], axis=1)
return jnp.linalg.slogdet(Jarr)
[docs]
@jit
def det_jkep_am(params):
"""
Analytic Jacobian determinant for the transformation
(a, e, i, Omega, omega, M) → (x, y, z, vx, vy, vz).
Args:
params (dict): Orbital-element parameters.
Returns:
float: Analytic Jacobian determinant.
"""
mu, a, e, sini = G * \
params['mass'], params['a'], params['ecc'], jnp.sin(params['inc'])
det = 0.5 * mu**(1.5) * a**(0.5) * e * sini
return det
[docs]
@jit
def det_jkep_atau(params):
"""
Analytic Jacobian determinant for the transformation
(a, e, i, Omega, omega, tau) → (x, y, z, vx, vy, vz).
Args:
params (dict): Orbital-element parameters.
Returns:
float: Analytic Jacobian determinant.
"""
det = det_jkep_am(params)
n = 2 * jnp.pi / params['period']
return -n * det
[docs]
@jit
def det_jkep_pm(params):
"""
Analytic Jacobian determinant for the transformation
(P, e, i, Omega, omega, M) → (x, y, z, vx, vy, vz).
Args:
params (dict): Orbital-element parameters.
Returns:
float: Analytic Jacobian determinant.
"""
mu, e, sini = G * params['mass'], params['ecc'], jnp.sin(params['inc'])
return mu**2 / (6 * jnp.pi) * e * sini
[docs]
@jit
def det_jkep_ptau(params):
"""
Analytic Jacobian determinant for the transformation
(P, e, i, Omega, omega, tau) → (x, y, z, vx, vy, vz).
Args:
params (dict): Orbital-element parameters.
Returns:
float: Analytic Jacobian determinant.
"""
det = det_jkep_pm(params)
n = 2 * jnp.pi / params['period']
return -n * det
[docs]
@partial(jit, static_argnums=(1,))
def slogdet_jkep2D_jax(params, keys):
"""
Compute the Jacobian determinant for the mapping from orbital elements to
Cartesian state vectors in the 2D Kepler problem using JAX autodiff
Args:
params (dict): Dictionary of orbital-element parameters.
keys (list[str]): Parameter names with respect to which the Jacobian
is computed. Only these parameters are differentiated.
Typical choices correspond to combinations such as
(a, ecc, omega, M), (period, ecc, omega, M).
Returns:
tuple: (sign, log_abs_det) from `jnp.linalg.slogdet`, where the Jacobian
is taken with respect to the flattened state vector (x, v).
"""
def func(params):
if 'a' in keys:
params['period'] = 2 * jnp.pi * \
jnp.sqrt(params['a']**3 / G / params['mass'])
elif 'n' in keys:
params['period'] = 2 * jnp.pi / params['n']
if 'M' in keys:
params['tau'] = params['t_ref'] - \
params['M'] * params['period'] / 2 / jnp.pi
out = elements_to_xv(0., params | {'lnode': 0, 'inc': jnp.pi / 2.})
return jnp.hstack([out['x'][0][0], out['x'][0][2], out['v'][0][0], out['v'][0][2]])
exclude = {'lnode', 'inc'}
params4 = {k: v for k, v in params.items() if k not in exclude}
Jdict = jacrev(func)(params4)
Jarr = jnp.stack([Jdict[k].reshape(-1) for k in keys], axis=1)
return jnp.linalg.slogdet(Jarr)
[docs]
@jit
def det_jkep2D_am(params):
"""
Analytic Jacobian determinant for the transformation
(a, e, omega, M) → (x, z, vx, vz) in the 2D Kepler problem.
Args:
params (dict): Orbital-element parameters.
Returns:
float: Analytic Jacobian determinant.
"""
mu, e = G * params['mass'], params['ecc']
det = 0.5 * mu * e / jnp.sqrt(1. - e**2)
return det
[docs]
@jit
def det_jkep2D_pm(params):
"""
Analytic Jacobian determinant for the transformation
(P, e, omega, M) → (x, z, vx, vz) in the 2D Kepler problem.
Args:
params (dict): Orbital-element parameters.
Returns:
float: Analytic Jacobian determinant.
"""
mu, p, e = G * params['mass'], params['period'], params['ecc']
a_over_p = (mu / p / 4. / jnp.pi**2)**(1./3.)
det = (mu / 3.) * e / jnp.sqrt(1. - e**2) * a_over_p
return det