from scipy.special import gammaln
import jax
__all__ = [
"ttv_default_parameter_bounds",
"ttv_optim_least_squares",
"ttv_optim_curve_fit",
"scale_pdic",
"unscale_pdic",
"get_flat_param_index",
"make_phase_to_tic_transform",
]
from jax import jacfwd, jacrev, jit
import numpy as np
import jax.numpy as jnp
from scipy.optimize import least_squares
from copy import deepcopy
import time
import warnings
from .utils import params_to_dict, dict_to_params
[docs]
def ttv_default_parameter_bounds(jttv, npl=None, t0_guess=None, p_guess=None,
dtic=0.2, dp_frac=1e-2, emax=0.2,
mmin=1e-7, mmax=1e-3):
"""Get parameter bounds for TTV optimization.
Args:
jttv: JaxTTV object.
npl (int, optional): Number of planets. Defaults to jttv.nplanet if None.
t0_guess (array-like, optional): Initial guess for transit times,
length must be npl.
p_guess (array-like, optional): Initial guess for orbital periods,
length must be npl.
dtic (float, optional): Half-width of bounds around t0_guess for
transit time.
dp_frac (float, optional): Fractional width of bounds around p_guess
for period.
emax (float, optional): Maximum ecosw/esinw bound.
mmin (float, optional): Minimum mass bound.
mmax (float, optional): Maximum mass bound.
Returns:
dict: Dictionary of parameter bounds with keys as parameter names
and values as [lower_bound_array, upper_bound_array].
"""
if npl is None:
npl = jttv.nplanet
if t0_guess is None:
t0_guess = np.array([tcobs_[0] for tcobs_ in jttv.tcobs])
else:
t0_guess = np.array(t0_guess)
assert len(
t0_guess) == npl, f"t0_guess length {len(t0_guess)} != npl {npl}"
if p_guess is None:
p_guess = np.array(jttv.p_init)
else:
p_guess = np.array(p_guess)
assert len(p_guess) == npl, f"p_guess length {len(p_guess)} != npl {npl}"
ones = np.ones(npl)
param_bounds = {
"tic": [t0_guess - dtic, t0_guess + dtic],
"period": [p_guess * (1 - dp_frac), p_guess * (1 + dp_frac)],
"ecosw": [-emax * ones, emax * ones],
"esinw": [-emax * ones, emax * ones],
"lnpmass": [np.log(mmin) * ones, np.log(mmax) * ones],
"pmass": [mmin * ones, mmax * ones],
}
return param_bounds
[docs]
def scale_pdic(pdic, param_bounds):
"""scale parameters using bounds
Args:
pdic: dict of physical parameters
param_bounds: dictionary of (lower bound array, upper bound array)
Returns:
dict: dictionary of scaled parameters
"""
pdic_scaled = {}
for key in param_bounds.keys():
pdic_scaled[key + "_scaled"] = (
(pdic[key] - param_bounds[key][0])
/ (param_bounds[key][1] - param_bounds[key][0])
)
return pdic_scaled
[docs]
def unscale_pdic(pdic_scaled, param_bounds):
"""unscale parameters using bounds
Args:
pdic: dict of scaled parameters
param_bounds: dictionary of (lower bound array, upper bound array)
Returns:
dict: dictionary of physical parameters in original scales
"""
pdic = {}
for key in param_bounds.keys():
pdic[key] = (
param_bounds[key][0]
+ (param_bounds[key][1] - param_bounds[key][0]) *
pdic_scaled[key + "_scaled"]
)
return pdic
def _get_cached_residual_functions(
jttv,
npl,
keys,
transit_orbit_idx=None,
jac=False,
diff_mode="fwd",
):
"""Return cached jitted residual / jacobian functions for repeated fits."""
if not hasattr(jttv, "_lsq_cache"):
jttv._lsq_cache = {}
if diff_mode not in ("fwd", "rev"):
raise ValueError(
f"diff_mode must be 'fwd' or 'rev', got {diff_mode!r}"
)
cache_key = (
npl,
tuple(keys),
None if transit_orbit_idx is None else tuple(transit_orbit_idx),
bool(jac),
diff_mode,
)
if cache_key not in jttv._lsq_cache:
def _model(p_flat):
pdic = params_to_dict(p_flat, npl, keys)
return jttv.get_transit_times_obs(
pdic,
transit_orbit_idx=transit_orbit_idx,
)[0]
def _resid(p_flat):
return (_model(p_flat) - jttv.tcobs_flatten) / jttv.errorobs_flatten
resid = jit(_resid)
if jac:
if diff_mode == "fwd":
jac_resid = jit(jacfwd(_resid))
else:
jac_resid = jit(jacrev(_resid))
else:
jac_resid = None
jttv._lsq_cache[cache_key] = {
"resid": resid,
"jac_resid": jac_resid,
"is_warmed": False,
}
return jttv._lsq_cache[cache_key]
def student_t_2nll_loss(nu, scale=1.0):
"""
Student-t loss for scipy.optimize.least_squares (2*NLL convention).
This callable follows SciPy's custom loss interface:
input : z = f**2
output : array with shape (3, m), containing rho, rho', rho''
Here f is the residual vector passed to least_squares. If your residuals
are already normalized by observational errors, then `scale` is an
additional Student-t scale parameter in those normalized units.
The returned rho(z) is the full 2*NLL per data point, including constants.
Therefore, with f_scale=1.0 in least_squares, res.cost corresponds to
the total NLL.
"""
nu = float(nu)
scale = float(scale)
if nu <= 0:
raise ValueError("nu must be positive.")
if scale <= 0:
raise ValueError("scale must be positive.")
a = nu * scale**2
const = (
2.0 * np.log(scale)
+ np.log(nu * np.pi)
+ 2.0 * gammaln(nu / 2.0)
- 2.0 * gammaln((nu + 1.0) / 2.0)
)
def loss(z):
rho = (nu + 1.0) * np.log1p(z / a) + const
drho = (nu + 1.0) / (a + z)
d2rho = -(nu + 1.0) / (a + z) ** 2
return np.vstack((rho, drho, d2rho))
return loss
[docs]
def get_flat_param_index(keys, npl, key, planet_idx):
"""
Return the flat index for a given parameter key and planet index.
Args:
keys: ordered list of parameter keys used in the flattened vector
npl: total number of planets in the model
key: parameter name, e.g. 'period', 'tic'
planet_idx: zero-based planet index
Returns:
int: flat index into the concatenated parameter vector
"""
if key not in keys:
raise ValueError(f"{key} not found in keys={keys}")
if not (0 <= planet_idx < npl):
raise ValueError(f"planet_idx={planet_idx} out of range for npl={npl}")
return keys.index(key) * npl + planet_idx
[docs]
def ttv_optim_least_squares(
jttv,
param_bounds_,
pinit=None,
n_start=1,
loss="linear",
loss_kwargs=None,
jac=False,
diff_mode="auto",
plot=True,
save=None,
transit_orbit_idx=None,
random_state=None,
max_nfev=None,
param_transform=None,
return_optspace=False,
):
"""Simple TTV fit using scipy.optimize.least_squares.
Args:
jttv: JaxTTV object
param_bounds_: bounds for parameters, dict of {key: (lower, upper)}
pinit: initial guess of parameters (dict)
n_start: number of initial guesses when pinit is not provided.
If pinit is given, a single optimization is run. Otherwise,
multi-start randomizes only the initial planet masses. The first
start uses a deterministic midpoint-like initial guess, and later
starts randomize only `lnpmass`.
loss: loss specification for scipy.optimize.least_squares.
This can be:
- a built-in SciPy loss string such as 'linear', 'soft_l1',
'huber', 'cauchy', 'arctan'
- 'student_t' to use Student-t negative log likelihood
- a custom callable compatible with scipy.optimize.least_squares
loss_kwargs: optional dict for loss-specific options. Default is None.
For loss='student_t':
{'nu': ..., 'scale': ...}
Default is {'nu': 4.0, 'scale': 1.0}.
For built-in or custom non-linear losses:
{'f_scale': ...} can be passed through to least_squares.
For loss='student_t', the custom loss returns 2*NLL per data
point, so least_squares.cost corresponds to the total NLL.
For loss='linear', cost = 0.5 * chi2.
jac: if True, use a JAX-based analytic Jacobian for the residual
function.
diff_mode: differentiation mode for the analytic Jacobian when
jac=True. Must be one of:
- 'auto': use 'fwd' for transit_time_method='fast' and
'rev' for transit_time_method='newton'
- 'rev'
- 'fwd'
Note: for transit_time_method='newton', 'fwd' is overridden to
'rev' because forward-mode differentiation may fail there.
plot: if True, TTV models are plotted with data.
save: path to save TTV plots.
transit_orbit_idx: list of indices to specify which planets are
transiting (needed when non-transiting planets are included)
random_state: int or np.random.RandomState, for reproducibility.
If None, multi-start initializations are not reproducible across runs.
max_nfev: maximum number of function evaluations
param_transform: optional callable mapping optimizer-space parameters
to model-space parameters before residual evaluation.
This is useful, for example, when optimizing orbital phase
instead of time of inferior conjunction for a non-transiting planet.
If jac=True, this should be JAX-compatible.
return_optspace: if True, also return the best-fit parameter dictionary
in optimizer space as a second output. This is useful for warm
starts when param_transform is used.
Returns:
dict or tuple:
- If return_optspace=False (default), returns the best-fit JaxTTV
parameter dictionary in model space.
- If return_optspace=True, returns
(pdic_opt_modelspace, pdic_opt_optspace).
"""
param_bounds = deepcopy(param_bounds_)
if loss_kwargs is None:
loss_kwargs = {}
if diff_mode not in ("auto", "rev", "fwd"):
raise ValueError(
f"diff_mode must be 'auto', 'rev', or 'fwd', got {diff_mode!r}"
)
# resolve differentiation mode
transit_time_method = jttv.transit_time_method
if diff_mode == "auto":
effective_diff_mode = (
"rev" if transit_time_method == "newton" else "fwd"
)
else:
effective_diff_mode = diff_mode
if jac and transit_time_method == "newton" and effective_diff_mode == "fwd":
warnings.warn(
"diff_mode='fwd' is not supported reliably with "
"transit_time_method='newton'; using diff_mode='rev' instead."
)
effective_diff_mode = "rev"
# check non-transiting planets
npl = len(param_bounds["period"][0])
if npl != jttv.nplanet:
print(f"# {npl - jttv.nplanet} non-transiting planets.")
if transit_orbit_idx is None:
raise ValueError(
"transit_orbit_idx must be provided when non-transiting planets "
"are included."
)
transit_orbit_idx = np.asarray(transit_orbit_idx)
if transit_orbit_idx.ndim != 1:
raise ValueError(
f"transit_orbit_idx must be 1D, got shape {transit_orbit_idx.shape}."
)
if len(transit_orbit_idx) != jttv.nplanet:
raise ValueError(
f"transit_orbit_idx must have length {jttv.nplanet}, "
f"got {len(transit_orbit_idx)}."
)
# keys to optimize
if "cosi" not in param_bounds.keys() or "lnode" not in param_bounds.keys():
warnings.warn(
"Bounds for cosi/lnode not provided: assuming coplanar orbits."
)
keys = ["period", "ecosw", "esinw", "tic", "lnpmass"]
else:
keys = ["period", "ecosw", "esinw", "cosi", "lnode", "tic", "lnpmass"]
params_lower = np.hstack([param_bounds[key][0] for key in keys])
params_upper = np.hstack([param_bounds[key][1] for key in keys])
bounds = (params_lower, params_upper)
# slice for lnpmass in the flattened parameter vector
offset = 0
mass_slice = None
for key in keys:
n = len(param_bounds[key][0])
if key == "lnpmass":
mass_slice = slice(offset, offset + n)
break
offset += n
if mass_slice is None:
raise ValueError("lnpmass not found in optimization keys.")
if isinstance(random_state, np.random.RandomState):
rng = random_state
else:
rng = np.random.RandomState(random_state)
cache = _get_cached_residual_functions(
jttv,
npl=npl,
keys=keys,
transit_orbit_idx=transit_orbit_idx,
jac=jac,
diff_mode=effective_diff_mode,
)
resid_base = cache["resid"]
jac_resid_base = cache["jac_resid"]
def transform_p(p):
if param_transform is None:
return p
return param_transform(p)
def resid_jax(p):
return resid_base(transform_p(p))
if jac:
if param_transform is None:
jac_resid_jax = jac_resid_base
else:
# include chain rule through param_transform
if effective_diff_mode == "rev":
jac_resid_jax = jax.jit(jax.jacrev(resid_jax))
else:
jac_resid_jax = jax.jit(jax.jacfwd(resid_jax))
def resid_np(p):
return np.array(resid_jax(jnp.asarray(p)), dtype=float, copy=True)
def jac_np(p):
return np.array(jac_resid_jax(jnp.asarray(p)), dtype=float, copy=True)
def chi2_np(p):
r = resid_np(p)
return float(np.sum(r**2))
# warm up once per cache key / transform choice
if pinit is not None:
p_warm = np.hstack([pinit[key] for key in keys])
else:
p_warm = 0.5 * (params_lower + params_upper)
p_warm = np.clip(p_warm, params_lower, params_upper)
if (not cache["is_warmed"]) or (param_transform is not None):
_ = resid_np(p_warm)
if jac:
_ = jac_np(p_warm)
if param_transform is None:
cache["is_warmed"] = True
if pinit is not None and n_start != 1:
print("# pinit is provided; ignoring n_start and running a single optimization.")
n_start_eff = 1 if pinit is not None else n_start
# resolve loss for least_squares
if loss == "student_t":
lsq_loss = student_t_2nll_loss(
nu=loss_kwargs.get("nu", 4.0),
scale=loss_kwargs.get("scale", 1.0),
)
lsq_f_scale = 1.0
else:
lsq_loss = loss
lsq_f_scale = loss_kwargs.get("f_scale", 1.0)
best_popt = None
best_cost = np.inf
best_chi2 = np.inf
print(
"# running least squares optimization "
f"(n_start={n_start_eff}, loss={loss}, "
f"transit_time_method={transit_time_method}, "
f"jac={'on' if jac else 'off'}, "
f"diff_mode={effective_diff_mode if jac else 'n/a'})..."
)
t0_all = time.time()
# baseline starting point used when pinit is not given
p0_base = 0.499 * params_lower + 0.501 * params_upper
for i in range(n_start_eff):
if pinit is not None:
p0 = np.hstack([pinit[key] for key in keys])
else:
p0 = p0_base.copy()
if i > 0:
lo = params_lower[mass_slice]
hi = params_upper[mass_slice]
p0[mass_slice] = lo + rng.rand(hi.size) * (hi - lo)
p0 = np.clip(p0, params_lower, params_upper)
chi2_init = chi2_np(p0)
t0 = time.time()
try:
res = least_squares(
resid_np,
p0,
jac=jac_np if jac else "2-point",
bounds=bounds,
method="trf",
loss=lsq_loss,
f_scale=lsq_f_scale,
max_nfev=max_nfev,
)
except (RuntimeError, ValueError) as e:
print(f"# start {i}: least_squares failed ({e})")
continue
dt = time.time() - t0
chi2_fin = float(np.sum(res.fun**2))
cost_fin = float(res.cost)
pmass0_str = np.array2string(
np.exp(p0[mass_slice]) / 3.003e-6,
precision=1,
separator=", ",
)
print(
f"# start {i}: initial pmass={pmass0_str}, "
f"chi2={chi2_init:.2f} --> {chi2_fin:.2f}, "
f"cost={cost_fin:.2f}, nfev={res.nfev}, elapsed={dt:.1f} s"
)
if cost_fin < best_cost:
best_cost = cost_fin
best_chi2 = chi2_fin
best_popt = res.x
print("# ------------------------------------------------------------")
print(
"# best objective over all starts: "
f"cost={best_cost:.2f}, chi2={best_chi2:.2f} "
f"({len(jttv.tcobs_flatten)} data)"
)
print("# total elapsed time: %.1f sec" % (time.time() - t0_all))
print("# ------------------------------------------------------------")
if best_popt is None:
raise RuntimeError("All fits failed.")
# convert from optimizer-space to model-space parameters
best_popt_model = np.array(
transform_p(jnp.asarray(best_popt)),
dtype=float,
copy=True,
)
pdic_opt = params_to_dict(best_popt_model, npl, keys)
if transit_time_method == "fast":
fast_validation_threshold = float(
np.median(np.asarray(jttv.errorobs_flatten, dtype=float)))
tc_fast = np.asarray(
jttv.get_transit_times_obs(
pdic_opt,
transit_orbit_idx=transit_orbit_idx,
)[0],
dtype=float,
)
jttv_newton = deepcopy(jttv)
jttv_newton.transit_time_method = "newton"
tc_newton = np.asarray(
jttv_newton.get_transit_times_obs(
pdic_opt,
transit_orbit_idx=transit_orbit_idx,
)[0],
dtype=float,
)
max_abs_dt = float(np.max(np.abs(tc_fast - tc_newton)))
if max_abs_dt > fast_validation_threshold:
raise RuntimeError(
"Optimization with transit_time_method='fast' failed validation: "
"the final model differs from a Newton-based transit-time evaluation by "
f"max_abs_dt={max_abs_dt:.2e} d, which exceeds the median timing error "
f"({fast_validation_threshold:.2e} d). "
"Consider rerunning the optimization with transit_time_method='newton'. "
"Once a reliable solution has been found, transit_time_method='fast' can "
"still be used for subsequent NUTS initialization and sampling."
)
if plot:
tcall = jttv.get_transit_times_all_list(
pdic_opt,
transit_orbit_idx=transit_orbit_idx,
)
jttv.plot_model(tcall, marker=".", save=save)
pdic_opt["pmass"] = jnp.exp(pdic_opt["lnpmass"])
if return_optspace:
pdic_optspace = params_to_dict(best_popt, npl, keys)
pdic_optspace["pmass"] = jnp.exp(pdic_optspace["lnpmass"])
return pdic_opt, pdic_optspace
return pdic_opt
[docs]
def ttv_optim_curve_fit(*args, **kwargs):
"""Deprecated wrapper for ttv_optim_least_squares."""
warnings.warn(
"ttv_optim_curve_fit is deprecated and will be removed in a future "
"release. Use ttv_optim_least_squares instead.",
FutureWarning,
stacklevel=2,
)
return ttv_optim_least_squares(*args, **kwargs)