Source code for jnkepler.information
__all__ = ["information_from_model_independent_normal"]
import jax.numpy as jnp
import functools
from collections import OrderedDict
from jax import jacrev, jacfwd, random
from numpyro import handlers, infer
from numpyro.distributions.transforms import biject_to
def _to_unconstrained(model, params_constrained, keys, *model_args, **model_kwargs):
"""
Convert constrained parameter values to their unconstrained representations
using the model's sample sites.
Args:
model (callable): A NumPyro model function.
params_constrained (dict): Dictionary mapping parameter names to values
in the *constrained* space (e.g., positive scales, simplex).
keys (list): Keys to be included in the output dict.
*args: Positional arguments passed to `model` when tracing.
**kwargs: Keyword arguments passed to `model` when tracing.
Returns:
dict: Dictionary mapping parameter names to values in the
*unconstrained* space, suitable for use in inference algorithms
(e.g., HMC/NUTS).
"""
tr = handlers.trace(handlers.seed(model, 0)).get_trace(
*model_args, **model_kwargs)
bij = {}
for name, site in tr.items():
if site["type"] == "sample" and not site["is_observed"]:
bij[name] = biject_to(site["fn"].support)
z_params = {k: bij[k].inv(params_constrained[k]) for k in keys}
return z_params
def _seed_and_substitute(model, params_dict, param_space, rng_key):
"""Seed and substitute parameters into a NumPyro model."""
if param_space == "unconstrained":
substituted = handlers.substitute(
model,
substitute_fn=functools.partial(
infer.util._unconstrain_reparam, params_dict),
)
elif param_space == "constrained":
substituted = handlers.substitute(model, data=params_dict)
else:
raise ValueError(
"param_space must be 'constrained' or 'unconstrained'.")
return handlers.seed(substituted, rng_seed=rng_key)
def _std_residuals_from_model_independent_normal(
model,
params_dict,
param_space,
rng_key,
*,
sigma_sd,
mu_name="tcmodel",
obs_name="obs",
observed=None,
model_args=(),
model_kwargs=None,
):
"""
Build standardized residuals z = (y - mu(theta)) / sigma for independent Gaussian likelihood.
"""
model_kwargs = {} if model_kwargs is None else model_kwargs
seeded = _seed_and_substitute(model, params_dict, param_space, rng_key)
trace = handlers.trace(seeded).get_trace(*model_args, **model_kwargs)
if mu_name not in trace:
raise KeyError(
f"deterministic mu '{mu_name}' not found in trace. "
"Record it via numpyro.deterministic(mu_name, mu)."
)
mu = jnp.asarray(trace[mu_name]["value"]).reshape(-1)
sigma_sd = jnp.asarray(sigma_sd).reshape(-1)
if observed is not None:
y = jnp.asarray(observed).reshape(-1)
else:
if (obs_name not in trace):
raise KeyError(f"obs site '{obs_name}' not found in trace.")
else:
y = jnp.asarray(trace[obs_name]["value"]).reshape(-1)
if y.shape != mu.shape or y.shape != sigma_sd.shape:
raise ValueError(
f"shape mismatch: y {y.shape}, mu {mu.shape}, sigma {sigma_sd.shape}")
return (y - mu) / sigma_sd # (N,)
[docs]
def information_from_model_independent_normal(
*,
model=None,
model_args=(),
model_kwargs=None,
pdic=None,
mu_name=None,
observed=None,
obs_name=None,
keys=None,
sigma_sd=None,
param_space="unconstrained",
rng_key=None,
):
"""
Compute Fisher information matrix for independent Gaussian likelihood directly from a NumPyro model,
using (observed - mu(pdic)) / sigma_sd obtained from a NumPyro model.
Args:
model: NumPyro model.
model_args, model_kwargs: static args/kwargs for the model.
pdic: dict of parameter values in constrained space.
mu_name: deterministic site name for the model mean.
observed: 1D array of observed values; obs_name is used if not provided.
obs_name: observed site name.
keys: list of parameter names to differentiate (order preserved).
sigma_sd: 1D array of standard deviations (SD) for iid noise.
param_space: 'constrained' or 'unconstrained'; use 'unconstrained' to initialize inverse_mass_matrix.
rng_key: PRNG key (default = jax.random.PRNGKey(0)).
Returns:
dict: A dictionary containing the Fisher information results and related metadata:
- "fisher" (jnp.ndarray): The (P, P) Fisher information matrix.
- "col_slices" (dict[str, slice]): Mapping from each parameter name to its
corresponding column range in the Fisher matrix.
- "col_names" (list[str]): Flattened per-column names, matching the order of
columns in the Fisher matrix.
- "params_unconstrained" (dict[str, jnp.ndarray]): Parameter values in the
unconstrained space used for differentiation.
"""
assert model is not None and pdic is not None and mu_name is not None and keys is not None and sigma_sd is not None
if (observed is None) and (obs_name is None):
raise ValueError("Either `observed` or `obs_name` must be provided.")
keys = list(keys)
if param_space == "unconstrained":
_pdic = _to_unconstrained(
model, pdic, keys, *model_args, **(model_kwargs or {}))
elif param_space == "constrained":
_pdic = dict({k: pdic[k] for k in keys})
else:
raise ValueError(
"param_space must be 'constrained' or 'unconstrained'.")
pdic_sub = OrderedDict((k, _pdic[k]) for k in keys)
rng_key = random.PRNGKey(0) if rng_key is None else rng_key
model_kwargs = {} if model_kwargs is None else model_kwargs
base = dict({k: v for k, v in _pdic.items()
if k not in (mu_name, obs_name)})
def r_fn(p_sub):
p_all = dict(base)
p_all.update(p_sub)
return _std_residuals_from_model_independent_normal(
model,
p_all,
param_space,
rng_key,
sigma_sd=sigma_sd,
mu_name=mu_name,
obs_name=obs_name,
model_args=model_args,
model_kwargs=model_kwargs,
observed=observed
) # (N,)
# Jacobian of standardized residuals w.r.t. params (ordered by `keys`)
Jtree = jacrev(r_fn)(pdic_sub)
# Stack columns in stable key order; flatten trailing dims per key
N = Jtree[keys[0]].shape[0]
cols, names, slices, c0 = [], [], {}, 0
for k in keys:
Jk = jnp.asarray(Jtree[k]).reshape(N, -1)
cols.append(Jk)
d = Jk.shape[1]
names += [k] if d == 1 else [f"{k}[{i}]" for i in range(d)]
slices[k] = slice(c0, c0 + d)
c0 += d
J = jnp.hstack(cols) # (N, P)
F = J.T @ J
return {
"fisher": F,
"col_slices": slices,
"col_names": names,
"params_unconstrained": _pdic
}