Source code for jnkepler.keplerian.orbit

__all__ = ["tic_to_tau", "radial_velocity_shape", "radial_velocity_shape_multi",
           "elements_to_xv", "elements_to_xv_scaled", "xv_to_elements"]

import jax.numpy as jnp
from jax import vmap
from ..jaxttv.conversion import m_to_u
from ..jaxttv.conversion import G
from ..jaxttv.conversion import elements_to_xv as _elements_to_xv
from ..jaxttv.conversion import xv_to_elements as _xv_to_elements


[docs] def tic_to_tau(tic, period, ecc, omega): """Compute time of periastron passage from time of inferior conjunction tic Args: tic: time of inferior conjunction where omega + f = pi/2 period: orbital period ecc: eccentricity omega: argument of periastron Returns: float: time of periastron passage """ tanw2 = jnp.tan(0.5 * omega) u0 = 2 * jnp.arctan(jnp.sqrt((1. - ecc)/(1. + ecc)) * (1. - tanw2)/(1. + tanw2)) tau = tic - period / (2. * jnp.pi) * (u0 - ecc * jnp.sin(u0)) return tau
[docs] def radial_velocity_shape(t, params): """Compute cos(omega+f) + e*cos(omega) Args: t: times at which RVs are computed porb: period ecc: eccentricity omega: argument of periastron tau: time of periastron passage Returns: array: radial velocities """ M = 2 * jnp.pi * (t - params['tau']) / params['period'] e, omega = params['ecc'], params['omega'] u = m_to_u(M, e) f = 2 * jnp.arctan(jnp.sqrt((1. + e)/(1. - e)) * jnp.tan(0.5 * u)) vz_unit_amp = jnp.cos(omega + f) + e * jnp.cos(omega) return vz_unit_amp
[docs] def radial_velocity_shape_multi(t, params): """Vectorized over leading axis of each leaf in params_all (dict-of-arrays).""" return vmap(lambda p: radial_velocity_shape(t, p))(params)
[docs] def elements_to_xv(t, params): """Convert orbital elements to Cartesian state vectors. Args: t (array_like): Times (days) at which positions and velocities are evaluated. params (dict): Dictionary containing per-orbit orbital elements: - period : orbital period (days) - ecc : eccentricity - inc : inclination (radian) - omega : argument of periastron (radian) - lnode : longitude of ascending node (radian) - tau : time of periastron passage (days) - mass : total mass (solar masses) Returns: dict: Cartesian position and velocity vectors: - x : array of shape (T, N, 3) if multiple orbits, or (T, 3) if a single orbit. Units: AU. - v : array of shape (T, N, 3) or (T, 3). Units: AU/day. """ t = jnp.atleast_1d(t) porb = jnp.atleast_1d(params['period']) ecc = jnp.atleast_1d(params['ecc']) inc = jnp.atleast_1d(params['inc']) omega = jnp.atleast_1d(params['omega']) lnode = jnp.atleast_1d(params['lnode']) tau = jnp.atleast_1d(params['tau']) mass = jnp.atleast_1d(params['mass']) # If mass is scalar but you have multiple orbits, repeat it (optional but safe) N = porb.size if mass.size == 1 and N > 1: mass = jnp.repeat(mass, N) mean_anom = 2 * jnp.pi * (t[:, None] - tau) / porb eccentric_anom = m_to_u(mean_anom, ecc) _elements_to_xv_vmap = vmap( _elements_to_xv, (None, None, None, None, None, 0, None), 0) x, v = _elements_to_xv_vmap( porb, ecc, inc, omega, lnode, eccentric_anom, mass) x = jnp.swapaxes(x, 1, 2) # -> (T, N, 3) v = jnp.swapaxes(v, 1, 2) if x.shape[1] == 1: x = x[:, 0] v = v[:, 0] return dict(x=x, v=v)
[docs] def elements_to_xv_scaled(t, params): """Convert orbital elements to state vectors scaled by semi-major axis a. Args: t (array_like): Times (days) at which positions and velocities are evaluated. params (dict): Dictionary containing per-orbit orbital elements: - period : orbital period (days) - ecc : eccentricity - inc : inclination (radian) - omega : argument of periastron (radian) - lnode : longitude of ascending node (radian) - tau : time of periastron passage (days) Returns: dict: Cartesian position and velocity vectors divided by a: - x : array of shape (T, N, 3) or (T, 3). Dimensionless. - v : array of shape (T, N, 3) or (T, 3). Units: 1/day. """ par = params.copy() par['mass'] = (2 * jnp.pi / params['period'])**2 / \ G # set mass so that a=1 return elements_to_xv(t, par)
[docs] def xv_to_elements(x, v, mass, t_ref=None): """Convert Cartesian state vectors to orbital elements. Args: x : array_like Cartesian position vector(s) in AU, shape (3,) or (N, 3). v : array_like Cartesian velocity vector(s) in AU/day, shape (3,) or (N, 3). mass : float or array_like Total mass (solar masses). t_ref : float or array_like, optional Reference epoch (days). If provided, the time of periastron passage (tau) is computed as tau = t_ref - M/n. Returns: dict Orbital elements: - a : semi-major axis (AU) - period : orbital period (days) - ecc : eccentricity - inc : inclination (radian) - omega : argument of periastron (radian) - lnode : longitude of ascending node (radian) - M : mean anomaly at t_ref (radian) - tau : time of periastron passage (days), if t_ref is given - mass : total mass (solar masses) Each value is a scalar if a single orbit is given, or a 1-D array of length N for multiple orbits. """ GM = jnp.atleast_1d(G * mass) x = jnp.atleast_2d(x) v = jnp.atleast_2d(v) elements = _xv_to_elements(x, v, GM) keys = ['a', 'period', 'ecc', 'inc', 'omega', 'lnode', 'M'] out = dict(zip(keys, elements)) out['mass'] = mass if t_ref is not None: n = 2.0 * jnp.pi / out['period'] out['tau'] = t_ref - out['M'] / n if x.shape[0] == 1: out = {k: jnp.squeeze(v) for k, v in out.items()} return out