__all__ = ["information", "scale_information"]
# "observed_information", "hessian", "information_numpyrox"] # experimental
import jax.numpy as jnp
from jax import jacfwd, jacrev
from .utils import params_to_dict
def get_2d_matrix(p, param_order):
"""extract 2D matrix from pytree
Args:
p: pytree (as output from numpyro_ext.information)
param_order: list of parameter keys
Returns:
2D array
"""
coord_dict = {}
N_par = 0
size = len(p[param_order[0]][param_order[0]][0])
for par in param_order:
coord_dict[par] = jnp.arange(N_par, N_par + size)
N_par += size
arr_2D = jnp.zeros((N_par, N_par))
for k1 in param_order:
for k2 in param_order:
arr_2D = arr_2D.at[jnp.ix_(
coord_dict[k1], coord_dict[k2])].set(p[k1][k2])
return arr_2D
def information_numpyrox(numpyro_model, pdic, **kwargs):
"""Fisher information from numpyro model using numpryo-ext
Args:
numpyro_model: numpyro model
pdic: dict containing parameters
kwargs: additional arguments for numpyro model
Returns:
information matrix evaulated at pdic, list of site names
"""
from numpyro_ext import information
info_inv = information(numpyro_model, invert=True)(pdic, **kwargs)
pnames = list(info_inv.keys())
matrix = get_2d_matrix(info_inv, param_order=pnames)
return matrix, pnames
def negative_log_likelihood(jttv, pdic, lnmass=False):
"""negative log likelihood (iid gaussian)
Args:
jttv: JaxTTV object
pdic: dict containing parameters
Returns:
negative log likelihood
"""
transit_times = jttv.get_transit_times_obs(pdic)[0]
return 0.5 * jnp.sum(((jttv.tcobs_flatten - transit_times) / jttv.errorobs_flatten)**2)
def observed_information(jttv, pdic, keys):
"""compute observed Fisher information matrix (a.k.a. Hessian) for iid gaussian likelihood
Note:
This returns the same matrix as 'hessian' function below for keys=['ecosw', 'esinw', 'mass', 'period', 'tic']
Args:
jttv: JaxTTV object
pdic: dict containing parameters
Returns:
observed information matrix computed as grad.T Sigma_inv grad
"""
assert {'ecosw', 'esinw', 'period', 'tic', 'lnode', 'cosi'}.issubset(
pdic.keys()), "pdic keys must contain all of ecosw, esinw, period, tic, lnode, cosi."
assert 'pmass' in pdic.keys() or 'lnpmass' in pdic.keys(
), "pdic keys must contain either mass or lnmass."
assert set(keys).issubset({'ecosw', 'esinw', 'period', 'tic', 'lnode', 'cosi', 'lnpmass', 'pmass'}
), "pdic keys must a subsect of {ecosw, esinw, period, tic, lnode, cosi}+{mass or lnmass}"
# jacfwd fails for newton-raphson method
from copy import deepcopy
if jttv.transit_time_method != "interpolation":
jttv_copy = deepcopy(jttv)
jttv_copy.transit_time_method = "interpolation"
else:
jttv_copy = jttv
hessian_pytree = jacfwd(
jacrev(negative_log_likelihood, argnums=1), argnums=1)(jttv_copy, pdic)
return get_2d_matrix(hessian_pytree, keys)
def hessian(self, pdic):
"""compute hessian for iid gaussian likelihood;
Note:
CURRENTLY WORKS ONLY FOR ['ecosw', 'esinw', 'mass', 'period', 'tic']
For these keys, this function returns the same matirx as 'observed_hessian' function above, but is faster
Args:
pdic: parameter dictionary
Returns:
hessian (second derivative of the negative log likelihood)
"""
from jnkepler.jaxttv.utils import initialize_jacobi_xv
from jnkepler.jaxttv.findtransit import find_transit_times_kepler_all
from jnkepler.jaxttv.symplectic import integrate_xv
keys = ['ecosw', 'esinw', 'pmass', 'period', 'tic']
def negloglike(parr):
# jacfwd fails for newton-raphson method, so use interpolate method
pdic = params_to_dict(parr, self.nplanet, keys)
xjac0, vjac0, masses = initialize_jacobi_xv(pdic, self.t_start)
times, xvjac = integrate_xv(
xjac0, vjac0, masses, self.times, nitr=self.nitr_kepler)
orbit_idx = self.pidx.astype(int) - 1
tcobs1d = self.tcobs_flatten
transit_times = find_transit_times_kepler_all(
orbit_idx, tcobs1d, times, xvjac, masses, nitr=self.nitr_transit)
return 0.5 * jnp.sum(((self.tcobs_flatten - transit_times)/self.errorobs_flatten)**2)
parr = jnp.hstack([jnp.array(pdic[key]) for key in keys])
hessian = jacfwd(jacrev(negloglike))(parr)
return hessian