Source code for jnkepler.nbodytransit.nbodytransit


__all__ = ["NbodyTransit", "q_to_u", "b_to_cosi"]

import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import warnings
from functools import partial
from ..jaxttv import JaxTTV
from ..jaxttv.utils import *
from ..jaxttv.conversion import *
from ..jaxttv.findtransit import find_transit_params_all, find_transit_params_fast
from ..jaxttv.symplectic import integrate_xv, kepler_step_map
from ..jaxttv.hermite4 import integrate_xv as integrate_xv_hermite4
from ..jaxttv.rv import *
from .transit import *
from jax import jit, grad, config
config.update('jax_enable_x64', True)


[docs] class NbodyTransit(JaxTTV): """main class for photodynamical analysis. Note: Unlike in JaxTTV class, non-transiting objects are not yet supported. """
[docs] def set_lcobs(self, times_lc, overlapping_transit=False, exposure_time=29.4/1440., supersample_factor=10, print_info=True): """initialization Args: times_lc: times in the light curve overlapping_transit: if True, the code accounts for simultaneous transits; overlap between the planets is not yet supported exposure_time: exposure time (same unit as times_lc) supersample_factor: flux is computed at intervals given by exposure_time / supersample_factor and then summed print_info: if True, show the parameters """ self.times_lc = times_lc self.supersample_num = int(supersample_factor // 2 * 2 + 1) if self.supersample_num < 2: self.times_super = times_lc else: dt = exposure_time / (self.supersample_num - 1) dtarr = np.linspace(0, exposure_time, self.supersample_num) dtarr -= np.median(dtarr) self.times_super = (times_lc[:, None] + dtarr).ravel() times_transit_idx, times_planet_idx = [], [] for j in range(self.nplanet): tcj = np.where(self.pidx == j+1, self.tcobs_flatten, -np.inf) _tidx = find_nearest_idx(tcj, self.times_super) _pidx = jnp.ones_like(_tidx) * j times_transit_idx.append(_tidx) times_planet_idx.append(_pidx) self.times_transit_idx = jnp.array(times_transit_idx) self.times_planet_idx = jnp.array(times_planet_idx) self.times_transit_idx_nool = find_nearest_idx( self.tcobs_flatten, self.times_super) self.times_planet_idx_nool = self.pidx[self.times_transit_idx_nool].astype( int) - 1 self.overlapping_transit = overlapping_transit if print_info: print("# exposure time (min):".ljust( 35) + "%.1f" % (exposure_time*1440)) print("# supersampling factor:".ljust(35) + "%d" % (self.supersample_num-1)) if not self.overlapping_transit: warnings.warn( "set overlapping_transit to True if overlapping transits are present in the data.")
[docs] def get_xvsky_tc(self, par_dict): """compute sky-plane positions and velocities at transit centers Args: par_dict (dict): dict containing TTV parameters Returns: tuple: - transit centers (Ntransit) - astrocentric positions in the sky plane at transit centers (Ntransit, xy) - astrocentric velocities in the sky plane at transit centers (Ntransit, xy) - times - mass array (star+planets, solar unit) """ xjac0, vjac0, masses = initialize_jacobi_xv( par_dict, self.t_start) # initial Jacobi position/velocity times, xvjac = integrate_xv( xjac0, vjac0, masses, self.times, nitr=self.nitr_kepler) # integration pidxarr = self.pidx.astype(int) # idx for planet, starting from 1 tcobs1d = self.tcobs_flatten if self.transit_time_method == "newton": tc, (xcm, vcm, _) = find_transit_params_all( pidxarr - 1, tcobs1d, times, xvjac, masses ) else: tc, (xcm, vcm, _) = find_transit_params_fast( pidxarr - 1, tcobs1d, times, xvjac, masses ) xsky_tc, vsky_tc = get_xvast_map(xcm, vcm, pidxarr) return tc, xsky_tc, vsky_tc, times, xvjac, masses
[docs] @partial(jit, static_argnums=(0,)) def get_flux(self, par_dict): """compute nbody flux Args: par_dict (dict): dict containing input parameters: - TTV parameters - srad: stellar radius (in units of solar radius) - radius_ratio: planet-to-star radius ratio (Norbit,) - u1, u2: quadratic limb-darkening coefficients Returns: tuple: - transit light curve (len(times_lc),) - transit times (1D flattened array) """ _par_dict = initialize_transit_params(par_dict) tc, xsky_tc, vsky_tc, _, _, _ = self.get_xvsky_tc(_par_dict) if self.overlapping_transit: nbodyflux_ss = compute_nbody_flux(_par_dict["srad"], _par_dict["radius_ratio"], _par_dict["u1"], _par_dict["u2"], tc, xsky_tc, vsky_tc, self.times_super, self.times_transit_idx, self.times_planet_idx) else: nbodyflux_ss = compute_nbody_flux_nooverlap(_par_dict["srad"], _par_dict["radius_ratio"], _par_dict["u1"], _par_dict["u2"], tc, xsky_tc, vsky_tc, self.times_super, self.times_transit_idx_nool, self.times_planet_idx_nool) nbodyflux = jnp.mean( nbodyflux_ss.reshape(-1, self.supersample_num), axis=1) return nbodyflux, tc
[docs] @partial(jit, static_argnums=(0,)) def get_flux_and_rv(self, par_dict, times_rv): """compute nbody flux and RV Args: par_dict (dict): dict containing input parameters: - TTV parameters - srad: stellar radius (in units of solar radius) - radius_ratio: planet-to-star radius ratio (Norbit,) - u1, u2: quadratic limb-darkening coefficients times_rv: times at which RVs are evaluated Returns: tuple: - transit light curve (len(times_lc),) - transit times (1D flattened array) - stellar RVs at times_rvs (m/s), positive when the star is moving away """ _par_dict = initialize_transit_params(par_dict) tc, xsky_tc, vsky_tc, times, xvjac, masses = self.get_xvsky_tc( _par_dict) if self.overlapping_transit: nbodyflux_ss = compute_nbody_flux(_par_dict["srad"], _par_dict["radius_ratio"], _par_dict["u1"], _par_dict["u2"], tc, xsky_tc, vsky_tc, self.times_super, self.times_transit_idx, self.times_planet_idx) else: nbodyflux_ss = compute_nbody_flux_nooverlap(_par_dict["srad"], _par_dict["radius_ratio"], _par_dict["u1"], _par_dict["u2"], tc, xsky_tc, vsky_tc, self.times_super, self.times_transit_idx_nool, self.times_planet_idx_nool) nbodyflux = jnp.mean( nbodyflux_ss.reshape(-1, self.supersample_num), axis=1) nbodyrv = rv_from_xvjac(times_rv, times, xvjac, masses) return nbodyflux, tc, nbodyrv
def initialize_transit_params(par_dict_): """initialize transit parameters Args: par_dict (dict) Returns: dict: following parameters are newly added: - stellar radius (solar unit); stellar mass = 1 assumed if only density is specified in par_dict - radius ratio - coefficients for quadratic limb-darkening law """ keys = par_dict_.keys() par_dict = par_dict_.copy() if "smass" in keys and "srad" in keys: pass elif "sdens" in keys: par_dict["smass"] = par_dict_["smass"] if "smass" in keys else 1. par_dict["srad"] = (par_dict["smass"] / par_dict_["sdens"])**(1./3.) else: raise ValueError( "Either (smass, srad) (stellar mass/radius in solar unit) or sdens (stellar mean density in solar unit) needs to be provided." ) if "q1" in keys and "q2" in keys: par_dict["u1"], par_dict["u2"] = q_to_u( par_dict_["q1"], par_dict_["q2"]) elif "u1" in keys and "u2" in keys: pass else: raise ValueError( "Either (q1, q2) or (u1, u2) needs to be provided for quadratic limb-darkening." ) if not "radius_ratio" in keys: raise ValueError( "radius_ratio needs to be provided for planet-to-star radius ratios.") if "b" in keys: par_dict["cosi"] = b_to_cosi(par_dict_["b"], par_dict_["period"], par_dict_[ "ecosw"], par_dict_["esinw"], par_dict_["srad"], par_dict_["smass"]) elif "cosi" in keys: pass else: raise ValueError( "Either b or cosi needs to be provided to specify orbital inclinations." ) return par_dict
[docs] def q_to_u(q1, q2): """convert q1, q2 into u1, u2 Args: q1, q2: quadratic limb-darkening coefficients as parameterized in Kipping, D. M. 2013, MNRAS, 435, 2152 Returns: tuple: quadratic limb-darkening coefficients u1, u2 """ usum = jnp.sqrt(q1) u1 = 2 * q2 * usum u2 = usum - u1 return u1, u2
[docs] def b_to_cosi(b, period, ecosw, esinw, rstar, mstar): """convert b into cosi following Eq.7 of Winn (2010), arXiv:1001.2010 Args: b: impact parameter (normalized to stellar radius) period: orbital period ecosw: eccentricity * cos(argument of periastron) esinw: eccentricity * sin(argument of periastron) rstar: stellar radius (solar unit) mstar: stellar mass (solar unit) Returns: float: cosine of inclination """ # adopting G, M_sun, R_sun from astropy.constants a_over_r = 4.2083 * period**(2./3.) / rstar * mstar**(1./3.) efactor = (1. - ecosw**2 - esinw**2) / (1. + esinw) return b / a_over_r / efactor