"""
MCMC sampling for GP hyperparameters.
MCMCSampler is the base class providing shared diagnostics, summary,
and plotting utilities. BlackJAXSampler adds gradient-based NUTS
sampling via the BlackJAX library.
"""
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
try:
from .gp_solver import GPSolver
except ImportError:
from gp_solver import GPSolver
__all__ = ["MCMCSampler", "BlackJAXSampler"]
# =====================================================================
# Base class
# =====================================================================
[docs]
class MCMCSampler:
"""
Base MCMC sampler for GP hyperparameters.
Wraps a GPSolver object and provides shared storage, diagnostics,
summary statistics, corner plots, and dict conversion. Subclasses
implement specific sampling algorithms (e.g. NUTS).
Parameters
----------
gp : GPSolver
A configured GPSolver instance.
"""
def __init__(self, gp):
if not isinstance(gp, GPSolver):
raise TypeError("gp must be a GPSolver instance")
self.gp = gp
# Storage for MCMC results
self.samples = None
self._info = None
self._last_state = None
self._adapted_step_size = None
self._adapted_inv_mass = None
self._last_rng_key = None
self._checkpoint_file = None
@property
def param_keys(self):
return self.gp.param_keys
@property
def n_params(self):
return self.gp.n_params
# =================================================================
# Diagnostics
# =================================================================
[docs]
def summary(self):
"""
Print summary statistics of the posterior samples.
Returns
-------
stats : dict
Parameter names mapped to (mean, std, 16%, 50%, 84%).
"""
if self.samples is None:
raise RuntimeError("No samples available. Run a sampler first.")
samples = np.asarray(self.samples)
stats = {}
print(f"{'param':>12s} {'mean':>10s} {'std':>10s} "
f"{'16%':>10s} {'50%':>10s} {'84%':>10s}")
print("-" * 68)
for i, key in enumerate(self.param_keys):
col = samples[:, i]
q16, q50, q84 = np.percentile(col, [16, 50, 84])
m, s = np.mean(col), np.std(col)
stats[key] = {"mean": m, "std": s,
"q16": q16, "q50": q50, "q84": q84}
print(f"{key:>12s} {m:10.5f} {s:10.5f} "
f"{q16:10.5f} {q50:10.5f} {q84:10.5f}")
if self._info is not None and "n_divergent" in self._info:
print(f"\nDivergences: {self._info['n_divergent']}")
print(f"Mean acceptance: "
f"{np.mean(self._info['acceptance_rate']):.3f}")
return stats
[docs]
def plot_covariance(self, method="fisher", theta_map=None,
n_sigma=2, n_grid=200, samples=None,
figsize=None, color="C0", alpha=0.3,
true_params=None, savefig=None,
**corner_kwargs):
"""
Corner plot of 2D covariance ellipses from the Hessian or Fisher
matrix, with 1D marginal Gaussians on the diagonal.
Uses ``corner.corner`` to lay out the figure when MCMC samples
are provided, and overlays the Laplace/Fisher Gaussian
approximation (ellipses + 1D marginals).
Parameters
----------
method : {"fisher", "hessian_map", "laplace"}
Which matrix to use for the Gaussian approximation.
theta_map : array_like, optional
Center of the ellipses. If None, uses MAP estimate.
n_sigma : float
Number of sigma for the ellipse contours (default 2).
n_grid : int
Grid resolution for the ellipse curves (default 200).
samples : array_like, optional
If provided, plotted as the corner histogram/contours.
If None, the figure is created with empty axes and only
the Gaussian approximation is drawn.
figsize : tuple, optional
Figure size.
color : str
Color for Gaussian ellipses and marginals (default "C0").
alpha : float
Fill alpha for the ellipse interiors (default 0.3).
true_params : dict or array_like, optional
True parameter values to mark with crosshairs.
savefig : str, optional
If provided, save figure to this path.
**corner_kwargs
Extra keyword arguments forwarded to ``corner.corner``
(e.g. ``quantiles``, ``show_titles``, ``hist_kwargs``).
Returns
-------
fig, axes : matplotlib Figure and 2D array of Axes.
"""
import corner
import matplotlib.pyplot as plt
gp = self.gp
n = self.n_params
keys = list(self.param_keys)
# Get MAP center
if theta_map is None:
if gp.map_estimate is None:
gp.fit_map()
theta_map = gp.map_estimate
mu = np.asarray(theta_map, dtype=np.float64)
# Get covariance matrix
if method == "fisher":
if gp._fisher_matrix is None:
gp.mass_matrix_fisher(theta_map)
cov = np.asarray(gp.inverse_mass_matrix)
elif method in ("hessian_map", "laplace"):
if method == "hessian_map":
if gp._hessian is None:
gp.mass_matrix_hessian_map(theta_map)
else:
if gp._laplace_hessian is None:
gp.mass_matrix_laplace(theta_map)
cov = np.asarray(gp.inverse_mass_matrix)
else:
raise ValueError(f"Unknown method: {method!r}")
# Parse true_params
if true_params is not None:
if isinstance(true_params, dict):
true_arr = np.array([true_params.get(k, np.nan)
for k in self.param_keys])
else:
true_arr = np.asarray(true_params, dtype=np.float64)
else:
true_arr = None
# --- Build figure with corner ----------------------------------
if samples is not None:
samples = np.asarray(samples)
corner_defaults = dict(
labels=keys,
show_titles=True,
plot_density=True,
plot_contours=True,
hist_kwargs={"density": True},
)
corner_defaults.update(corner_kwargs)
if true_arr is not None:
corner_defaults.setdefault(
"truths", list(true_arr))
if figsize is not None:
corner_defaults["fig"] = plt.figure(figsize=figsize)
fig = corner.corner(samples, **corner_defaults)
else:
# No samples — create an empty corner-style grid
if figsize is None:
figsize = (2.5 * n, 2.5 * n)
fig, axes_grid = plt.subplots(n, n, figsize=figsize)
if n == 1:
axes_grid = np.array([[axes_grid]])
# Hide upper triangle
for i in range(n):
for j in range(n):
if j > i:
axes_grid[i, j].set_visible(False)
# Label edges
for i in range(n):
for j in range(i + 1):
if i == n - 1:
axes_grid[i, j].set_xlabel(keys[j])
if j == 0 and i > 0:
axes_grid[i, j].set_ylabel(keys[i])
fig.subplots_adjust(hspace=0.05, wspace=0.05)
axes = np.array(fig.axes).reshape(n, n)
# --- Overlay Gaussian approximation ----------------------------
t_ellipse = np.linspace(0, 2 * np.pi, n_grid)
for i in range(n):
for j in range(n):
if j > i:
continue
ax = axes[i, j]
if i == j:
# 1D Gaussian marginal (properly normalized)
sigma_i = np.sqrt(cov[i, i])
x_range = np.linspace(mu[i] - 4 * sigma_i,
mu[i] + 4 * sigma_i, 300)
pdf = (np.exp(-0.5 * ((x_range - mu[i]) / sigma_i) ** 2)
/ (sigma_i * np.sqrt(2 * np.pi)))
ax.plot(x_range, pdf, color=color, lw=1.5)
ax.fill_between(x_range, pdf, alpha=alpha,
color=color)
ax.axvline(mu[i], color=color, ls="--", lw=0.8)
if true_arr is not None and np.isfinite(true_arr[i]):
ax.axvline(true_arr[i], color="k", ls=":", lw=1)
else:
# 2D covariance ellipses
sub_cov = np.array([[cov[j, j], cov[j, i]],
[cov[i, j], cov[i, i]]])
eigvals, eigvecs = np.linalg.eigh(sub_cov)
eigvals = np.maximum(eigvals, 0)
for ns in [1, n_sigma]:
xy = (eigvecs
[docs]
@ np.diag(np.sqrt(eigvals) * ns)
@ np.array([np.cos(t_ellipse),
np.sin(t_ellipse)]))
ax.plot(mu[j] + xy[0], mu[i] + xy[1],
color=color, lw=1.2)
xy1 = (eigvecs
@ np.diag(np.sqrt(eigvals))
@ np.array([np.cos(t_ellipse),
np.sin(t_ellipse)]))
ax.fill(mu[j] + xy1[0], mu[i] + xy1[1],
color=color, alpha=alpha)
ax.plot(mu[j], mu[i], "+", color=color, ms=8,
mew=1.5)
if true_arr is not None:
if (np.isfinite(true_arr[j])
and np.isfinite(true_arr[i])):
ax.plot(true_arr[j], true_arr[i], "x",
color="k", ms=6, mew=1.2)
if savefig is not None:
fig.savefig(savefig, dpi=150, bbox_inches="tight")
return fig, axes
def to_dict(self, samples=None):
"""
Convert samples array to a dict keyed by parameter name.
Parameters
----------
samples : jnp.ndarray, optional
Shape (n_samples, n_params). If None, uses self.samples.
Returns
-------
d : dict
{param_name: array of shape (n_samples,)}
"""
if samples is None:
samples = self.samples
if samples is None:
raise RuntimeError("No samples available.")
samples = np.asarray(samples)
return {k: samples[:, i]
for i, k in enumerate(self.param_keys)}
# =====================================================================
# BlackJAX NUTS sampler
# =====================================================================
[docs]
class BlackJAXSampler(MCMCSampler):
"""
NUTS sampler using the BlackJAX library.
Inherits diagnostics, summary, plotting, and dict conversion from
MCMCSampler. Adds ``run_nuts`` for gradient-based No-U-Turn
sampling with dual-averaging step-size adaptation.
Parameters
----------
gp : GPSolver
A configured GPSolver instance.
save_dir : str, optional
Directory for all outputs produced by this sampler (corner
plots, covariance plots, etc.). Created automatically if it
does not exist. When set, ``save_checkpoint`` will default to
saving the checkpoint inside this directory.
"""
def __init__(self, gp, save_dir=None):
super().__init__(gp)
if save_dir is not None:
import os
os.makedirs(save_dir, exist_ok=True)
self.save_dir = save_dir
[docs]
def run_nuts(self, n_samples=1000, n_warmup=500, theta_init=None,
mass_matrix_method="hessian_map", step_size=None,
rng_key=None, target_accept=0.8, progress_bar=False,
n_chains=1, checkpoint_file=None):
"""
Run BlackJAX NUTS sampler.
Uses ``blackjax.window_adaptation`` for JIT-compiled warmup
(step-size and mass-matrix adaptation), then ``jax.lax.scan``
for the sampling loop. Both paths avoid Python-level loops,
minimizing retracing overhead and memory from accumulated
intermediates.
Warmup always runs a single chain to adapt the step size and
mass matrix. When ``n_chains > 1``, the adapted parameters
are shared across all chains, which are initialized with
jittered copies of the warmup endpoint and sampled in parallel
via ``jax.vmap``.
Parameters
----------
n_samples : int
Number of post-warmup samples per chain (default 1000).
n_warmup : int
Number of warmup steps for step-size adaptation (default 500).
theta_init : dict or array_like, optional
Initial position. If None, uses GPSolver's MAP estimate.
mass_matrix_method : {"hessian_map", "fisher", "laplace", "diagonal", None}
Method to estimate the mass matrix (delegated to GPSolver).
step_size : float, optional
Initial NUTS step size before adaptation. If None, a
heuristic based on the mass matrix scale is used.
rng_key : jax.random.PRNGKey, optional
Random key. Default: PRNGKey(0).
target_accept : float
Target acceptance rate for dual averaging (default 0.8).
progress_bar : bool
If True, print periodic progress updates during the
lax.scan sampling loop (default False).
n_chains : int
Number of independent chains to run in parallel via
``jax.vmap`` (default 1). All chains share the same
adapted step size and mass matrix from a single warmup.
checkpoint_file : str, optional
Default file path for ``save_checkpoint``. When set,
calling ``save_checkpoint()`` with no arguments will
use this path.
Returns
-------
samples : jnp.ndarray
Shape ``(n_samples, n_params)`` when ``n_chains=1``, or
``(n_chains, n_samples, n_params)`` when ``n_chains > 1``.
info : dict
Sampling diagnostics (arrays have a leading chain
dimension when ``n_chains > 1``).
"""
import blackjax
gp = self.gp
if rng_key is None:
rng_key = jax.random.PRNGKey(0)
# Initial position
if theta_init is None:
if gp.map_estimate is None:
gp.fit_map()
theta_init = gp.map_estimate
elif isinstance(theta_init, dict):
theta_init = jnp.array(
[float(theta_init[k]) for k in gp.param_keys],
dtype=jnp.float64)
else:
theta_init = jnp.asarray(theta_init, dtype=jnp.float64)
# Estimate mass matrix (delegated to GPSolver)
inv_mass = gp._get_mass_matrix(mass_matrix_method, theta_init)
# For NUTS, use diagonal mass matrix (more robust than full)
inv_mass_diag = jnp.diag(inv_mass)
# Clamp extreme values for stability
median_var = jnp.median(inv_mass_diag)
inv_mass_diag = jnp.clip(inv_mass_diag,
median_var * 1e-4, median_var * 1e4)
# Initial step size: heuristic based on mass matrix scale
if step_size is None:
step_size = float(0.5 * jnp.min(jnp.sqrt(inv_mass_diag)))
step_size = max(step_size, 1e-5)
self._n_chains = n_chains
if checkpoint_file is not None:
self._checkpoint_file = checkpoint_file
# -- Warmup (single chain) -----------------------------------
print(f"Warmup: {n_warmup} steps (window adaptation, "
f"init step_size={step_size:.6f})...")
warmup_key, sample_key = jax.random.split(rng_key)
warmup = blackjax.window_adaptation(
blackjax.nuts,
gp.log_posterior,
is_mass_matrix_diagonal=True,
initial_step_size=step_size,
target_acceptance_rate=target_accept,
progress_bar=progress_bar,
)
adapt_results, adapt_info = warmup.run(
warmup_key, theta_init, num_steps=n_warmup,
)
adapted_step_size = adapt_results.parameters["step_size"]
adapted_inv_mass = adapt_results.parameters["inverse_mass_matrix"]
warmup_state = adapt_results.state
print(f" Adapted step size: {float(adapted_step_size):.6f}")
# Store adapted parameters and checkpoint between warmup and
# sampling so that warmup intermediates can be freed.
self._adapted_step_size = float(adapted_step_size)
self._adapted_inv_mass = np.asarray(adapted_inv_mass)
self._last_state = warmup_state
self._last_rng_key = sample_key
self._info = {
"step_size": self._adapted_step_size,
"n_warmup": n_warmup,
"n_samples": 0,
"n_chains": n_chains,
"n_divergent": 0,
}
if self._checkpoint_file is not None:
self.save_checkpoint(append_samples=False)
print(" Warmup checkpoint saved; clearing warmup memory...")
del adapt_results, adapt_info, warmup
jax.clear_caches()
# -- Sampling via lax.scan -----------------------------------
def _run_one_chain(state, chain_key):
"""Sample one chain via lax.scan."""
kernel = blackjax.nuts(
gp.log_posterior,
step_size=adapted_step_size,
inverse_mass_matrix=adapted_inv_mass,
)
chain_keys = jax.random.split(chain_key, n_samples)
def one_step(carry, key_idx):
st, n_div = carry
key, _idx = key_idx
st, info = kernel.step(key, st)
n_div = n_div + info.is_divergent.astype(jnp.int32)
return (st, n_div), (st.position, info)
indices = jnp.arange(n_samples)
(final_st, total_div), (positions, infos) = jax.lax.scan(
one_step, (state, jnp.int32(0)), (chain_keys, indices),
)
return final_st, total_div, positions, infos, chain_keys[-1]
if n_chains > 1:
print(f"Sampling {n_samples} iterations x {n_chains} chains...")
# Jitter the warmup endpoint to create independent starting
# positions for each chain
jitter_key, sample_key = jax.random.split(sample_key)
jitter_scale = 0.01 * jnp.sqrt(adapted_inv_mass)
noise = jax.random.normal(
jitter_key, shape=(n_chains, len(theta_init)))
init_positions = warmup_state.position[None, :] \
+ jitter_scale[None, :] * noise
# Initialize NUTS states for each chain from jittered positions
init_fn = blackjax.nuts(
gp.log_posterior,
step_size=adapted_step_size,
inverse_mass_matrix=adapted_inv_mass,
).init
states = jax.vmap(init_fn)(init_positions)
sample_keys = jax.random.split(sample_key, n_chains)
all_final, all_div, all_pos, all_infos, all_last_keys = jax.vmap(
_run_one_chain
)(states, sample_keys)
# Shape: (n_chains, n_samples, n_params)
self.samples = all_pos
self._info = {
"divergences": np.asarray(all_infos.is_divergent),
"acceptance_rate": np.asarray(all_infos.acceptance_rate),
"num_steps": np.asarray(
all_infos.num_integration_steps),
"step_size": float(adapted_step_size),
"n_warmup": n_warmup,
"n_samples": n_samples,
"n_chains": n_chains,
"n_divergent": int(jnp.sum(all_div)),
}
self._last_state = all_final
self._adapted_step_size = float(adapted_step_size)
self._adapted_inv_mass = np.asarray(adapted_inv_mass)
self._last_rng_key = all_last_keys
total_div = int(jnp.sum(all_div))
mean_accept = float(jnp.mean(
jnp.array(self._info["acceptance_rate"])))
print(f"NUTS complete: {n_chains} chains x {n_samples} samples, "
f"{total_div} total divergences, "
f"mean acceptance rate = {mean_accept:.3f}")
else:
print(f"Sampling {n_samples} post-warmup iterations...")
final_state, total_div, positions, infos, last_key = \
_run_one_chain(warmup_state, sample_key)
self.samples = positions
self._info = {
"divergences": np.asarray(infos.is_divergent),
"acceptance_rate": np.asarray(infos.acceptance_rate),
"num_steps": np.asarray(infos.num_integration_steps),
"step_size": float(adapted_step_size),
"n_warmup": n_warmup,
"n_samples": n_samples,
"n_chains": 1,
"n_divergent": int(total_div),
}
self._last_state = final_state
self._adapted_step_size = float(adapted_step_size)
self._adapted_inv_mass = np.asarray(adapted_inv_mass)
self._last_rng_key = last_key
mean_accept = float(np.mean(self._info["acceptance_rate"]))
print(f"NUTS complete: {n_samples} samples, "
f"{int(total_div)} divergences, "
f"mean acceptance rate = {mean_accept:.3f}")
return self.samples, self._info
[docs]
def save_checkpoint(self, path=None, append_samples=True,
plot_corner=False):
"""
Save sampler state to disk for later resumption.
When ``append_samples=True`` (the default), new samples are
appended to any existing samples already stored in ``path``,
and ``self.samples`` is cleared from memory. This enables a
sample-checkpoint-clear loop that keeps memory usage constant.
Parameters
----------
path : str, optional
File path (saved as ``.npz``). If None, uses the
``checkpoint_file`` set in ``run_nuts``, or
``save_dir/checkpoint.npz`` if ``save_dir`` was set.
append_samples : bool
If True, append current ``self.samples`` to any samples
already on disk, then clear ``self.samples`` from memory.
If False, overwrite with only the current in-memory samples.
plot_corner : bool
If True, load all samples currently on disk after saving
and write a corner plot to ``save_dir/corner_plot.png``
(or alongside the checkpoint file if ``save_dir`` is not
set).
"""
import os
if self._last_state is None:
raise RuntimeError("No sampler state to save. Run run_nuts first.")
if path is None:
path = self._checkpoint_file
if path is None and self.save_dir is not None:
path = os.path.join(self.save_dir, "checkpoint.npz")
if path is None:
raise ValueError(
"No path provided, no checkpoint_file set, and no save_dir. "
"Pass a path, set checkpoint_file in run_nuts, or set save_dir.")
samples_to_save = np.asarray(self.samples) if self.samples is not None else None
# Merge with samples already on disk
if append_samples and samples_to_save is not None:
import os
_path = path if path.endswith(".npz") else path + ".npz"
if os.path.exists(_path):
existing = np.load(_path)
if "samples" in existing and existing["samples"].size > 0:
# multi-chain: (n_chains, n_samples, n_params) → concat on axis=1
# single-chain: (n_samples, n_params) → concat on axis=0
cat_axis = 1 if samples_to_save.ndim == 3 else 0
samples_to_save = np.concatenate(
[existing["samples"], samples_to_save], axis=cat_axis)
existing.close()
save_kwargs = {
# NUTS state (shape has leading chain dim when n_chains > 1)
"position": np.asarray(self._last_state.position),
"logdensity": np.asarray(self._last_state.logdensity),
"logdensity_grad": np.asarray(self._last_state.logdensity_grad),
# Adapted kernel parameters
"step_size": np.asarray(self._adapted_step_size),
"inverse_mass_matrix": np.asarray(self._adapted_inv_mass),
"rng_key": np.asarray(self._last_rng_key),
# Diagnostics (scalars)
"n_warmup": np.asarray(self._info["n_warmup"]),
"n_chains": np.asarray(getattr(self, "_n_chains", 1)),
}
if samples_to_save is not None:
save_kwargs["samples"] = samples_to_save
n_on_disk = samples_to_save.shape[0]
else:
save_kwargs["samples"] = np.array([])
n_on_disk = 0
np.savez(path, **save_kwargs)
if append_samples:
# Free in-memory samples and per-sample diagnostics
self.samples = None
self._info = {
"step_size": self._info["step_size"],
"n_warmup": self._info["n_warmup"],
"n_samples": n_on_disk,
"n_divergent": self._info.get("n_divergent", 0),
}
print(f"Checkpoint saved to {path} ({n_on_disk} samples on disk)")
if plot_corner and n_on_disk > 0:
import corner
import matplotlib.pyplot as plt
all_samples = self.load_samples(path)
fig = corner.corner(
all_samples,
labels=list(self.param_keys),
show_titles=True,
title_fmt=".3f",
)
_chk = path if path.endswith(".npz") else path + ".npz"
corner_dir = self.save_dir if self.save_dir is not None \
else os.path.dirname(os.path.abspath(_chk))
corner_path = os.path.join(corner_dir, "corner_plot.png")
fig.savefig(corner_path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"Corner plot saved to {corner_path} "
f"({n_on_disk} samples)")
[docs]
def load_checkpoint(self, path):
"""
Restore sampler state from a checkpoint file.
Loads only the NUTS state and adapted kernel parameters needed
to resume sampling. Samples stored in the file are **not**
loaded into memory — use ``load_samples`` to read them later.
Parameters
----------
path : str
Path to a ``.npz`` checkpoint file.
"""
import blackjax
data = np.load(path)
# Reconstruct NUTS state (works for both single and multi-chain:
# arrays have a leading chain dimension when n_chains > 1)
self._last_state = blackjax.mcmc.hmc.HMCState(
position=jnp.asarray(data["position"]),
logdensity=jnp.asarray(data["logdensity"]),
logdensity_grad=jnp.asarray(data["logdensity_grad"]),
)
n_chains = int(data["n_chains"]) if "n_chains" in data else 1
self._n_chains = n_chains
if n_chains > 1:
self._adapted_step_size = np.asarray(data["step_size"])
else:
self._adapted_step_size = float(data["step_size"])
self._adapted_inv_mass = jnp.asarray(data["inverse_mass_matrix"])
self._last_rng_key = jnp.asarray(data["rng_key"])
n_on_disk = data["samples"].shape[0] if data["samples"].size > 0 else 0
n_warmup = int(data["n_warmup"])
data.close()
# Don't load samples into memory — keep it lightweight
self.samples = None
self._info = {
"step_size": self._adapted_step_size,
"n_warmup": n_warmup,
"n_samples": n_on_disk,
"n_chains": n_chains,
"n_divergent": 0,
}
print(f"Checkpoint loaded from {path} "
f"({n_on_disk} samples on disk, {n_chains} chain(s), "
f"not loaded into memory)")
[docs]
@staticmethod
def load_samples(path, flatten_chains=True):
"""
Read all samples from a checkpoint file without loading
the sampler state.
Parameters
----------
path : str
Path to a ``.npz`` checkpoint file.
flatten_chains : bool
If True (default), collapse the chain dimension so the
returned array is always ``(n_total, n_params)``. Set to
False to get the raw ``(n_chains, n_samples, n_params)``
array for per-chain diagnostics (e.g. R-hat).
Returns
-------
samples : np.ndarray
Shape ``(n_total, n_params)`` when ``flatten_chains=True``,
or ``(n_chains, n_samples, n_params)`` otherwise.
"""
data = np.load(path)
samples = data["samples"].copy()
data.close()
if flatten_chains and samples.ndim == 3:
n_chains, n_samp, n_params = samples.shape
samples = samples.reshape(n_chains * n_samp, n_params)
return samples
[docs]
def resume_nuts(self, n_samples=1000, n_chains=None, rng_key=None, progress_bar=False):
"""
Continue NUTS sampling from a previous run or loaded checkpoint.
Skips warmup entirely and uses the previously adapted step size
and mass matrix (shared across all chains). Returns only the
new batch of samples. Call ``save_checkpoint`` afterward to
append the batch to disk and free memory.
Parameters
----------
n_samples : int
Number of additional samples per chain (default 1000).
n_chains : int, optional
Number of chains to run. If None (default), uses the value
stored in the sampler state (from ``run_nuts`` or
``load_checkpoint``).
rng_key : jax.random.PRNGKey, optional
Random key. If None, advances from the last key used.
progress_bar : bool
If True, print periodic progress updates (default False).
Only supported for single-chain runs.
Returns
-------
samples : jnp.ndarray
Shape ``(n_samples, n_params)`` for single chain, or
``(n_chains, n_samples, n_params)`` for multiple chains.
info : dict
Diagnostics for this batch.
"""
import blackjax
if self._last_state is None:
raise RuntimeError(
"No previous state. Run run_nuts or load_checkpoint first.")
if n_chains is None:
n_chains = getattr(self, "_n_chains", 1)
else:
self._n_chains = n_chains
gp = self.gp
step_size = self._adapted_step_size
inv_mass = jnp.asarray(self._adapted_inv_mass)
def _run_one_chain(state, chain_key):
kernel = blackjax.nuts(
gp.log_posterior,
step_size=step_size,
inverse_mass_matrix=inv_mass,
)
chain_keys = jax.random.split(chain_key, n_samples)
def one_step(carry, key_idx):
st, n_div = carry
key, _idx = key_idx
st, info = kernel.step(key, st)
n_div = n_div + info.is_divergent.astype(jnp.int32)
return (st, n_div), (st.position, info)
indices = jnp.arange(n_samples)
(final_st, total_div), (positions, infos) = jax.lax.scan(
one_step, (state, jnp.int32(0)), (chain_keys, indices),
)
return final_st, total_div, positions, infos, chain_keys[-1]
if n_chains > 1:
states = self._last_state
if rng_key is None:
rng_key = jax.random.fold_in(self._last_rng_key[0], 1)
chain_keys = jax.random.split(rng_key, n_chains)
print(f"Resuming: {n_chains} chains x {n_samples} samples...")
all_final, all_div, all_pos, all_infos, all_last_keys = \
jax.vmap(_run_one_chain)(states, chain_keys)
self.samples = all_pos
self._info = {
"divergences": np.asarray(all_infos.is_divergent),
"acceptance_rate": np.asarray(all_infos.acceptance_rate),
"num_steps": np.asarray(
all_infos.num_integration_steps),
"step_size": step_size,
"n_warmup": self._info.get("n_warmup", 0),
"n_samples": n_samples,
"n_chains": n_chains,
"n_divergent": int(jnp.sum(all_div)),
}
self._last_state = all_final
self._last_rng_key = all_last_keys
total_div = int(jnp.sum(all_div))
mean_accept = float(jnp.mean(
jnp.array(self._info["acceptance_rate"])))
print(f"Resume complete: {n_chains} chains x {n_samples} "
f"new samples, {total_div} divergences, "
f"mean acceptance rate = {mean_accept:.3f}")
else:
state = self._last_state
if rng_key is None:
rng_key = jax.random.split(self._last_rng_key)[0]
print(f"Resuming: {n_samples} additional samples "
f"(step_size={step_size:.6f})...")
final_state, new_div, new_positions, new_infos, last_key = \
_run_one_chain(state, rng_key)
self.samples = new_positions
self._info = {
"step_size": step_size,
"n_warmup": self._info.get("n_warmup", 0),
"n_samples": n_samples,
"n_chains": 1,
"n_divergent": int(new_div),
"divergences": np.asarray(new_infos.is_divergent),
"acceptance_rate": np.asarray(new_infos.acceptance_rate),
"num_steps": np.asarray(new_infos.num_integration_steps),
}
self._last_state = final_state
self._last_rng_key = last_key
mean_accept = float(np.mean(self._info["acceptance_rate"]))
print(f"Resume complete: {n_samples} new samples, "
f"{int(new_div)} divergences, "
f"mean acceptance rate = {mean_accept:.3f}")
return self.samples, self._info