"""
MCMC sampling for GP hyperparameters.
MCMCSampler is the base class providing shared diagnostics, summary,
and plotting utilities. BlackJAXSampler adds gradient-based NUTS
sampling via the BlackJAX library.
"""
import jax
jax.config.update("jax_enable_x64", True)
import os
import jax.numpy as jnp
import numpy as np
try:
from .gp_solver import GPSolver
except ImportError:
from gp_solver import GPSolver
__all__ = ["MCMCSampler", "BlackJAXSampler", "DynestySampler"]
# =====================================================================
# Base class
# =====================================================================
[docs]
class MCMCSampler:
"""
Base MCMC sampler for GP hyperparameters.
Wraps a GPSolver object and provides shared storage, diagnostics,
summary statistics, corner plots, and dict conversion. Subclasses
implement specific sampling algorithms (e.g. NUTS).
Parameters
----------
gp : GPSolver
A configured GPSolver instance.
"""
def __init__(self, gp):
if not isinstance(gp, GPSolver):
raise TypeError("gp must be a GPSolver instance")
self.gp = gp
# Storage for MCMC results
self.samples = None
self._info = None
self._last_state = None
self._adapted_step_size = None
self._adapted_inv_mass = None
self._last_rng_key = None
self._checkpoint_file = None
self._map_completed = False
self._warmup_completed = False
@property
def param_keys(self):
return self.gp.param_keys
@property
def n_params(self):
return self.gp.n_params
# =================================================================
# Diagnostics
# =================================================================
[docs]
def summary(self):
"""
Print summary statistics of the posterior samples.
Returns
-------
stats : dict
Parameter names mapped to (mean, std, 16%, 50%, 84%).
"""
if self.samples is None:
raise RuntimeError("No samples available. Run a sampler first.")
samples = np.asarray(self.samples)
stats = {}
print(f"{'param':>12s} {'mean':>10s} {'std':>10s} "
f"{'16%':>10s} {'50%':>10s} {'84%':>10s}")
print("-" * 68)
for i, key in enumerate(self.param_keys):
col = samples[:, i]
q16, q50, q84 = np.percentile(col, [16, 50, 84])
m, s = np.mean(col), np.std(col)
stats[key] = {"mean": m, "std": s,
"q16": q16, "q50": q50, "q84": q84}
print(f"{key:>12s} {m:10.5f} {s:10.5f} "
f"{q16:10.5f} {q50:10.5f} {q84:10.5f}")
if self._info is not None and "n_divergent" in self._info:
print(f"\nDivergences: {self._info['n_divergent']}")
print(f"Mean acceptance: "
f"{np.mean(self._info['acceptance_rate']):.3f}")
return stats
[docs]
def plot_covariance(self, method="fisher", theta_map=None,
n_sigma=2, n_grid=200, samples=None,
figsize=None, color="C0", alpha=0.3,
true_params=None, savefig=None,
**corner_kwargs):
"""
Corner plot of 2D covariance ellipses from the Hessian or Fisher
matrix, with 1D marginal Gaussians on the diagonal.
Uses ``corner.corner`` to lay out the figure when MCMC samples
are provided, and overlays the Laplace/Fisher Gaussian
approximation (ellipses + 1D marginals).
Parameters
----------
method : {"fisher", "hessian_map", "laplace"}
Which matrix to use for the Gaussian approximation.
theta_map : array_like, optional
Center of the ellipses. If None, uses MAP estimate.
n_sigma : float
Number of sigma for the ellipse contours (default 2).
n_grid : int
Grid resolution for the ellipse curves (default 200).
samples : array_like, optional
If provided, plotted as the corner histogram/contours.
If None, the figure is created with empty axes and only
the Gaussian approximation is drawn.
figsize : tuple, optional
Figure size.
color : str
Color for Gaussian ellipses and marginals (default "C0").
alpha : float
Fill alpha for the ellipse interiors (default 0.3).
true_params : dict or array_like, optional
True parameter values to mark with crosshairs.
savefig : str, optional
If provided, save figure to this path.
**corner_kwargs
Extra keyword arguments forwarded to ``corner.corner``
(e.g. ``quantiles``, ``show_titles``, ``hist_kwargs``).
Returns
-------
fig, axes : matplotlib Figure and 2D array of Axes.
"""
import corner
import matplotlib.pyplot as plt
gp = self.gp
n = self.n_params
keys = list(self.param_keys)
# Get MAP center
if theta_map is None:
if gp.map_estimate is None:
gp.fit_map()
theta_map = gp.map_estimate
mu = np.asarray(theta_map, dtype=np.float64)
# Get covariance matrix
if method == "fisher":
if gp._fisher_matrix is None:
gp.mass_matrix_fisher(theta_map)
cov = np.asarray(gp.inverse_mass_matrix)
elif method in ("hessian_map", "laplace"):
if method == "hessian_map":
if gp._hessian is None:
gp.mass_matrix_hessian_map(theta_map)
else:
if gp._laplace_hessian is None:
gp.mass_matrix_laplace(theta_map)
cov = np.asarray(gp.inverse_mass_matrix)
else:
raise ValueError(f"Unknown method: {method!r}")
# Parse true_params
if true_params is not None:
if isinstance(true_params, dict):
true_arr = np.array([true_params.get(k, np.nan)
for k in self.param_keys])
else:
true_arr = np.asarray(true_params, dtype=np.float64)
else:
true_arr = None
# --- Build figure with corner ----------------------------------
if samples is not None:
samples = np.asarray(samples)
corner_defaults = dict(
labels=keys,
show_titles=True,
plot_density=True,
plot_contours=True,
hist_kwargs={"density": True},
)
corner_defaults.update(corner_kwargs)
if true_arr is not None:
corner_defaults.setdefault(
"truths", list(true_arr))
if figsize is not None:
corner_defaults["fig"] = plt.figure(figsize=figsize)
fig = corner.corner(samples, **corner_defaults)
else:
# No samples — create an empty corner-style grid
if figsize is None:
figsize = (2.5 * n, 2.5 * n)
fig, axes_grid = plt.subplots(n, n, figsize=figsize)
if n == 1:
axes_grid = np.array([[axes_grid]])
# Hide upper triangle
for i in range(n):
for j in range(n):
if j > i:
axes_grid[i, j].set_visible(False)
# Label edges
for i in range(n):
for j in range(i + 1):
if i == n - 1:
axes_grid[i, j].set_xlabel(keys[j])
if j == 0 and i > 0:
axes_grid[i, j].set_ylabel(keys[i])
fig.subplots_adjust(hspace=0.05, wspace=0.05)
axes = np.array(fig.axes).reshape(n, n)
# --- Overlay Gaussian approximation ----------------------------
t_ellipse = np.linspace(0, 2 * np.pi, n_grid)
for i in range(n):
for j in range(n):
if j > i:
continue
ax = axes[i, j]
if i == j:
# 1D Gaussian marginal (properly normalized)
sigma_i = np.sqrt(cov[i, i])
x_range = np.linspace(mu[i] - 4 * sigma_i,
mu[i] + 4 * sigma_i, 300)
pdf = (np.exp(-0.5 * ((x_range - mu[i]) / sigma_i) ** 2)
/ (sigma_i * np.sqrt(2 * np.pi)))
ax.plot(x_range, pdf, color=color, lw=1.5)
ax.fill_between(x_range, pdf, alpha=alpha,
color=color)
ax.axvline(mu[i], color=color, ls="--", lw=0.8)
if true_arr is not None and np.isfinite(true_arr[i]):
ax.axvline(true_arr[i], color="k", ls=":", lw=1)
else:
# 2D covariance ellipses
sub_cov = np.array([[cov[j, j], cov[j, i]],
[cov[i, j], cov[i, i]]])
eigvals, eigvecs = np.linalg.eigh(sub_cov)
eigvals = np.maximum(eigvals, 0)
for ns in [1, n_sigma]:
xy = (eigvecs
[docs]
@ np.diag(np.sqrt(eigvals) * ns)
@ np.array([np.cos(t_ellipse),
np.sin(t_ellipse)]))
ax.plot(mu[j] + xy[0], mu[i] + xy[1],
color=color, lw=1.2)
xy1 = (eigvecs
@ np.diag(np.sqrt(eigvals))
@ np.array([np.cos(t_ellipse),
np.sin(t_ellipse)]))
ax.fill(mu[j] + xy1[0], mu[i] + xy1[1],
color=color, alpha=alpha)
ax.plot(mu[j], mu[i], "+", color=color, ms=8,
mew=1.5)
if true_arr is not None:
if (np.isfinite(true_arr[j])
and np.isfinite(true_arr[i])):
ax.plot(true_arr[j], true_arr[i], "x",
color="k", ms=6, mew=1.2)
if savefig is not None:
fig.savefig(savefig, dpi=150, bbox_inches="tight")
return fig, axes
def plot_corner_map(self, samples=None, checkpoint_path=None,
cmap="viridis", marker_size=40, savefig=None,
true_params=None, **corner_kwargs):
"""
Corner plot of MCMC samples with MAP solutions overlaid as
scatter points colored by their log-likelihood.
Parameters
----------
samples : array_like, optional
Shape ``(n_samples, n_params)``. If None, loads from the
checkpoint file.
checkpoint_path : str, optional
Path to checkpoint ``.npz`` file containing MAP solutions.
If None, uses the default checkpoint file.
cmap : str
Colormap for the MAP scatter points (default "viridis").
marker_size : float
Marker size for scatter points (default 40).
savefig : str, optional
If provided, save figure to this path.
true_params : dict or array_like, optional
True parameter values to mark with crosshairs.
**corner_kwargs
Extra keyword arguments forwarded to ``corner.corner``.
Returns
-------
fig, axes : matplotlib Figure and 2D array of Axes.
"""
import corner
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
gp = self.gp
keys = list(self.param_keys)
n = len(keys)
# Load samples
if samples is None:
path = checkpoint_path or self._checkpoint_file
if path is None:
raise ValueError("No samples provided and no checkpoint file set.")
samples = self.load_samples(path)
samples = np.asarray(samples)
# Flatten multi-chain samples: (n_chains, n_samples, n_params) -> (N, n_params)
if samples.ndim == 3:
samples = samples.reshape(-1, samples.shape[-1])
# Load MAP solutions and log-likelihoods
path = checkpoint_path or self._checkpoint_file
if path is not None:
_path = path if path.endswith(".npz") else path + ".npz"
data = np.load(_path, allow_pickle=True)
all_theta_maps = list(data["all_theta_maps"]) if "all_theta_maps" in data else None
map_loglikes = np.asarray(data["map_loglikes"]) if "map_loglikes" in data else None
data.close()
else:
all_theta_maps = getattr(self, 'all_theta_maps', None)
map_loglikes = getattr(self, 'map_loglikes', None)
if all_theta_maps is None:
raise ValueError("No MAP solutions found. Run find_map / run_map first.")
# Convert MAP dicts to arrays
map_arrays = []
for tm in all_theta_maps:
if isinstance(tm, dict):
arr = np.array([float(tm[k]) for k in keys])
else:
arr = np.asarray(tm, dtype=np.float64)
map_arrays.append(arr)
map_arr = np.array(map_arrays) # (n_maps, n_params)
# Compute log-likelihoods if not available
if map_loglikes is None:
map_loglikes = np.array([
float(gp.log_likelihood_fn(m)) for m in map_arr])
# Parse true_params
if true_params is not None:
if isinstance(true_params, dict):
true_arr = [true_params.get(k, np.nan) for k in keys]
else:
true_arr = list(np.asarray(true_params, dtype=np.float64))
else:
true_arr = None
# Build corner plot
old_usetex = matplotlib.rcParams.get("text.usetex", False)
matplotlib.rcParams["text.usetex"] = False
try:
corner_defaults = dict(
labels=keys,
show_titles=True,
title_fmt=".3f",
plot_density=True,
plot_contours=True,
hist_kwargs={"density": True},
)
if true_arr is not None:
corner_defaults["truths"] = true_arr
corner_defaults.update(corner_kwargs)
fig = corner.corner(samples, **corner_defaults)
axes = np.array(fig.axes).reshape(n, n)
# Color normalization for log-likelihoods
norm = Normalize(vmin=map_loglikes.min(), vmax=map_loglikes.max())
cm = plt.get_cmap(cmap)
# Overlay MAP scatter on off-diagonal panels
for i in range(n):
for j in range(i):
ax = axes[i, j]
sc = ax.scatter(map_arr[:, j], map_arr[:, i],
c=map_loglikes, cmap=cmap, norm=norm,
s=marker_size, edgecolors="k",
linewidths=0.5, zorder=10)
# Overlay MAP values on diagonal histograms
for i in range(n):
ax = axes[i, i]
colors = cm(norm(map_loglikes))
for k, (val, c) in enumerate(zip(map_arr[:, i], colors)):
ax.axvline(val, color=c, lw=1.2, alpha=0.7, zorder=10)
# Add colorbar
cbar_ax = fig.add_axes([1.02, 0.15, 0.02, 0.7])
cbar = fig.colorbar(
plt.cm.ScalarMappable(norm=norm, cmap=cmap),
cax=cbar_ax, label="log-likelihood")
if savefig is not None:
fig.savefig(savefig, dpi=150, bbox_inches="tight")
print(f"Corner + MAP plot saved to {savefig}")
finally:
matplotlib.rcParams["text.usetex"] = old_usetex
return fig, axes
[docs]
def to_dict(self, samples=None):
"""
Convert samples array to a dict keyed by parameter name.
Parameters
----------
samples : jnp.ndarray, optional
Shape (n_samples, n_params). If None, uses self.samples.
Returns
-------
d : dict
{param_name: array of shape (n_samples,)}
"""
if samples is None:
samples = self.samples
if samples is None:
raise RuntimeError("No samples available.")
samples = np.asarray(samples)
return {k: samples[:, i]
for i, k in enumerate(self.param_keys)}
# =====================================================================
# BlackJAX NUTS sampler
# =====================================================================
[docs]
class BlackJAXSampler(MCMCSampler):
"""
NUTS sampler using the BlackJAX library.
Inherits diagnostics, summary, plotting, and dict conversion from
MCMCSampler. Adds ``run_map``, ``run_warmup``, and
``run_sampling`` for gradient-based No-U-Turn sampling with
dual-averaging step-size adaptation.
When multiple chains are requested, sampling is parallelized across
available devices via ``jax.pmap``. Chains are distributed evenly
across devices (``n_chains`` must be divisible by
``jax.device_count()``). On a single GPU this behaves identically
to the previous ``jax.vmap`` implementation.
Parameters
----------
gp : GPSolver
A configured GPSolver instance.
save_dir : str, optional
Directory for all outputs produced by this sampler (corner
plots, covariance plots, etc.). Created automatically if it
does not exist. When set, ``save_checkpoint`` will default to
saving the checkpoint inside this directory.
checkpoint_file : str, optional
Path to the checkpoint file. When provided, overrides the
default ``save_dir/mcmc_checkpoint.npz``. If neither
``checkpoint_file`` nor ``save_dir`` is given, no checkpoint
file is set until one is passed to a later method.
"""
def __init__(self, gp, save_dir="results", checkpoint_file="mcmc_checkpoint.npz"):
super().__init__(gp)
if save_dir is not None:
import os
os.makedirs(save_dir, exist_ok=True)
self.save_dir = save_dir
if checkpoint_file is not None:
self._checkpoint_file = os.path.join(save_dir, checkpoint_file)
else:
self._checkpoint_file = None
self._n_devices = jax.device_count()
[docs]
def run_map(self, nopt=10, keys=None, checkpoint_file=None, theta0=None, **kwargs):
"""
Find MAP solutions via parallel multi-start optimization.
Runs ``GPSolver.fit_map_parallel`` and stores the results.
If the checkpoint file already contains MAP data, loads from
it instead of re-running the optimization.
Parameters
----------
nopt : int
Number of independent optimization restarts (default 10).
keys : list of str, optional
Parameter names to optimize. If None, uses all bounded
parameters from GPSolver.
theta0 : dict, optional
Initial parameter guess to include as one of the
optimization starting points. Replaces one random
start so the total number of restarts stays ``nopt``.
checkpoint_file : str, optional
Path to save/load MAP solutions. If provided, also
updates the sampler's default checkpoint path. Defaults
to ``self._checkpoint_file``.
**kwargs
Additional keyword arguments passed to
``GPSolver.fit_map_parallel`` (e.g. ``method``,
``maxiter``).
Returns
-------
all_theta_maps : list of dict
All MAP solutions sorted by objective (best first).
"""
if checkpoint_file is not None:
self._checkpoint_file = checkpoint_file
path = self._checkpoint_file
# Try loading from disk if checkpoint file is provided
if path is not None and os.path.exists(path):
data = np.load(path, allow_pickle=True)
if "all_theta_maps" in data:
all_theta_maps = list(data["all_theta_maps"])
data.close()
print(f"Loaded {len(all_theta_maps)} MAP solutions from {path}")
self.all_theta_maps = all_theta_maps
self.theta_map = all_theta_maps[0]
return all_theta_maps
data.close()
# Run optimization
gp = self.gp
if keys is None:
keys = list(gp.param_keys)
print(f"Finding MAP solution ({nopt} restarts, returning all)...")
all_theta_maps, all_results = gp.fit_map_parallel(
nopt=nopt, keys=keys, return_all=True, theta0=theta0, **kwargs)
self.all_theta_maps = all_theta_maps
self.theta_map = all_theta_maps[0]
# Compute log-likelihood for each MAP solution
map_loglikes = []
for tm in all_theta_maps:
if isinstance(tm, dict):
arr = np.array([float(tm[k]) for k in gp.param_keys],
dtype=np.float64)
else:
arr = np.asarray(tm, dtype=np.float64)
map_loglikes.append(float(gp.log_likelihood_fn(arr)))
self.map_loglikes = np.array(map_loglikes)
print(f"MAP solution: {self.theta_map}")
# Save to checkpoint file
if path is not None:
_path = path if path.endswith(".npz") else path + ".npz"
# Merge with existing checkpoint data if present
save_kwargs = {}
if os.path.exists(_path):
existing = np.load(_path, allow_pickle=True)
for k in existing.files:
save_kwargs[k] = existing[k]
existing.close()
save_kwargs["theta_map"] = self.theta_map
save_kwargs["all_theta_maps"] = all_theta_maps
save_kwargs["map_loglikes"] = self.map_loglikes
np.savez(path, **save_kwargs)
print(f"MAP solutions saved to {path}")
self._map_completed = True
return all_theta_maps
def _run_pathfinder_warmup(self, rng_key, theta_inits, log_posterior_fn,
target_accept=0.8, maxiter=100, maxcor=10,
num_elbo_samples=200):
"""Multi-path Pathfinder warmup: run single-path Pathfinder from
each starting position, select the best by ELBO, and extract a
diagonal mass matrix from the L-BFGS inverse Hessian factors.
Parameters
----------
rng_key : jax.random.PRNGKey
theta_inits : jnp.ndarray, shape (n_paths, n_params)
Starting positions (e.g. top MAP solutions).
log_posterior_fn : callable
Un-normalized log-density.
target_accept : float
Target NUTS acceptance rate (used for step-size selection).
maxiter : int
Maximum L-BFGS iterations per path (default 100).
maxcor : int
L-BFGS history size (default 10).
num_elbo_samples : int
Samples per path for ELBO estimation (default 200).
Returns
-------
best_position : jnp.ndarray, shape (n_params,)
Position from the path with highest ELBO.
inv_mass_diag : jnp.ndarray, shape (n_params,)
Diagonal inverse mass matrix from the best path's
L-BFGS inverse Hessian approximation.
step_size : float
Initial NUTS step size derived from the mass matrix.
all_positions : jnp.ndarray, shape (n_paths, n_params)
Best position from each path (for chain initialization).
"""
from blackjax.vi.pathfinder import approximate as pf_approximate
n_paths = theta_inits.shape[0]
n_params = theta_inits.shape[1]
print(f"Pathfinder warmup: {n_paths} paths, "
f"maxiter={maxiter}, maxcor={maxcor}...")
best_states = []
for i in range(n_paths):
path_key = jax.random.fold_in(rng_key, i)
state, info = pf_approximate(
path_key,
log_posterior_fn,
theta_inits[i],
num_samples=num_elbo_samples,
maxiter=maxiter,
maxcor=maxcor,
)
best_states.append(state)
print(f" Path {i}: ELBO = {float(state.elbo):.2f}")
# Select best path by ELBO
elbos = jnp.array([s.elbo for s in best_states])
best_idx = int(jnp.argmax(elbos))
best = best_states[best_idx]
print(f" Best path: {best_idx} (ELBO = {float(best.elbo):.2f})")
# Extract diagonal inverse mass matrix from L-BFGS factors.
# The approximate inverse Hessian is:
# H^{-1} = diag(alpha) + beta @ gamma @ beta^T
# We take the diagonal for a diagonal mass matrix.
alpha = best.alpha
beta = best.beta
gamma = best.gamma
bg = beta @ gamma # (n_params, 2*maxcor)
bgbt_diag = jnp.sum(bg * beta, axis=1) # diag(beta @ gamma @ beta^T)
inv_mass_diag = alpha + bgbt_diag
# Clamp extreme values for stability (same logic as window_adaptation path)
median_var = jnp.median(inv_mass_diag)
inv_mass_diag = jnp.clip(inv_mass_diag,
median_var * 1e-4, median_var * 1e4)
# Ensure all positive
inv_mass_diag = jnp.maximum(inv_mass_diag, 1e-10)
# Step size heuristic: use dual averaging target rate
step_size = float(jnp.median(jnp.sqrt(inv_mass_diag)))
step_size = max(step_size, 1e-5)
all_positions = jnp.array([s.position for s in best_states])
print(f" Adapted step size: {step_size:.6f}")
print(f" Inv mass diag range: [{float(inv_mass_diag.min()):.2e}, "
f"{float(inv_mass_diag.max()):.2e}]")
return best.position, inv_mass_diag, step_size, all_positions
[docs]
def run_warmup(self, n_warmup=500, theta_init=None,
mass_matrix_method="hessian_map", step_size=None,
rng_key=None, target_accept=0.8, progress_bar=False,
n_chains=1, checkpoint_file=None,
warmup_method="window_adaptation",
pathfinder_maxiter=100, pathfinder_maxcor=10,
pathfinder_num_elbo=200):
"""
Run warmup phase: adapt step size and mass matrix.
Supports three warmup strategies:
- ``"window_adaptation"`` (default): BlackJAX's standard
dual-averaging window adaptation of both step size and
mass matrix.
- ``"pathfinder"``: multi-path Pathfinder via L-BFGS.
- ``"dual_averaging"``: fixes the mass matrix (from Hessian
at MAP) and only adapts the step size.
After warmup, adapted parameters are stored on the sampler
and a checkpoint is saved (if ``checkpoint_file`` is set).
Parameters
----------
n_warmup : int
Number of warmup steps (default 500).
theta_init : dict or array_like, optional
Initial position. If None, uses GPSolver's MAP estimate.
Can also be a list of dicts or 2-D array for per-chain
starting points.
mass_matrix_method : {"hessian_map", "fisher", "laplace", "diagonal", None}
Method to estimate the mass matrix.
step_size : float, optional
Initial NUTS step size. If None, a heuristic is used.
rng_key : jax.random.PRNGKey, optional
Random key. Default: PRNGKey(0).
target_accept : float
Target acceptance rate (default 0.8).
progress_bar : bool
If True, show progress during window adaptation.
n_chains : int
Number of chains (used to validate device count and
store per-chain init positions).
checkpoint_file : str, optional
Override the default checkpoint file path. When set,
updates ``self._checkpoint_file`` for all subsequent
save/load operations. Defaults to
``save_dir/mcmc_checkpoint.npz`` when ``save_dir`` is set.
warmup_method : {"window_adaptation", "pathfinder", "dual_averaging"}
Warmup strategy.
pathfinder_maxiter : int
Max L-BFGS iterations for Pathfinder (default 100).
pathfinder_maxcor : int
L-BFGS history size for Pathfinder (default 10).
pathfinder_num_elbo : int
Number of ELBO samples for Pathfinder (default 200).
"""
import blackjax
gp = self.gp
n_devices = self._n_devices
if n_chains > 1 and n_chains % n_devices != 0:
raise ValueError(
f"n_chains ({n_chains}) must be divisible by the number of "
f"available devices ({n_devices}). Use n_chains in "
f"{[n_devices * i for i in range(1, 5)]}.")
if rng_key is None:
rng_key = jax.random.PRNGKey(0)
# Initial position
# theta_init can be:
# - None -> use GPSolver MAP estimate
# - dict -> single starting point
# - 1-D array -> single starting point
# - list of dicts -> per-chain starting points
# - 2-D array (n_chains, n_params) -> per-chain starting points
_per_chain_inits = None
if theta_init is None:
if gp.map_estimate is None:
gp.fit_map()
theta_init = gp.map_estimate
elif isinstance(theta_init, list) and len(theta_init) > 0 and isinstance(theta_init[0], dict):
_per_chain_inits = jnp.array(
[[float(d[k]) for k in gp.param_keys] for d in theta_init],
dtype=jnp.float64)
theta_init = _per_chain_inits[0] # best MAP for warmup
elif isinstance(theta_init, dict):
theta_init = jnp.array(
[float(theta_init[k]) for k in gp.param_keys],
dtype=jnp.float64)
else:
theta_init = jnp.asarray(theta_init, dtype=jnp.float64)
if theta_init.ndim == 2:
_per_chain_inits = theta_init
theta_init = _per_chain_inits[0] # best MAP for warmup
self._n_chains = n_chains
self._per_chain_inits = _per_chain_inits
self._theta_init = theta_init
if checkpoint_file is not None:
self._checkpoint_file = checkpoint_file
warmup_key, sample_key = jax.random.split(rng_key)
if warmup_method == "pathfinder":
# -- Pathfinder warmup ------------------------------------
# Build init array for multi-path: use per-chain inits if
# available, otherwise tile the single init.
if _per_chain_inits is not None:
pf_inits = _per_chain_inits
else:
pf_inits = theta_init[None, :] # single path
best_pos, adapted_inv_mass, adapted_step_size, pf_positions = \
self._run_pathfinder_warmup(
warmup_key, pf_inits, gp.log_posterior,
target_accept=target_accept,
maxiter=pathfinder_maxiter,
maxcor=pathfinder_maxcor,
num_elbo_samples=pathfinder_num_elbo,
)
if step_size is not None:
adapted_step_size = step_size
# Override per-chain inits with pathfinder best positions
self._per_chain_inits = pf_positions
# Create a NUTS state at the best position for single-chain path
warmup_state = blackjax.nuts(
gp.log_posterior,
step_size=adapted_step_size,
inverse_mass_matrix=adapted_inv_mass,
).init(best_pos)
elif warmup_method == "dual_averaging":
# -- Dual averaging warmup (fixed mass matrix) ------------
# Estimate mass matrix (delegated to GPSolver)
inv_mass = gp._get_mass_matrix(mass_matrix_method, theta_init)
# For NUTS, use diagonal mass matrix (more robust than full)
inv_mass_diag = jnp.diag(inv_mass)
# Clamp extreme values for stability
median_var = jnp.median(inv_mass_diag)
inv_mass_diag = jnp.clip(inv_mass_diag,
median_var * 1e-4, median_var * 1e4)
# Initial step size: heuristic based on mass matrix scale
if step_size is None:
step_size = float(0.5 * jnp.min(jnp.sqrt(inv_mass_diag)))
step_size = max(step_size, 1e-5)
print(f"Warmup: {n_warmup} steps (dual averaging, fixed mass matrix, "
f"init step_size={step_size:.6f})...")
from blackjax.adaptation.step_size import (
dual_averaging_adaptation,
)
da_init, da_update, da_final = dual_averaging_adaptation(
target=target_accept,
)
da_state = da_init(step_size)
kernel = blackjax.nuts(
gp.log_posterior,
step_size=step_size,
inverse_mass_matrix=inv_mass_diag,
)
warmup_state = kernel.init(theta_init)
for i in range(n_warmup):
warmup_key, step_key = jax.random.split(warmup_key)
warmup_state, info = kernel.step(step_key, warmup_state)
da_state = da_update(da_state, info.acceptance_rate)
new_step_size = jnp.exp(da_state.log_step_size)
kernel = blackjax.nuts(
gp.log_posterior,
step_size=new_step_size,
inverse_mass_matrix=inv_mass_diag,
)
adapted_step_size = jnp.exp(da_state.log_step_size_avg)
adapted_inv_mass = inv_mass_diag
print(f" Adapted step size: {float(adapted_step_size):.6f}")
# Re-init warmup_state with the final adapted step size
kernel = blackjax.nuts(
gp.log_posterior,
step_size=adapted_step_size,
inverse_mass_matrix=adapted_inv_mass,
)
warmup_state = kernel.init(warmup_state.position)
else:
# -- Window adaptation warmup (default) -------------------
# Estimate mass matrix (delegated to GPSolver)
inv_mass = gp._get_mass_matrix(mass_matrix_method, theta_init)
# For NUTS, use diagonal mass matrix (more robust than full)
inv_mass_diag = jnp.diag(inv_mass)
# Clamp extreme values for stability
median_var = jnp.median(inv_mass_diag)
inv_mass_diag = jnp.clip(inv_mass_diag,
median_var * 1e-4, median_var * 1e4)
# Initial step size: heuristic based on mass matrix scale
if step_size is None:
step_size = float(0.5 * jnp.min(jnp.sqrt(inv_mass_diag)))
step_size = max(step_size, 1e-5)
print(f"Warmup: {n_warmup} steps (window adaptation, "
f"init step_size={step_size:.6f})...")
warmup = blackjax.window_adaptation(
blackjax.nuts,
gp.log_posterior,
is_mass_matrix_diagonal=True,
initial_step_size=step_size,
target_acceptance_rate=target_accept,
progress_bar=progress_bar,
)
adapt_results, adapt_info = warmup.run(
warmup_key, theta_init, num_steps=n_warmup,
)
adapted_step_size = adapt_results.parameters["step_size"]
adapted_inv_mass = adapt_results.parameters["inverse_mass_matrix"]
warmup_state = adapt_results.state
print(f" Adapted step size: {float(adapted_step_size):.6f}")
del adapt_results, adapt_info, warmup
# Store adapted parameters and checkpoint between warmup and
# sampling so that warmup intermediates can be freed.
self._adapted_step_size = float(adapted_step_size)
self._adapted_inv_mass = np.asarray(adapted_inv_mass)
self._last_state = warmup_state
self._last_rng_key = sample_key
self._info = {
"step_size": self._adapted_step_size,
"n_warmup": n_warmup,
"n_samples": 0,
"n_chains": n_chains,
"n_divergent": 0,
}
self._warmup_completed = True
if self._checkpoint_file is not None:
self.save_checkpoint(append_samples=False)
print(" Warmup checkpoint saved; clearing warmup memory...")
jax.clear_caches()
# Warm up the log_posterior JIT kernel so CUDA timers are
# accurate when the sampling scan launches.
jax.block_until_ready(gp.log_posterior(theta_init))
[docs]
def run_sampling(self, n_samples=1000):
"""
Run NUTS sampling using adapted parameters from ``run_warmup``.
Must be called after ``run_warmup`` (or will use parameters
restored from a checkpoint).
Parameters
----------
n_samples : int
Number of post-warmup samples per chain (default 1000).
Returns
-------
samples : jnp.ndarray
Shape ``(n_samples, n_params)`` when ``n_chains=1``, or
``(n_chains, n_samples, n_params)`` when ``n_chains > 1``.
info : dict
Sampling diagnostics (arrays have a leading chain
dimension when ``n_chains > 1``).
"""
import blackjax
gp = self.gp
n_devices = self._n_devices
n_chains = self._n_chains
adapted_step_size = self._adapted_step_size
adapted_inv_mass = self._adapted_inv_mass
warmup_state = self._last_state
sample_key = self._last_rng_key
n_warmup = self._info["n_warmup"]
_per_chain_inits = getattr(self, "_per_chain_inits", None)
theta_init = getattr(self, "_theta_init", warmup_state.position)
# -- Sampling via lax.scan -----------------------------------
def _run_one_chain(state, chain_key):
"""Sample one chain via lax.scan."""
kernel = blackjax.nuts(
gp.log_posterior,
step_size=adapted_step_size,
inverse_mass_matrix=adapted_inv_mass,
)
chain_keys = jax.random.split(chain_key, n_samples)
def one_step(carry, key_idx):
st, n_div = carry
key, _idx = key_idx
st, info = kernel.step(key, st)
n_div = n_div + info.is_divergent.astype(jnp.int32)
return (st, n_div), (st.position, info)
indices = jnp.arange(n_samples)
(final_st, total_div), (positions, infos) = jax.lax.scan(
one_step, (state, jnp.int32(0)), (chain_keys, indices),
)
return final_st, total_div, positions, infos, chain_keys[-1]
if n_chains > 1:
chains_per_device = n_chains // n_devices
print(f"Sampling {n_samples} iterations x {n_chains} chains "
f"across {n_devices} device(s)...")
# If the state is already multi-chain (from a previous
# run_sampling call), reuse it directly. Otherwise
# initialize per-chain states from MAP solutions or jitter.
is_multi_chain_state = warmup_state.position.ndim > 1
if is_multi_chain_state:
states = warmup_state
if sample_key.ndim > 1:
rng_key = jax.random.fold_in(sample_key[0], 1)
else:
rng_key = jax.random.fold_in(sample_key, 1)
sample_keys = jax.random.split(rng_key, n_chains)
else:
if _per_chain_inits is not None and _per_chain_inits.shape[0] >= n_chains:
init_positions = _per_chain_inits[:n_chains]
print(f" Using {n_chains} distinct MAP solutions as chain init positions")
else:
jitter_key, sample_key = jax.random.split(sample_key)
jitter_scale = 0.01 * jnp.sqrt(adapted_inv_mass)
noise = jax.random.normal(
jitter_key, shape=(n_chains, len(theta_init)))
init_positions = warmup_state.position[None, :] \
+ jitter_scale[None, :] * noise
# Initialize NUTS states for each chain from init positions
init_fn = blackjax.nuts(
gp.log_posterior,
step_size=adapted_step_size,
inverse_mass_matrix=adapted_inv_mass,
).init
states = jax.vmap(init_fn)(init_positions)
sample_keys = jax.random.split(sample_key, n_chains)
# Reshape for pmap: (n_devices, chains_per_device, ...)
states = jax.tree.map(
lambda x: x.reshape(n_devices, chains_per_device, *x.shape[1:]),
states)
sample_keys = sample_keys.reshape(n_devices, chains_per_device, -1)
# pmap over devices, vmap over chains within each device
all_final, all_div, all_pos, all_infos, all_last_keys = jax.pmap(
jax.vmap(_run_one_chain)
)(states, sample_keys)
# Flatten device dimension: (n_devices, chains_per_device, ...) -> (n_chains, ...)
all_final = jax.tree.map(
lambda x: x.reshape(n_chains, *x.shape[2:]), all_final)
all_div = all_div.reshape(n_chains)
all_pos = all_pos.reshape(n_chains, n_samples, -1)
all_infos = jax.tree.map(
lambda x: x.reshape(n_chains, *x.shape[2:]), all_infos)
all_last_keys = all_last_keys.reshape(n_chains, -1)
# Shape: (n_chains, n_samples, n_params)
self.samples = all_pos
self._info = {
"divergences": np.asarray(all_infos.is_divergent),
"acceptance_rate": np.asarray(all_infos.acceptance_rate),
"num_steps": np.asarray(
all_infos.num_integration_steps),
"step_size": float(adapted_step_size),
"n_warmup": n_warmup,
"n_samples": n_samples,
"n_chains": n_chains,
"n_divergent": int(jnp.sum(all_div)),
}
self._last_state = jax.tree.map(jnp.array, all_final)
self._adapted_step_size = float(adapted_step_size)
self._adapted_inv_mass = np.asarray(adapted_inv_mass)
self._last_rng_key = jnp.array(all_last_keys)
total_div = int(jnp.sum(all_div))
mean_accept = float(jnp.mean(
jnp.array(self._info["acceptance_rate"])))
print(f"NUTS complete: {n_chains} chains x {n_samples} samples, "
f"{total_div} total divergences, "
f"mean acceptance rate = {mean_accept:.3f}")
else:
print(f"Sampling {n_samples} post-warmup iterations...")
final_state, total_div, positions, infos, last_key = \
_run_one_chain(warmup_state, sample_key)
self.samples = positions
self._info = {
"divergences": np.asarray(infos.is_divergent),
"acceptance_rate": np.asarray(infos.acceptance_rate),
"num_steps": np.asarray(infos.num_integration_steps),
"step_size": float(adapted_step_size),
"n_warmup": n_warmup,
"n_samples": n_samples,
"n_chains": 1,
"n_divergent": int(total_div),
}
self._last_state = final_state
self._adapted_step_size = float(adapted_step_size)
self._adapted_inv_mass = np.asarray(adapted_inv_mass)
self._last_rng_key = last_key
mean_accept = float(np.mean(self._info["acceptance_rate"]))
print(f"NUTS complete: {n_samples} samples, "
f"{int(total_div)} divergences, "
f"mean acceptance rate = {mean_accept:.3f}")
return self.samples, self._info
[docs]
def save_checkpoint(self, path=None, append_samples=True,
plot_corner=False):
"""
Save sampler state to disk for later resumption.
When ``append_samples=True`` (the default), new samples are
appended to any existing samples already stored in ``path``,
and ``self.samples`` is cleared from memory. This enables a
sample-checkpoint-clear loop that keeps memory usage constant.
Parameters
----------
path : str, optional
File path (saved as ``.npz``). If None, uses the
``checkpoint_file`` set in ``run_warmup``, or
``save_dir/checkpoint.npz`` if ``save_dir`` was set.
append_samples : bool
If True, append current ``self.samples`` to any samples
already on disk, then clear ``self.samples`` from memory.
If False, overwrite with only the current in-memory samples.
plot_corner : bool
If True, load all samples currently on disk after saving
and write a corner plot to ``save_dir/corner_plot.png``
(or alongside the checkpoint file if ``save_dir`` is not
set).
"""
import os
if self._last_state is None:
raise RuntimeError("No sampler state to save. Run run_warmup first.")
if path is None:
path = self._checkpoint_file
if path is None and self.save_dir is not None:
path = os.path.join(self.save_dir, "mcmc_checkpoint.npz")
if path is None:
raise ValueError(
"No path provided, no checkpoint_file set, and no save_dir. "
"Pass a path, set checkpoint_file in run_warmup, or set save_dir.")
samples_to_save = np.asarray(self.samples) if self.samples is not None else None
# Merge with samples already on disk
if append_samples and samples_to_save is not None:
import os
_path = path if path.endswith(".npz") else path + ".npz"
if os.path.exists(_path):
existing = np.load(_path)
if "samples" in existing and existing["samples"].size > 0:
# multi-chain: (n_chains, n_samples, n_params) → concat on axis=1
# single-chain: (n_samples, n_params) → concat on axis=0
cat_axis = 1 if samples_to_save.ndim == 3 else 0
samples_to_save = np.concatenate(
[existing["samples"], samples_to_save], axis=cat_axis)
existing.close()
save_kwargs = {
# NUTS state (shape has leading chain dim when n_chains > 1)
"position": np.asarray(self._last_state.position),
"logdensity": np.asarray(self._last_state.logdensity),
"logdensity_grad": np.asarray(self._last_state.logdensity_grad),
# Adapted kernel parameters
"step_size": np.asarray(self._adapted_step_size),
"inverse_mass_matrix": np.asarray(self._adapted_inv_mass),
"rng_key": np.asarray(self._last_rng_key),
# Diagnostics (scalars)
"n_warmup": np.asarray(self._info["n_warmup"]),
"n_chains": np.asarray(getattr(self, "_n_chains", 1)),
}
if samples_to_save is not None:
save_kwargs["samples"] = samples_to_save
n_on_disk = samples_to_save.shape[0]
else:
save_kwargs["samples"] = np.array([])
n_on_disk = 0
# Preserve MAP solutions and their log-likelihoods
if (hasattr(self, 'all_theta_maps')
and self.all_theta_maps is not None
and len(self.all_theta_maps) > 0):
save_kwargs['theta_map'] = self.theta_map
save_kwargs['all_theta_maps'] = np.array(
self.all_theta_maps, dtype=object)
# Compute log-likelihood for each MAP solution
gp = self.gp
map_loglikes = []
for tm in self.all_theta_maps:
if isinstance(tm, dict):
arr = np.array([float(tm[k]) for k in gp.param_keys],
dtype=np.float64)
else:
arr = np.asarray(tm, dtype=np.float64)
map_loglikes.append(float(gp.log_likelihood_fn(arr)))
save_kwargs['map_loglikes'] = np.array(map_loglikes)
np.savez(path, **save_kwargs)
if append_samples:
# Free in-memory samples and per-sample diagnostics
self.samples = None
self._info = {
"step_size": self._info["step_size"],
"n_warmup": self._info["n_warmup"],
"n_samples": n_on_disk,
"n_divergent": self._info.get("n_divergent", 0),
}
print(f"Checkpoint saved to {path} ({n_on_disk} samples on disk)")
if plot_corner and n_on_disk > 0:
import corner
import matplotlib
import matplotlib.pyplot as plt
_chk = path if path.endswith(".npz") else path + ".npz"
corner_dir = self.save_dir if self.save_dir is not None \
else os.path.dirname(os.path.abspath(_chk))
corner_path = os.path.join(corner_dir, "corner_plot.png")
all_samples = self.load_samples(path)
has_maps = (hasattr(self, 'all_theta_maps')
and self.all_theta_maps is not None
and len(self.all_theta_maps) > 0)
if has_maps:
fig, _ = self.plot_corner_map(
samples=all_samples, checkpoint_path=path,
savefig=corner_path)
else:
old_usetex = matplotlib.rcParams.get("text.usetex", False)
matplotlib.rcParams["text.usetex"] = False
try:
fig = corner.corner(
all_samples,
labels=list(self.param_keys),
show_titles=True,
title_fmt=".3f",
)
fig.savefig(corner_path, dpi=150, bbox_inches="tight")
finally:
matplotlib.rcParams["text.usetex"] = old_usetex
plt.close(fig)
print(f"Corner plot saved to {corner_path} "
f"({n_on_disk} samples)")
[docs]
def load_checkpoint(self, checkpoint_file=None):
"""
Restore sampler state from a checkpoint file.
Loads only the NUTS state and adapted kernel parameters needed
to resume sampling. Samples stored in the file are **not**
loaded into memory — use ``load_samples`` to read them later.
Parameters
----------
checkpoint_file : str, optional
Path to a ``.npz`` checkpoint file. If provided, also
updates the sampler's default checkpoint path. If None,
uses the default ``save_dir/mcmc_checkpoint.npz``.
"""
import blackjax
if checkpoint_file is not None:
self._checkpoint_file = checkpoint_file
path = self._checkpoint_file
if path is None:
raise ValueError(
"No checkpoint_file provided and no save_dir was set. "
"Pass a checkpoint_file or set save_dir.")
data = np.load(path)
# Reconstruct NUTS state (works for both single and multi-chain:
# arrays have a leading chain dimension when n_chains > 1)
self._last_state = blackjax.mcmc.hmc.HMCState(
position=jnp.asarray(data["position"]),
logdensity=jnp.asarray(data["logdensity"]),
logdensity_grad=jnp.asarray(data["logdensity_grad"]),
)
n_chains = int(data["n_chains"]) if "n_chains" in data else 1
self._n_chains = n_chains
if n_chains > 1:
self._adapted_step_size = np.asarray(data["step_size"])
else:
self._adapted_step_size = float(data["step_size"])
self._adapted_inv_mass = jnp.asarray(data["inverse_mass_matrix"])
self._last_rng_key = jnp.asarray(data["rng_key"])
n_on_disk = data["samples"].shape[0] if data["samples"].size > 0 else 0
n_warmup = int(data["n_warmup"])
# Restore MAP solutions if present
if "all_theta_maps" in data:
self.all_theta_maps = list(data["all_theta_maps"])
self.theta_map = (data["theta_map"].item()
if data["theta_map"].ndim == 0
else data["theta_map"])
if "map_loglikes" in data:
self.map_loglikes = np.asarray(data["map_loglikes"])
print(f"Restored {len(self.all_theta_maps)} MAP solutions "
f"from checkpoint")
data.close()
# Don't load samples into memory — keep it lightweight
self.samples = None
self._info = {
"step_size": self._adapted_step_size,
"n_warmup": n_warmup,
"n_samples": n_on_disk,
"n_chains": n_chains,
"n_divergent": 0,
}
print(f"Checkpoint loaded from {path} "
f"({n_on_disk} samples on disk, {n_chains} chain(s), "
f"not loaded into memory)")
@staticmethod
def _make_batched_vmap(fn, n_particles, batch_size, n_devices=None):
"""Replace ``jax.vmap(fn)`` with a multi-GPU batched version.
Particles are split into chunks of ``batch_size``, distributed
across ``n_devices`` GPUs with ``pmap``, and each device
evaluates its chunk with ``vmap``.
When only one device is available (or ``n_devices=1``) it
falls back to ``lax.map(vmap(fn), batches)`` so that only
``batch_size`` evaluations are live at once.
Parameters
----------
fn : callable
Scalar function of a single particle, e.g.
``loglikelihood_fn(theta) -> float``.
n_particles : int
Total number of particles. Must be divisible by
``batch_size``.
batch_size : int
Number of particles to evaluate simultaneously per
device.
n_devices : int, optional
Number of JAX devices to use. Defaults to all visible
devices.
Returns
-------
batched_fn : callable
``batched_fn(particles)`` with ``particles`` of shape
``(n_particles, ...)``, returns ``(n_particles, ...)``.
"""
if n_devices is None:
n_devices = jax.device_count()
if n_particles % batch_size != 0:
raise ValueError(
f"n_particles ({n_particles}) must be divisible by "
f"particle_batch_size ({batch_size}).")
n_batches = n_particles // batch_size
if n_devices > 1 and n_batches >= n_devices:
# Multi-GPU path: pmap across devices, scan over rounds
if n_batches % n_devices != 0:
raise ValueError(
f"n_particles / particle_batch_size "
f"({n_batches}) must be divisible by "
f"n_devices ({n_devices}).")
rounds_per_device = n_batches // n_devices
def batched_fn(all_particles):
# (n_devices, rounds_per_device, batch_size, ...)
shaped = all_particles.reshape(
n_devices, rounds_per_device, batch_size,
*all_particles.shape[1:])
def _device_work(device_batches):
# device_batches: (rounds_per_device, batch_size, ...)
def _one_round(_, batch):
return None, jax.vmap(fn)(batch)
_, results = jax.lax.scan(
_one_round, None, device_batches)
return results # (rounds_per_device, batch_size, ...)
# (n_devices, rounds_per_device, batch_size, ...)
out = jax.pmap(_device_work)(shaped)
flat = out.reshape(n_particles, *out.shape[3:])
# Strip pmap sharding so the next tempering step's
# pmap (which creates a new mesh) won't clash.
return jnp.array(np.asarray(flat))
else:
# Single-GPU path: sequential scan over batches
def batched_fn(all_particles):
shaped = all_particles.reshape(
n_batches, batch_size, *all_particles.shape[1:])
def _one_round(_, batch):
return None, jax.vmap(fn)(batch)
_, out = jax.lax.scan(_one_round, None, shaped)
return out.reshape(n_particles, *out.shape[2:])
return batched_fn
@staticmethod
def _make_batched_update(raw_nuts_kernel, nuts_init_fn,
tempered_logposterior_fn,
step_size, inverse_mass_matrix,
num_mcmc_steps,
n_particles, batch_size, n_devices=None,
max_num_doublings=10):
"""Batched MCMC rejuvenation distributed across GPUs.
Replaces ``jax.vmap(mcmc_kernel)`` in
``blackjax.smc.base.update_and_take_last`` with a
``pmap``/``scan``-based version so that only ``batch_size``
NUTS chains are live simultaneously. Uses the raw NUTS
kernel directly (not the wrapped ``SamplingAlgorithm``) so
the tempered log-posterior can be swapped each step.
Returns ``(update_fn, n_particles)``.
"""
if n_devices is None:
n_devices = jax.device_count()
if n_particles % batch_size != 0:
raise ValueError(
f"n_particles ({n_particles}) must be divisible by "
f"particle_batch_size ({batch_size}).")
n_batches = n_particles // batch_size
def _single_mcmc(rng_key, position):
state = nuts_init_fn(position, tempered_logposterior_fn)
def body_fn(state, rng_key):
new_state, info = raw_nuts_kernel(
rng_key, state, tempered_logposterior_fn,
step_size, inverse_mass_matrix,
max_num_doublings)
return new_state, info
keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info
if n_devices > 1 and n_batches >= n_devices:
if n_batches % n_devices != 0:
raise ValueError(
f"n_particles / particle_batch_size "
f"({n_batches}) must be divisible by "
f"n_devices ({n_devices}).")
rounds_per_device = n_batches // n_devices
def update_fn(keys, particles):
k_shaped = keys.reshape(
n_devices, rounds_per_device, batch_size,
*keys.shape[1:])
p_shaped = particles.reshape(
n_devices, rounds_per_device, batch_size,
*particles.shape[1:])
def _device_work(dk, dp):
def _one_round(_, args):
bk, bp = args
return None, jax.vmap(
_single_mcmc)(bk, bp)
_, results = jax.lax.scan(
_one_round, None, (dk, dp))
return results
positions, infos = jax.pmap(
_device_work)(k_shaped, p_shaped)
flat_pos = positions.reshape(
n_particles, *positions.shape[3:])
flat_infos = jax.tree.map(
lambda x: x.reshape(
n_particles, *x.shape[3:]),
infos)
# Strip pmap sharding so the next tempering step's
# pmap (which creates a new mesh) won't clash.
flat_pos = jnp.array(
np.asarray(flat_pos))
flat_infos = jax.tree.map(
lambda x: jnp.array(np.asarray(x)),
flat_infos)
return flat_pos, flat_infos
else:
def update_fn(keys, particles):
k_shaped = keys.reshape(
n_batches, batch_size, *keys.shape[1:])
p_shaped = particles.reshape(
n_batches, batch_size, *particles.shape[1:])
def _one_round(_, args):
bk, bp = args
return None, jax.vmap(
_single_mcmc)(bk, bp)
_, (positions, infos) = jax.lax.scan(
_one_round, None, (k_shaped, p_shaped))
flat_pos = positions.reshape(
n_particles, *positions.shape[2:])
flat_infos = jax.tree.map(
lambda x: x.reshape(
n_particles, *x.shape[2:]),
infos)
return flat_pos, flat_infos
return update_fn, n_particles
[docs]
def run_smc(self, n_particles=500, n_mcmc_steps=10,
n_adapt_steps=25, target_ess=0.5, target_accept=0.6,
rng_key=None, step_size=None,
mass_matrix_method="hessian_map", theta_init=None,
max_tempering_steps=200, checkpoint_every=10,
checkpoint_file=None, particle_batch_size=None,
max_num_doublings=10):
"""
Run adaptive tempered Sequential Monte Carlo.
Starts from the prior and anneals toward the full posterior
using an adaptive temperature schedule. At each tempering
step, particles are resampled and rejuvenated with NUTS
moves. The NUTS step size is re-adapted via dual averaging
at each tempering stage using a representative particle.
Parameters
----------
n_particles : int
Number of SMC particles (default 500).
n_mcmc_steps : int
NUTS rejuvenation steps per tempering stage (default 10).
n_adapt_steps : int
Dual-averaging warmup steps to adapt the NUTS step size
at each tempering stage (default 25).
target_ess : float
Target effective sample size as a fraction of
``n_particles`` (default 0.5).
target_accept : float
Target NUTS acceptance rate for dual averaging
(default 0.6).
rng_key : jax.random.PRNGKey, optional
Random key. Default: PRNGKey(42).
step_size : float, optional
Initial NUTS step size. If None, a heuristic from the
mass matrix is used.
mass_matrix_method : str, optional
Method to estimate the inverse mass matrix (default
``"hessian_map"``). Set to None to use an identity
matrix.
theta_init : dict or array_like, optional
Reference point for mass matrix estimation. If None,
the MAP estimate is used.
max_tempering_steps : int
Safety limit on the number of tempering stages
(default 200).
checkpoint_every : int
Save a checkpoint every this many tempering steps
(default 10). Set to 0 to disable periodic
checkpointing.
checkpoint_file : str, optional
Override the default checkpoint file path.
particle_batch_size : int, optional
Process particles in batches of this size to limit GPU
memory usage. When multiple GPUs are visible the
batches are distributed across devices via
``jax.pmap``. ``n_particles`` must be divisible by
this value (and by ``batch_size * n_devices`` for
multi-GPU). If None, all particles are evaluated at
once (original blackjax behavior).
max_num_doublings : int, optional
Maximum NUTS tree depth (default 10). Lower values
(e.g. 5-6) reduce peak GPU memory per particle at the
cost of shorter trajectories.
Returns
-------
samples : np.ndarray, shape (n_particles, n_params)
Weighted posterior particles at the final temperature.
info : dict
Diagnostics including tempering schedule and log
evidence estimate.
"""
import blackjax
from blackjax.smc.resampling import systematic
from blackjax.adaptation.step_size import dual_averaging_adaptation
gp = self.gp
n_devices = self._n_devices
if rng_key is None:
rng_key = jax.random.PRNGKey(42)
if checkpoint_file is not None:
self._checkpoint_file = checkpoint_file
# --- Mass matrix and step size ---------------------------------
if theta_init is None:
if gp.map_estimate is None:
gp.fit_map()
theta_init = gp.map_estimate
if isinstance(theta_init, dict):
theta_init = jnp.array(
[float(theta_init[k]) for k in gp.param_keys],
dtype=jnp.float64)
if mass_matrix_method is not None:
inv_mass = gp._get_mass_matrix(mass_matrix_method, theta_init)
inv_mass_diag = jnp.diag(inv_mass)
median_var = jnp.median(inv_mass_diag)
inv_mass_diag = jnp.clip(inv_mass_diag,
median_var * 1e-4, median_var * 1e4)
else:
inv_mass_diag = jnp.ones(gp.n_params)
if step_size is None:
step_size = float(0.5 * jnp.min(jnp.sqrt(inv_mass_diag)))
step_size = max(step_size, 1e-5)
# --- Draw initial particles ------------------------------------
init_key, run_key = jax.random.split(rng_key)
bounds = gp.bounds
lo, hi = bounds[:, 0], bounds[:, 1]
# Use Laplace approximation around MAP solutions when available,
# otherwise fall back to uniform prior draws.
if (hasattr(self, 'all_theta_maps')
and self.all_theta_maps is not None
and len(self.all_theta_maps) > 0):
# Convert MAP solutions to arrays
map_arrays = []
for tm in self.all_theta_maps:
if isinstance(tm, dict):
arr = jnp.array([float(tm[k]) for k in gp.param_keys],
dtype=jnp.float64)
else:
arr = jnp.asarray(tm, dtype=jnp.float64)
map_arrays.append(arr)
# Full Hessian covariance (already computed for mass matrix)
if mass_matrix_method is not None:
cov = np.asarray(inv_mass) # (n_params, n_params)
else:
cov = np.diag(np.asarray(inv_mass_diag))
# Inflate covariance to broaden the Laplace approximation
cov_inflated = 4.0 * cov
try:
L_cov = np.linalg.cholesky(cov_inflated)
except np.linalg.LinAlgError:
# Fall back to diagonal if Cholesky fails
L_cov = np.diag(np.sqrt(np.abs(np.diag(cov_inflated))))
n_maps = len(map_arrays)
# Distribute particles evenly across MAP solutions
particles_per_map = n_particles // n_maps
remainder = n_particles % n_maps
all_particles = []
for i, mu in enumerate(map_arrays):
n_i = particles_per_map + (1 if i < remainder else 0)
init_key, draw_key = jax.random.split(init_key)
z = jax.random.normal(draw_key, shape=(n_i, gp.n_params))
pts = jnp.asarray(mu) + z @ jnp.asarray(L_cov.T)
all_particles.append(pts)
particles = jnp.concatenate(all_particles, axis=0)
# Clip to prior bounds
particles = jnp.clip(particles, lo, hi)
print(f"Initialized {n_particles} particles from Laplace "
f"approximation around {n_maps} MAP solution(s)")
else:
particles = (jax.random.uniform(init_key,
shape=(n_particles, gp.n_params))
* (hi - lo) + lo)
print(f"Initialized {n_particles} particles from uniform prior")
# --- Helper: adapt step size via dual averaging ---------------
def _adapt_step_size(adapt_key, position, current_step_size,
lam, n_steps):
"""Run short dual-averaging warmup on one particle at the
tempered log-density ``log_prior + lam * log_likelihood``."""
def tempered_logdensity(theta):
return gp.log_prior_fn(theta) + lam * gp.log_likelihood_fn(theta)
da_init, da_update, _ = dual_averaging_adaptation(
target=target_accept)
da_state = da_init(current_step_size)
kernel = blackjax.nuts(
tempered_logdensity,
step_size=current_step_size,
inverse_mass_matrix=inv_mass_diag,
max_num_doublings=max_num_doublings,
)
state = kernel.init(position)
for _ in range(n_steps):
adapt_key, step_key = jax.random.split(adapt_key)
state, info = kernel.step(step_key, state)
da_state = da_update(da_state, info.acceptance_rate)
new_ss = jnp.exp(da_state.log_step_size)
kernel = blackjax.nuts(
tempered_logdensity,
step_size=new_ss,
inverse_mass_matrix=inv_mass_diag,
max_num_doublings=max_num_doublings,
)
adapted_ss = float(jnp.exp(da_state.log_step_size_avg))
return max(adapted_ss, 1e-6)
# --- Build SMC kernel factory ---------------------------------
use_batched = particle_batch_size is not None
import blackjax.mcmc.nuts as nuts_module
raw_nuts_kernel = nuts_module.build_kernel()
def _build_smc_kernel(ss):
if not use_batched:
# Standard blackjax path — vmap over all particles.
# Use the raw kernel so SMC can swap the log-density
# at each tempering step.
return blackjax.adaptive_tempered_smc.build_kernel(
logprior_fn=gp.log_prior_fn,
loglikelihood_fn=gp.log_likelihood_fn,
mcmc_step_fn=raw_nuts_kernel,
mcmc_init_fn=nuts_module.init,
resampling_fn=systematic,
target_ess=target_ess,
)
# ----------------------------------------------------------
# Batched path: replace jax.vmap with pmap/scan batches
# so only particle_batch_size evaluations are live at once.
# Uses the raw NUTS kernel so the tempered log-posterior
# can be swapped at each tempering step.
# ----------------------------------------------------------
import blackjax.smc.ess as ess_mod
import blackjax.smc.solver as solver_mod
from blackjax.smc.tempered import TemperedSMCState
import blackjax.smc.base as smc_base
from jax.scipy.special import logsumexp
batched_ll = self._make_batched_vmap(
gp.log_likelihood_fn, n_particles,
particle_batch_size, n_devices)
def _compute_delta(state):
logprob = batched_ll(state.particles)
n = logprob.shape[0]
target_val = jnp.log(n * target_ess)
max_delta = 1 - state.tempering_param
def fun_to_solve(delta):
log_w = jnp.nan_to_num(-delta * logprob)
return ess_mod.log_ess(log_w) - target_val
delta = solver_mod.dichotomy(
fun_to_solve, 0.0, max_delta)
return jnp.clip(delta, 0.0, max_delta)
def _batched_tempered_kernel(
rng_key, state, num_mcmc_steps_,
tempering_param, mcmc_parameters):
delta = tempering_param - state.tempering_param
cur_ss = mcmc_parameters["step_size"]
cur_imm = mcmc_parameters["inverse_mass_matrix"]
# Batched weight function
def log_weights_fn(position):
return delta * gp.log_likelihood_fn(position)
batched_weight_fn = self._make_batched_vmap(
log_weights_fn, n_particles,
particle_batch_size, n_devices)
# Tempered log-posterior for MCMC rejuvenation.
# Use the NEW temperature so particles are moved toward
# the correct target, and to avoid 0 * log_likelihood
# at lambda=0 (which produces NaN gradients when the GP
# covariance is ill-conditioned).
_new_tp = tempering_param
def tempered_logposterior_fn(position):
return (gp.log_prior_fn(position)
+ _new_tp
* gp.log_likelihood_fn(position))
# Build batched MCMC update using raw kernel
update_fn, _ = self._make_batched_update(
raw_nuts_kernel,
nuts_module.init,
tempered_logposterior_fn,
cur_ss, cur_imm,
num_mcmc_steps_,
n_particles,
particle_batch_size,
n_devices,
max_num_doublings=max_num_doublings,
)
# --- Resample, update, reweight (mirrors smc.base.step)
resampling_key, updating_key = jax.random.split(
rng_key, 2)
resampling_idx = systematic(
resampling_key, state.weights, n_particles)
resampled = jax.tree.map(
lambda x: x[resampling_idx], state.particles)
keys = jax.random.split(updating_key, n_particles)
new_particles, update_info = update_fn(
keys, resampled)
log_w = batched_weight_fn(new_particles)
logsum_w = logsumexp(log_w)
norm_const = logsum_w - jnp.log(n_particles)
weights = jnp.exp(log_w - logsum_w)
new_state = TemperedSMCState(
new_particles, weights,
state.tempering_param + delta)
info = smc_base.SMCInfo(
resampling_idx, norm_const, update_info)
return new_state, info
def kernel(rng_key, state, num_mcmc_steps,
mcmc_parameters):
delta = _compute_delta(state)
tempering_param = delta + state.tempering_param
return _batched_tempered_kernel(
rng_key, state, num_mcmc_steps,
tempering_param, mcmc_parameters)
return kernel
smc_state = blackjax.adaptive_tempered_smc.init(particles)
# --- Checkpoint helper ----------------------------------------
chk_path = self._checkpoint_file
def _save_smc_checkpoint(smc_st, lambdas_, step_sizes_,
log_ev, run_key_, step_size_,
inv_mass_diag_):
if chk_path is None:
return
save_kwargs = dict(
particles=np.asarray(smc_st.particles),
weights=np.asarray(smc_st.weights),
tempering_param=float(smc_st.tempering_param),
tempering_schedule=np.array(lambdas_),
step_sizes=np.array(step_sizes_),
log_evidence=log_ev,
step_size=step_size_,
inverse_mass_matrix=np.asarray(inv_mass_diag_),
rng_key=np.asarray(run_key_),
n_particles=n_particles,
n_mcmc_steps=n_mcmc_steps,
n_adapt_steps=n_adapt_steps,
# Include samples key for compatibility with load_samples
samples=np.asarray(smc_st.particles),
)
# Preserve MAP solutions if they exist on disk
_chk = chk_path if chk_path.endswith(".npz") else chk_path + ".npz"
if os.path.exists(_chk):
existing = np.load(_chk, allow_pickle=True)
if "all_theta_maps" in existing:
save_kwargs["all_theta_maps"] = existing["all_theta_maps"]
if "theta_map" in existing:
save_kwargs["theta_map"] = existing["theta_map"]
if "map_loglikes" in existing:
save_kwargs["map_loglikes"] = existing["map_loglikes"]
existing.close()
np.savez(chk_path, **save_kwargs)
print(f" Checkpoint saved to {chk_path} "
f"(lambda={float(smc_st.tempering_param):.6f})")
# --- Run tempering loop ---------------------------------------
if use_batched:
print(f"SMC: {n_particles} particles, "
f"batch_size={particle_batch_size}, "
f"n_devices={n_devices}, "
f"target_ess={target_ess:.2f}, "
f"n_adapt={n_adapt_steps}, "
f"target_accept={target_accept:.2f}")
else:
print(f"SMC: {n_particles} particles, "
f"target_ess={target_ess:.2f}, "
f"n_adapt={n_adapt_steps}, "
f"target_accept={target_accept:.2f}")
lambdas = [0.0]
step_sizes = [step_size]
log_evidence = 0.0
for step in range(max_tempering_steps):
run_key, step_key, adapt_key = jax.random.split(run_key, 3)
smc_kernel = _build_smc_kernel(step_size)
if use_batched:
# Batched path handles params internally via raw kernel
mcmc_params = {"step_size": step_size,
"inverse_mass_matrix": inv_mass_diag}
else:
# Non-batched blackjax path needs extend_params so
# unshared_parameters_and_step_fn sees shape[0]==1
# and treats them as shared across particles.
from blackjax.smc.base import extend_params
mcmc_params = extend_params(
{"step_size": jnp.array(step_size),
"inverse_mass_matrix": inv_mass_diag,
"max_num_doublings": jnp.array(max_num_doublings)})
smc_state, smc_info = smc_kernel(
step_key,
smc_state,
num_mcmc_steps=n_mcmc_steps,
mcmc_parameters=mcmc_params,
)
lam = float(smc_state.tempering_param)
lambdas.append(lam)
ll_inc = float(smc_info.log_likelihood_increment)
if np.isfinite(ll_inc):
log_evidence += ll_inc
else:
print(f" Warning: non-finite log_likelihood_increment "
f"({ll_inc}) at step {step + 1}, skipping")
print(f" Step {step + 1}: lambda={lam:.6f}, "
f"step_size={step_size:.6f}, log_Z={log_evidence:.2f}")
if lam >= 1.0:
step_sizes.append(step_size)
_save_smc_checkpoint(smc_state, lambdas, step_sizes,
log_evidence, run_key, step_size,
inv_mass_diag)
break
# Adapt step size for the next tempering stage using a
# high-weight particle as the warmup starting point.
best_idx = int(jnp.argmax(smc_state.weights))
best_particle = smc_state.particles[best_idx]
step_size = _adapt_step_size(
adapt_key, best_particle, step_size, lam, n_adapt_steps)
step_sizes.append(step_size)
# Periodic checkpoint
if (checkpoint_every > 0
and (step + 1) % checkpoint_every == 0):
_save_smc_checkpoint(smc_state, lambdas, step_sizes,
log_evidence, run_key, step_size,
inv_mass_diag)
else:
print(f" Warning: reached max_tempering_steps="
f"{max_tempering_steps} without reaching lambda=1.0 "
f"(final={lam:.6f})")
_save_smc_checkpoint(smc_state, lambdas, step_sizes,
log_evidence, run_key, step_size,
inv_mass_diag)
n_steps = len(lambdas) - 1
print(f"SMC complete: {n_steps} tempering steps, "
f"log_evidence={log_evidence:.2f}")
# --- Store results --------------------------------------------
final_particles = np.asarray(smc_state.particles)
self.samples = final_particles
self._n_chains = 1
self._info = {
"n_particles": n_particles,
"n_mcmc_steps": n_mcmc_steps,
"n_adapt_steps": n_adapt_steps,
"n_tempering_steps": n_steps,
"tempering_schedule": np.array(lambdas),
"step_sizes": np.array(step_sizes),
"log_evidence": log_evidence,
"step_size": step_size,
"n_warmup": 0,
"n_samples": n_particles,
"n_chains": 1,
}
return final_particles, self._info
[docs]
@staticmethod
def load_samples(path, flatten_chains=True):
"""
Read all samples from a checkpoint file without loading
the sampler state.
Parameters
----------
path : str
Path to a ``.npz`` checkpoint file.
flatten_chains : bool
If True (default), collapse the chain dimension so the
returned array is always ``(n_total, n_params)``. Set to
False to get the raw ``(n_chains, n_samples, n_params)``
array for per-chain diagnostics (e.g. R-hat).
Returns
-------
samples : np.ndarray
Shape ``(n_total, n_params)`` when ``flatten_chains=True``,
or ``(n_chains, n_samples, n_params)`` otherwise.
"""
data = np.load(path)
samples = data["samples"].copy()
data.close()
if flatten_chains and samples.ndim == 3:
n_chains, n_samp, n_params = samples.shape
samples = samples.reshape(n_chains * n_samp, n_params)
return samples
# =====================================================================
# Dynesty nested sampler
# =====================================================================
class DynestySampler(MCMCSampler):
"""
Nested sampler using the dynesty library.
Computes the posterior and Bayesian evidence via nested sampling.
Inherits diagnostics, summary, plotting, and dict conversion from
MCMCSampler.
Parameters
----------
gp : GPSolver
A configured GPSolver instance.
prior_transform : callable, optional
Function mapping the unit hypercube ``u`` (array of shape
``(n_params,)``) to the physical parameter space. If None,
a default uniform prior transform is constructed from the
GPSolver bounds.
save_dir : str, optional
Directory for outputs. Created automatically if it does not
exist.
checkpoint_file : str, optional
Path to the checkpoint file (relative to ``save_dir``).
"""
def __init__(self, gp, prior_transform=None, save_dir="results",
checkpoint_file="mcmc_checkpoint.npz"):
super().__init__(gp)
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
self.save_dir = save_dir
if checkpoint_file is not None:
self._checkpoint_file = os.path.join(save_dir, checkpoint_file)
else:
self._checkpoint_file = None
if prior_transform is not None:
self.prior_transform = prior_transform
else:
bounds = np.asarray(gp.bounds)
lo = bounds[:, 0]
hi = bounds[:, 1]
def _default_prior_transform(u):
return lo + u * (hi - lo)
self.prior_transform = _default_prior_transform
def _log_likelihood(self, theta):
"""Evaluate the GP log-likelihood (numpy-compatible wrapper)."""
val = float(self.gp.log_likelihood_fn(jnp.asarray(theta)))
if not np.isfinite(val):
return -np.inf
return val
def run_map(self, nopt=10, keys=None, checkpoint_file=None,
theta0=None, **kwargs):
"""
Find MAP solutions via parallel multi-start optimization.
Same interface as ``BlackJAXSampler.run_map``.
Parameters
----------
nopt : int
Number of independent optimization restarts (default 10).
keys : list of str, optional
Parameter names to optimize.
theta0 : dict, optional
Initial parameter guess.
checkpoint_file : str, optional
Path to save/load MAP solutions.
**kwargs
Additional keyword arguments passed to
``GPSolver.fit_map_parallel``.
Returns
-------
all_theta_maps : list of dict
All MAP solutions sorted by objective (best first).
"""
if checkpoint_file is not None:
self._checkpoint_file = checkpoint_file
path = self._checkpoint_file
if path is not None and os.path.exists(path):
data = np.load(path, allow_pickle=True)
if "all_theta_maps" in data:
all_theta_maps = list(data["all_theta_maps"])
data.close()
print(f"Loaded {len(all_theta_maps)} MAP solutions from {path}")
self.all_theta_maps = all_theta_maps
self.theta_map = all_theta_maps[0]
return all_theta_maps
data.close()
gp = self.gp
if keys is None:
keys = list(gp.param_keys)
print(f"Finding MAP solution ({nopt} restarts, returning all)...")
all_theta_maps, all_results = gp.fit_map_parallel(
nopt=nopt, keys=keys, return_all=True, theta0=theta0, **kwargs)
self.all_theta_maps = all_theta_maps
self.theta_map = all_theta_maps[0]
map_loglikes = []
for tm in all_theta_maps:
if isinstance(tm, dict):
arr = np.array([float(tm[k]) for k in gp.param_keys],
dtype=np.float64)
else:
arr = np.asarray(tm, dtype=np.float64)
map_loglikes.append(float(gp.log_likelihood_fn(arr)))
self.map_loglikes = np.array(map_loglikes)
print(f"MAP solution: {self.theta_map}")
if path is not None:
_path = path if path.endswith(".npz") else path + ".npz"
save_kwargs = {}
if os.path.exists(_path):
existing = np.load(_path, allow_pickle=True)
for k in existing.files:
save_kwargs[k] = existing[k]
existing.close()
save_kwargs["theta_map"] = self.theta_map
save_kwargs["all_theta_maps"] = all_theta_maps
save_kwargs["map_loglikes"] = self.map_loglikes
np.savez(path, **save_kwargs)
print(f"MAP solutions saved to {path}")
self._map_completed = True
return all_theta_maps
def run_sampling(self, nlive=500, dlogz=0.01, bound="multi",
sample="auto", rstate=None, checkpoint_file=None,
**kwargs):
"""
Run dynesty nested sampling.
Parameters
----------
nlive : int
Number of live points (default 500).
dlogz : float
Stopping criterion on the remaining evidence (default 0.01).
bound : str
Bounding method (default "multi").
sample : str
Sampling method (default "auto").
rstate : numpy.random.RandomState, optional
Random state for reproducibility.
checkpoint_file : str, optional
Override the default checkpoint file path.
**kwargs
Additional keyword arguments passed to
``dynesty.NestedSampler`` and ``sampler.run_nested``.
Returns
-------
samples : np.ndarray
Equally weighted posterior samples, shape
``(n_samples, n_params)``.
info : dict
Diagnostics including log-evidence and its uncertainty.
"""
import dynesty
if checkpoint_file is not None:
self._checkpoint_file = checkpoint_file
n_params = self.n_params
# Separate constructor kwargs from run_nested kwargs
constructor_keys = {
"update_interval", "first_update", "npdim", "bootstrap",
"enlarge", "vol_dec", "vol_check", "walks", "facc",
"slices", "fmove", "max_move", "maxiter_init",
"maxcall_init", "ncdim", "pool", "queue_size",
}
constructor_kwargs = {k: v for k, v in kwargs.items()
if k in constructor_keys}
run_kwargs = {k: v for k, v in kwargs.items()
if k not in constructor_keys}
print(f"Dynesty nested sampling: nlive={nlive}, "
f"dlogz={dlogz}, bound={bound}, sample={sample}")
sampler = dynesty.NestedSampler(
self._log_likelihood,
self.prior_transform,
n_params,
nlive=nlive,
bound=bound,
sample=sample,
rstate=rstate,
**constructor_kwargs,
)
sampler.run_nested(dlogz=dlogz, **run_kwargs)
results = sampler.results
# Extract equally weighted posterior samples
from dynesty.utils import resample_equal
weights = np.exp(results.logwt - results.logz[-1])
samples = resample_equal(results.samples, weights)
self.samples = samples
self._dynesty_results = results
self._info = {
"logz": results.logz[-1],
"logzerr": results.logzerr[-1],
"nlive": nlive,
"niter": results.niter,
"ncall": sum(results.ncall),
"eff": results.eff,
"n_samples": len(samples),
"n_chains": 1,
"n_warmup": 0,
}
print(f"Dynesty complete: {results.niter} iterations, "
f"{len(samples)} posterior samples")
print(f" log(Z) = {results.logz[-1]:.2f} "
f"+/- {results.logzerr[-1]:.2f}")
if self._checkpoint_file is not None:
self.save_checkpoint()
return samples, self._info
def run_dynamic_sampling(self, nlive_init=500, nlive_batch=250,
dlogz_init=0.01, maxbatch=10,
wt_kwargs=None, rstate=None,
checkpoint_file=None, **kwargs):
"""
Run dynesty dynamic nested sampling.
Dynamic nested sampling allocates live points adaptively to
improve both evidence and posterior estimates.
Parameters
----------
nlive_init : int
Initial number of live points (default 500).
nlive_batch : int
Number of live points per batch (default 250).
dlogz_init : float
Stopping criterion for the initial baseline run
(default 0.01).
maxbatch : int
Maximum number of dynamic batches (default 10).
wt_kwargs : dict, optional
Keyword arguments for the importance weight function.
Default: ``{"pfrac": 0.8}`` (80% posterior, 20% evidence).
rstate : numpy.random.RandomState, optional
Random state for reproducibility.
checkpoint_file : str, optional
Override the default checkpoint file path.
**kwargs
Additional keyword arguments passed to
``dynesty.DynamicNestedSampler``.
Returns
-------
samples : np.ndarray
Equally weighted posterior samples, shape
``(n_samples, n_params)``.
info : dict
Diagnostics including log-evidence and its uncertainty.
"""
import dynesty
if checkpoint_file is not None:
self._checkpoint_file = checkpoint_file
if wt_kwargs is None:
wt_kwargs = {"pfrac": 0.8}
n_params = self.n_params
print(f"Dynesty dynamic nested sampling: "
f"nlive_init={nlive_init}, nlive_batch={nlive_batch}, "
f"maxbatch={maxbatch}")
sampler = dynesty.DynamicNestedSampler(
self._log_likelihood,
self.prior_transform,
n_params,
rstate=rstate,
**kwargs,
)
sampler.run_nested(
nlive_init=nlive_init,
nlive_batch=nlive_batch,
dlogz_init=dlogz_init,
maxbatch=maxbatch,
wt_kwargs=wt_kwargs,
)
results = sampler.results
from dynesty.utils import resample_equal
weights = np.exp(results.logwt - results.logz[-1])
samples = resample_equal(results.samples, weights)
self.samples = samples
self._dynesty_results = results
self._info = {
"logz": results.logz[-1],
"logzerr": results.logzerr[-1],
"nlive_init": nlive_init,
"nlive_batch": nlive_batch,
"niter": results.niter,
"ncall": sum(results.ncall),
"eff": results.eff,
"n_samples": len(samples),
"n_chains": 1,
"n_warmup": 0,
}
print(f"Dynamic nested sampling complete: {results.niter} "
f"iterations, {len(samples)} posterior samples")
print(f" log(Z) = {results.logz[-1]:.2f} "
f"+/- {results.logzerr[-1]:.2f}")
if self._checkpoint_file is not None:
self.save_checkpoint()
return samples, self._info
def save_checkpoint(self, path=None):
"""
Save sampler results to disk.
Parameters
----------
path : str, optional
File path (saved as ``.npz``). If None, uses the
default checkpoint file.
"""
if path is None:
path = self._checkpoint_file
if path is None:
raise ValueError(
"No path provided and no checkpoint_file set.")
save_kwargs = {
"samples": np.asarray(self.samples) if self.samples is not None else np.array([]),
}
if self._info is not None:
save_kwargs["logz"] = self._info.get("logz", np.nan)
save_kwargs["logzerr"] = self._info.get("logzerr", np.nan)
# Preserve MAP solutions
if (hasattr(self, "all_theta_maps")
and self.all_theta_maps is not None):
save_kwargs["theta_map"] = self.theta_map
save_kwargs["all_theta_maps"] = np.array(
self.all_theta_maps, dtype=object)
if hasattr(self, "map_loglikes"):
save_kwargs["map_loglikes"] = self.map_loglikes
np.savez(path, **save_kwargs)
n_samples = self.samples.shape[0] if self.samples is not None else 0
print(f"Checkpoint saved to {path} ({n_samples} samples)")
def load_checkpoint(self, checkpoint_file=None):
"""
Restore sampler results from a checkpoint file.
Parameters
----------
checkpoint_file : str, optional
Path to a ``.npz`` checkpoint file.
"""
if checkpoint_file is not None:
self._checkpoint_file = checkpoint_file
path = self._checkpoint_file
if path is None:
raise ValueError("No checkpoint_file provided or set.")
data = np.load(path, allow_pickle=True)
if "samples" in data and data["samples"].size > 0:
self.samples = data["samples"].copy()
self._info = {
"logz": float(data["logz"]) if "logz" in data else None,
"logzerr": float(data["logzerr"]) if "logzerr" in data else None,
"n_samples": self.samples.shape[0] if self.samples is not None else 0,
"n_chains": 1,
"n_warmup": 0,
}
if "all_theta_maps" in data:
self.all_theta_maps = list(data["all_theta_maps"])
self.theta_map = (data["theta_map"].item()
if data["theta_map"].ndim == 0
else data["theta_map"])
if "map_loglikes" in data:
self.map_loglikes = np.asarray(data["map_loglikes"])
data.close()
n_samples = self.samples.shape[0] if self.samples is not None else 0
logz = self._info["logz"]
logz_str = f", log(Z)={logz:.2f}" if logz is not None else ""
print(f"Checkpoint loaded from {path} "
f"({n_samples} samples{logz_str})")
@staticmethod
def load_samples(path, flatten_chains=True):
"""
Read samples from a checkpoint file.
Parameters
----------
path : str
Path to a ``.npz`` checkpoint file.
flatten_chains : bool
Ignored (included for API compatibility with
BlackJAXSampler).
Returns
-------
samples : np.ndarray
Shape ``(n_samples, n_params)``.
"""
data = np.load(path)
samples = data["samples"].copy()
data.close()
return samples