Source code for jnkepler.infer


__all__ = ["optim_svi", "fit_t_distribution"]

from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoLaplaceApproximation
from numpyro.infer.initialization import init_to_value, init_to_sample
from scipy.stats import t as tdist
from scipy.stats import norm
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from scipy import special


[docs] def optim_svi(numpyro_model, step_size, num_steps, p_initial=None, **kwargs): """optimization using Stochastic Variational Inference (SVI) Args: numpyro_model: numpyro model step_size: step size for optimization num_steps: # of steps for optimization p_initial: initial parameter set (dict); if None, use init_to_sample to initialize Returns: dict: dictionary containing optimized parameters """ optimizer = numpyro.optim.Adam(step_size=step_size) if p_initial is None: guide = AutoLaplaceApproximation( numpyro_model, init_loc_fn=init_to_sample) else: guide = AutoLaplaceApproximation( numpyro_model, init_loc_fn=init_to_value(values=p_initial)) # SVI object svi = SVI(numpyro_model, guide, optimizer, loss=Trace_ELBO(), **kwargs) # run the optimizer and get the posterior median svi_result = svi.run(random.PRNGKey(0), num_steps) params_svi = svi_result.params p_fit = guide.median(params_svi) return p_fit
[docs] def fit_t_distribution(y, plot=True, fit_mean=False, save=None, xrange=5): """fit Student's t distribution to a sample y Args: y: 1D array plot: if True, plot results fit_mean: if True, mean of the distribution is also fitted Returns: dict: dictionary with the following keys: - lndf_loc: mean of log(dof) - lndf_scale: std of log(dof) - lnvar_loc: mean of log(variance) - lnvar_scale: std of log(variance) - mean_loc: mean of mean (if fitted) - mean_scale: std of mean (if fitted) """ def model(y): logdf = numpyro.sample( "lndf", dist.Uniform(jnp.log(0.1), jnp.log(100))) logvar = numpyro.sample("lnvar", dist.Uniform(-2, 10)) df = numpyro.deterministic("df", jnp.exp(logdf)) v1 = numpyro.deterministic("v1", jnp.exp(logvar)) if fit_mean: mean = numpyro.sample( "mean", dist.Uniform(-jnp.std(y), jnp.std(y))) numpyro.sample("obs", dist.StudentT( loc=mean, scale=jnp.sqrt(v1), df=df), obs=y) else: numpyro.sample("obs", dist.StudentT( scale=jnp.sqrt(v1), df=df), obs=y) kernel = numpyro.infer.NUTS(model) mcmc = numpyro.infer.MCMC(kernel, num_warmup=500, num_samples=500) rng_key = random.PRNGKey(0) mcmc.run(rng_key, y) mcmc.print_summary() samples = mcmc.get_samples() lndf, lnvar = np.mean(samples['lndf']), np.mean(samples['lnvar']) lndf_sd, lnvar_sd = np.std(samples['lndf']), np.std(samples['lnvar']) pout = {'lndf_loc': lndf, 'lndf_scale': lndf_sd, 'lnvar_loc': lnvar, 'lnvar_scale': lnvar_sd} if fit_mean: mean, mean_sd = np.mean(samples['mean']), np.std(samples['mean']) pout['mean_loc'] = mean pout['mean_scale'] = mean_sd else: mean = 0. if plot: sd = np.std(y) fig, ax = plt.subplots(1, 2, figsize=(16, 4)) ax[1].set_yscale("log") ax[1].set_ylabel("PDF") ax[0].set_ylabel("CDF") ax[0].set_xlabel("residual / assigned error") ax[1].set_xlabel("residual / assigned error") bin_width = _knuth_bin_width(y) bins = int(np.ceil((y.max() - y.min()) / bin_width)) ax[1].hist(y, histtype='step', lw=3, alpha=0.6, bins=bins, density=True, color='gray') ymin, ymax = plt.gca().get_ylim() ax[1].set_ylim(ymin/5., ymax*1.5) x0 = np.linspace(-xrange, xrange, 100) ax[1].plot(x0, norm(scale=sd).pdf(x0), lw=1, color='C0', ls='dashed', label='normal, $\mathrm{SD}=%.2f$' % sd) ax[1].plot(x0, norm.pdf(x0), lw=1, color='C0', ls='dotted', label='normal, $\mathrm{SD}=1$') ax[1].plot(x0, tdist(loc=mean, scale=np.exp(lnvar*0.5), df=np.exp(lndf)).pdf(x0), label='Student\'s t\n(lndf=%.2f, lnvar=%.2f, mean=%.2f)' % (lndf, lnvar, mean)) ysum = np.ones_like(y) hist, edge = np.histogram(y, bins=len(y)) ax[0].plot(np.r_[x0[0], edge[0], edge[:-1], edge[-1], x0[-1]], np.r_[0, 0, np.cumsum(hist)/len(y), 1, 1], lw=3, alpha=0.6, color='gray') ax[0].plot(x0, norm(loc=0, scale=sd).cdf(x0), lw=1, color='C0', ls='dashed', label='normal, $\mathrm{SD}=%.2f$' % sd) ax[0].plot(x0, norm.cdf(x0), lw=1, color='C0', ls='dotted', label='normal, $\mathrm{SD}=1$') ax[0].plot(x0, tdist(loc=mean, scale=np.exp(lnvar*0.5), df=np.exp(lndf)).cdf(x0), label='Student\'s t\n(lndf=%.2f, lnvar=%.2f, mean=%.2f)' % (lndf, lnvar, mean)) ax[0].legend(loc='upper left', fontsize=14) if save is not None: plt.savefig(save+"residual.png", dpi=200, bbox_inches="tight") return pout
def _knuth_bin_width(x, Mmax=200, return_bins=False): """ Compute optimal histogram bin width using Knuth (2006) rule with a Freedman-Diaconis lower bound. The function performs a discrete search over the number of bins M = 1..min(Mmax, N) and selects the value that maximizes the Knuth log-posterior. A Freedman-Diaconis floor ensures a reasonable binning even for pathological distributions. Args: x (array_like): 1D input data array. Must be finite and non-empty. Mmax (int, optional): Maximum number of bins to consider in the Knuth search. Defaults to 200 (or N if smaller). return_bins (bool, optional): If True, also return the computed bin edges. Returns: float or tuple: If `return_bins` is False, returns the optimal bin width (`float`). If `return_bins` is True, returns a tuple `(width, bins)` where: - `width` (`float`): Optimal bin width. - `bins` (`ndarray`): Array of bin edges. """ x = np.asarray(x, float).ravel() if x.size == 0 or not np.isfinite(x).all(): raise ValueError("x must be finite and non-empty") x.sort() n = x.size xmin, xmax = x[0], x[-1] # all identical -> trivial if xmax == xmin: w = 1.0 return (w, np.array([xmin - 0.5, xmax + 0.5])) if return_bins else w # --- Freedman–Diaconis bins (floor) --- q25, q75 = np.percentile(x, [25, 75]) iqr = q75 - q25 if iqr <= 0: M_fd = max(1, int(np.sqrt(n))) else: dx_fd = 2.0 * iqr / (n ** (1/3)) M_fd = max(1, int(np.ceil((xmax - xmin) / dx_fd))) if dx_fd > 0 else 1 # --- Knuth discrete search --- Mmax = int(max(1, min(Mmax, n))) best_val, best_M = -np.inf, 1 for M in range(1, Mmax + 1): bins = np.linspace(xmin, xmax, M + 1) # match astropy's choice nk, _ = np.histogram(x, bins=bins) val = ( n * np.log(M) + special.gammaln(0.5 * M) - M * special.gammaln(0.5) - special.gammaln(n + 0.5 * M) + np.sum(special.gammaln(nk + 0.5)) ) if val > best_val: best_val, best_M = val, M M = max(best_M, M_fd) bins = np.linspace(xmin, xmax, M + 1) w = bins[1] - bins[0] return (w, bins) if return_bins else w