Source code for jnkepler.jaxttv.utils

from .symplectic import kepler_step_map
__all__ = [
    "initialize_jacobi_xv", "get_energy_diff", "get_energy_diff_jac",
    "params_to_elements", "elements_to_pdic", "convert_elements", "findidx_map", "params_to_dict", "dict_to_params", "em_to_dict"
]

import numpy as np
import jax.numpy as jnp
from jax import jit, vmap, config
from .conversion import m_to_u, tic_to_m, tic_to_u, elements_to_xv, xv_to_elements, G, xvjac_to_xvcm
config.update('jax_enable_x64', True)


M_earth = 3.0034893e-6


[docs] def params_to_dict(params, npl, keys): """convert 1D parameter array into parameter dict Args: parameter array: [arr for key1, arr for key2, ...] where len(arr) is the number of planets npl: number of planets keys: parameter keys [key1, key2, ...] Returns: parameter dict """ pdic = {} for i, key in enumerate(keys): pdic[key] = params[i*npl:(i+1)*npl] return pdic
[docs] def dict_to_params(pdic, npl, keys): """ Inverse of `params_to_dict` when each value has length `npl`. Args: pdic (Mapping[str, array_like]): Parameter dict, e.g. {'a': arr(len=npl), 'b': arr(len=npl), ...}. keys (Sequence[str]): Order of parameters; concatenation follows this order. Returns: ndarray or jax.numpy.ndarray: 1D parameter array of length len(keys) * npl. Raises: KeyError: if a key in `keys` is missing from `pdic`. ValueError: if value lengths are inconsistent across keys. """ arrs = [] for k in keys: if k not in pdic: raise KeyError(f"Key {k!r} not in pdic.") a = np.asarray(pdic[k]).ravel() if a.size != npl: raise ValueError(f"Length mismatch at {k!r}: {a.size} != {npl}") arrs.append(a) params = np.concatenate(arrs, axis=0) return params
[docs] def em_to_dict(elements, masses): """convert arrays of elements and masses in v0.1.0 to parameter dict Note: This function is mainly for running tests; no longer needed for v>=0.2. Args: elements: elements (JaxTTV format) masses: masses of star + planets (solar units) Returns: parameter dict """ pdic = {} for k, key in enumerate(["period", "ecosw", "esinw", "cosi", "lnode", "tic"]): pdic[key] = elements[:, k] pdic['pmass'] = masses[1:] return pdic
[docs] def initialize_jacobi_xv(par_dict, t_epoch): """compute initial position/velocity from parameter dict Note: Here the elements are interpreted as Jacobi elements using the total interior mass (see Section 2.2 of Rein & Tamayo 2015). Args: par_dict: parameter dictionary that needs to contain - either (ecosw, esinw) or (e, omega) - cosi (set to be 0 if not specified) - lnode (set to be 0 if not specified) - either (time of inferior conjunction) or (mean anomaly) - stellar mass (set to be 1 if not specified), solar unit - either (planetary mass) or (ln planetary mass), solar unit, former is used when both are provided t_epoch: epoch at which elements are defined Returns: tuple: - Jacobi positions at t_epoch (Norbit, xyz) - Jacobi velocities at t_epoch (Norbit, xyz) - masses: 1D array of stellar and planetary masses (Nbody,) """ keys = par_dict.keys() period = par_dict["period"] if "ecosw" in keys and "esinw" in keys: ecosw, esinw = par_dict['ecosw'], par_dict['esinw'] ecc = jnp.sqrt(ecosw**2 + esinw**2) omega = jnp.arctan2(esinw, ecosw) elif "ecc" in keys and "omega" in keys: ecc, omega = par_dict["ecc"], par_dict["omega"] else: raise ValueError( "Either (ecosw, esinw) or (ecc, omega) needs to be provided.") if "cosi" in keys: inc = jnp.arccos(par_dict["cosi"]) else: inc = jnp.arccos(period * 0.) if "lnode" in keys: lnode = par_dict["lnode"] else: lnode = period * 0. if "tic" in keys: ma = tic_to_m(par_dict["tic"], period, ecc, omega, t_epoch) elif "ma" in keys: ma = par_dict["ma"] else: raise ValueError( "Either tic (time of inf. conjunction) or ma (mean anom.) needs to be provided.") u = m_to_u(ma, ecc) # eccentric anomaly if "smass" in keys: smass = par_dict["smass"] else: smass = 1. # in this case pmass should be considered as planet-to-star mass ratio if "pmass" in keys: masses = jnp.hstack([smass, par_dict['pmass']]) elif "lnpmass" in keys: masses = jnp.hstack([smass, jnp.exp(par_dict['lnpmass'])]) else: raise ValueError( "Either pmass (solar unit) or lnpmass needs to be provided.") xjac, vjac = [], [] for j in range(len(period)): xj, vj = elements_to_xv( period[j], ecc[j], inc[j], omega[j], lnode[j], u[j], jnp.sum(masses[:j+2])) xjac.append(xj) vjac.append(vj) return jnp.array(xjac), jnp.array(vjac), masses
@jit def get_energy(x, v, masses): """compute total energy of the system in CM frame Args: x: CM positions (Nbody, xyz) v: CM velocities (Nbody, xyz) masses: masses of the bodies (Nbody,) Returns: total energy in units of Msun*(AU/day)^2 """ K = jnp.sum(0.5 * masses * jnp.sum(v*v, axis=1)) X = x[:, None] - x[None, :] M = masses[:, None] * masses[None, :] U = -G * jnp.sum(M * jnp.tril(1./jnp.sqrt(jnp.sum(X*X, axis=2)), k=-1)) return K + U # map along the 1st axes of x and v (Nstep) get_energy_map = jit(vmap(get_energy, (0, 0, None), 0))
[docs] @jit def get_energy_diff(xva, masses): """compute fractional energy change given integration result Args: xva: posisions, velocities, accelerations in CoM frame (Nstep, x or v or a, Norbit, xyz) masses: masses of the bodies (Nbody,) Returns: fractional change in total energy """ _xva = jnp.array([xva[0, :, :, :], xva[-1, :, :, :]]) etot = get_energy_map(_xva[:, 0, :, :], _xva[:, 1, :, :], masses) return etot[1]/etot[0] - 1.
[docs] @jit def get_energy_diff_jac(xvjac, masses, dt): """compute fractional energy change given integration result Args: xvjac: Jacobi posisions and velocities (Nstep, x or v, Norbit, xyz) masses: masses of the bodies (Nbody,) Returns: fractional change in total energy """ xvjac_ends = jnp.array([xvjac[0], xvjac[-1]]) xjac_ends_correct, vjac_ends_correct = kepler_step_map( xvjac_ends[:, 0, :, :], xvjac_ends[:, 1, :, :], masses, dt) xcm, vcm = xvjac_to_xvcm(xjac_ends_correct, vjac_ends_correct, masses) etot = get_energy_map(xcm, vcm, masses) return etot[1]/etot[0] - 1.
[docs] def elements_to_pdic(elements, masses, outkeys=None, force_coplanar=True): """convert JaxTTV elements/masses into dictionary Note: This function is for v<0.2. Args: elements: Jacobi orbital elements (period, ecosw, esinw, cosi, \Omega, T_inf_conjunction) masses: masses of the bodies (Nbody,) outkeys: if specified only include these keys in the output force_coplanar: if True, set incl=pi/2 and lnode=0 Returns: dicionary of the parameters """ npl = len(masses) - 1 pdic = {} pdic['pmass'] = masses[1:] / M_earth pdic['period'] = jnp.array([elements[j][0] for j in range(npl)]) pdic['ecosw'] = jnp.array([elements[j][1] for j in range(npl)]) pdic['esinw'] = jnp.array([elements[j][2] for j in range(npl)]) if force_coplanar: copl = 0. else: copl = 1. pdic['cosi'] = jnp.array([elements[j][3]*copl for j in range(npl)]) pdic['lnode'] = jnp.array([elements[j][4]*copl for j in range(npl)]) pdic['tic'] = jnp.array([elements[j][5] for j in range(npl)]) pdic['ecc'] = jnp.sqrt(pdic['ecosw']**2 + pdic['esinw']**2) pdic['omega'] = jnp.arctan2(pdic['esinw'], pdic['ecosw']) pdic['lnmass'] = jnp.log(masses[1:]) pdic['mass'] = masses[1:] pdic['ecc'] = jnp.sqrt(pdic['ecosw']**2 + pdic['esinw']**2) pdic['cosw'] = pdic['ecosw'] / jnp.fmax(pdic['ecc'], 1e-2) pdic['sinw'] = pdic['esinw'] / jnp.fmax(pdic['ecc'], 1e-2) if outkeys is None: return pdic for key in list(pdic.keys()): if key not in outkeys: pdic.pop(key) return pdic
[docs] def params_to_elements(params, npl): """convert JaxTTV parameter array into element and mass arrays Args: params: JaxTTV parameter array npl: number of orbits (planets) Returns: tuple: - Jacobi orbital elements (period, ecosw, esinw, cosi, \Omega, T_inf_conjunction) - ln(masses) of the bodies (Nbody,) """ elements = jnp.array(params[:-npl].reshape(npl, -1)) masses = jnp.exp(jnp.hstack([0, params[-npl:]])) return elements, masses
[docs] def convert_elements(par_dict, t_epoch, WHsplit=False): """convert JaxTTV elements into more normal sets of parameters Args: par_dict: parameter dict t_epoch: epoch at which elements are defined WHsplit: elements are converted to coordinates assuming Wisdom-Holman splitting. This should be True when the output is used for TTVFast. Returns: tuple: - array: (semi-major axis, period, eccentricity, inclination, argument of periastron, longitude of ascending node, mean anomaly) x (orbits), angles are in radians - mass array """ xjac, vjac, masses = initialize_jacobi_xv(par_dict, t_epoch) if WHsplit: # for H_Kepler defined in WH splitting (i.e. TTVFast) ki = G * masses[0] * jnp.cumsum(masses)[1:] / \ jnp.hstack([masses[0], jnp.cumsum(masses)[1:][:-1]]) else: # total interior mass ki = G * jnp.cumsum(masses)[1:] return xv_to_elements(xjac, vjac, ki), masses
[docs] def findidx_map(arr1, arr2): """pick up elements of arr1 nearest to each element in arr2 Args: arr1: array from which elements are picked up arr2: array of the values for which nearest matches are searched Returns: indices of arr1 nearest to each element in arr2 """ def func(arr1, val): return jnp.argmin(jnp.abs(arr1 - val)) func_map = jit(vmap(func, (None, 0), 0)) return func_map(arr1, arr2)