MCMC Sampling with BlackJAX#
This notebook fits GP hyperparameters to a synthetic stellar lightcurve using No-U-Turn Sampling (NUTS) via BlackJAX.
Sections
Build the spot model and simulate a lightcurve
Find the MAP estimate with
GPSolver.fit_map()Run NUTS warmup and sampling with
BlackJAXSamplerInspect diagnostics and posterior summary
Resume sampling in batches (constant memory)
Visualize the posterior — corner plot and covariance ellipses
import sys
sys.path.append("../..")
import numpy as np
import matplotlib.pyplot as plt
from spotgp import (
TrapezoidSymmetricEnvelope,
VisibilityFunction,
SpotEvolutionModel,
LightcurveModel,
GPSolver,
)
from spotgp.mcmc import BlackJAXSampler
np.random.seed(42)
1. Build the spot model and simulate a lightcurve#
envelope = TrapezoidSymmetricEnvelope(lspot=12.0, tau_spot=6.0)
visibility = VisibilityFunction(peq=5.0, kappa=0.3, inc=np.pi / 3)
model = SpotEvolutionModel(
envelope=envelope,
visibility=visibility,
nspot_rate=0.25, # spots per day
alpha_max=0.05, # peak spot angular radius [rad]
fspot=0.0,
)
print("param_keys:", model.param_keys)
print(f"sigma_k = {model.sigma_k:.5f}")
tsim = 200
tsamp = 0.5
nspot = int(tsim * 0.25)
lc = LightcurveModel.from_spot_model(
model, nspot=nspot, tsim=tsim, tsamp=tsamp,
long=[0, 2 * np.pi],
)
sigma_n = 0.3 * np.std(lc.flux)
flux_obs = lc.flux + np.random.normal(0, sigma_n, lc.flux.shape)
flux_err = np.full_like(flux_obs, sigma_n)
tobs = lc.t
print(f"Lightcurve: {len(tobs)} points over {tsim} days")
print(f"Signal std: {np.std(lc.flux):.5f} Noise: {sigma_n:.5f}")
fig, ax = plt.subplots(figsize=(13, 4))
ax.errorbar(tobs, (flux_obs - 1) * 100, yerr=flux_err * 100,
fmt=".k", ms=3, capsize=0, alpha=0.5, label="Observed")
ax.plot(tobs, (lc.flux - 1) * 100, "r-", lw=1.2, label="True")
ax.set_xlabel("Time [days]", fontsize=20)
ax.set_ylabel(r"$\Delta$ Flux [%]", fontsize=20)
ax.legend(fontsize=14)
ax.set_xlim(tobs[0], tobs[-1])
fig.tight_layout()
plt.show()
2. Find the MAP estimate#
See also
For a detailed tutorial on MAP optimization see GPSolver Tutorial
bounds = {
"peq": (2.0, 10.0),
"kappa": (-0.5, 0.8),
"inc": (0.1, np.pi / 2),
"lspot": (1.0, 30.0),
"tau_spot": (0.5, 20.0),
"log_sigma_k": (-5.0, -1.0),
}
gp = GPSolver(tobs, flux_obs, flux_err, model, bounds=bounds).build_jax()
print("param_keys:", gp.param_keys)
print("bandwidth :", gp.bandwidth, "pts")
import time
t0 = time.time()
theta_map, opt_result = gp.fit_map(nopt=5, method="L-BFGS-B")
print(f"fit_map() wall time: {time.time() - t0:.1f} s (converged: {opt_result.success})")
# True values for comparison
theta_true = {
"peq": model.peq,
"kappa": model.kappa,
"inc": model.inc,
"lspot": model.lspot,
"tau_spot": model.tau_spot,
"sigma_k": model.sigma_k,
}
print(f"\n{'param':>12s} {'true':>10s} {'MAP':>10s}")
print("-" * 38)
for k, v_true in theta_true.items():
v_map = theta_map.get(k, theta_map.get("log_" + k, float("nan")))
print(f"{k:>12s} {v_true:10.4f} {v_map:10.4f}")
tlag_plot = np.arange(0, 3 * model.peq, tsamp)
fig, axes = plt.subplots(3, 1, figsize=(12, 11))
gp.plot_prediction(theta=theta_map, ax=axes[0], model_color="steelblue", model_label="MAP fit")
gp.plot_acf(theta=theta_map, ax=axes[1], tlags=tlag_plot,
model_color="steelblue", model_label="MAP fit")
gp.plot_psd(theta=theta_map, ax=axes[2], model_color="steelblue", model_label="MAP fit")
fig.suptitle("MAP fit", fontsize=14)
fig.tight_layout()
plt.show()
3. Run NUTS warmup and sampling#
BlackJAXSampler wraps GPSolver and provides gradient-based NUTS sampling
via BlackJAX. The workflow has three steps:
run_map()— find MAP solutions (or load from disk)run_warmup()— adapt step size and mass matrixrun_sampling()— draw post-warmup samples (call repeatedly for batches)
Key parameters for run_warmup:
Parameter |
Default |
Description |
|---|---|---|
|
500 |
Warmup steps for dual-averaging step-size and mass-matrix adaptation |
|
1 |
Independent chains run in parallel via |
|
MAP estimate |
Starting position |
|
|
Initial mass matrix: |
|
0.8 |
Target acceptance rate for dual averaging |
|
None |
Path to save state so sampling can be resumed later |
|
|
|
Warmup runs a single chain to adapt the step size and mass matrix; these
are then shared across all chains when n_chains > 1.
sampler = BlackJAXSampler(gp, save_dir="results/mcmc_demo")
sampler.run_warmup(
n_warmup=500,
n_chains=2,
theta_init=theta_map,
mass_matrix_method="hessian_map",
target_accept=0.8,
checkpoint_file="results/mcmc_demo/checkpoint.npz",
)
samples, info = sampler.run_sampling(n_samples=500)
print(f"\nSamples shape : {np.asarray(samples).shape} "
f"(n_chains, n_samples, n_params)")
print(f"Divergences : {info['n_divergent']}")
print(f"Acceptance rate: {np.mean(info['acceptance_rate']):.3f}")
4. Diagnostics and posterior summary#
sampler.summary() prints mean, std, and 16/50/84 percentiles for each
parameter. info contains per-step diagnostics from the sampling loop.
stats = sampler.summary()
# Trace plot: one row per parameter, both chains overlaid
samples_np = np.asarray(samples) # (n_chains, n_samples, n_params)
n_chains, n_samp, n_params = samples_np.shape
fig, axes = plt.subplots(n_params, 1, figsize=(12, 2.2 * n_params), sharex=True)
chain_colors = ["steelblue", "tomato"]
for i, (ax, key) in enumerate(zip(axes, gp.param_keys)):
for c in range(n_chains):
ax.plot(samples_np[c, :, i], lw=0.6, alpha=0.7,
color=chain_colors[c % len(chain_colors)], label=f"chain {c}")
# True value (physical space; log_sigma_k is shown as log10)
true_val = theta_true.get(key, theta_true.get(key[4:], None))
if key.startswith("log_") and true_val is not None:
true_val = np.log10(true_val)
if true_val is not None:
ax.axhline(true_val, color="k", lw=1.2, ls="--", alpha=0.7)
ax.set_ylabel(key, fontsize=11)
if i == 0:
ax.legend(fontsize=10, loc="upper right")
axes[-1].set_xlabel("Sample", fontsize=13)
fig.suptitle("NUTS trace plot", fontsize=13)
fig.tight_layout()
plt.show()
5. Resume sampling in batches#
save_checkpoint() appends the current batch to disk and frees self.samples
from memory. run_sampling() continues from the last NUTS state, skipping
warmup entirely.
This pattern keeps memory usage constant regardless of total sample count — the full chain is built on disk while only one batch lives in RAM at a time.
# Save the first batch and clear in-memory samples
sampler.save_checkpoint()
print("In-memory samples after checkpoint:", sampler.samples) # None
# Run two more batches of 500 samples each
for batch in range(2):
samples_batch, info_batch = sampler.run_sampling(n_samples=500)
sampler.save_checkpoint()
print(f"Batch {batch + 1}: {np.asarray(samples_batch).shape} "
f"acceptance={np.mean(info_batch['acceptance_rate']):.3f}")
# Load all samples from disk for analysis
all_samples = BlackJAXSampler.load_samples(
"results/mcmc_demo/checkpoint.npz",
flatten_chains=True, # collapse chains → (n_total, n_params)
)
print(f"Total samples (all batches, all chains): {all_samples.shape}")
# Per-chain arrays for R-hat diagnostics
all_samples_chains = BlackJAXSampler.load_samples(
"results/mcmc_demo/checkpoint.npz",
flatten_chains=False, # (n_chains, n_samples, n_params)
)
print(f"Per-chain shape: {all_samples_chains.shape}")
6. Visualize the posterior#
6a. Corner plot#
import corner
# Convert log_sigma_k samples back to physical sigma_k for display
display_samples = all_samples.copy()
lsk_idx = list(gp.param_keys).index("log_sigma_k")
display_samples[:, lsk_idx] = 10 ** display_samples[:, lsk_idx]
display_labels = [k.replace("log_sigma_k", "sigma_k") for k in gp.param_keys]
truths = [theta_true.get(k.replace("log_", ""), None) for k in gp.param_keys]
fig = corner.corner(
display_samples,
labels=display_labels,
truths=truths,
show_titles=True,
title_fmt=".3f",
color="steelblue",
)
fig.suptitle("Posterior corner plot", fontsize=14, y=1.01)
plt.show()
6b. Posterior covariance ellipses#
sampler.plot_covariance() overlays Gaussian covariance ellipses from the
Laplace approximation at the MAP (method="hessian_map") against the MCMC
samples, making it easy to spot non-Gaussianity.
fig, axes = sampler.plot_covariance(
method="hessian_map",
true_params=theta_true,
samples=all_samples,
color="steelblue",
alpha=0.3,
)
plt.show()
6c. Posterior predictive check#
Plot the GP prediction at the posterior median against the data to verify the fit quality.
# Posterior median parameter vector
theta_median = {k: float(np.median(all_samples[:, i]))
for i, k in enumerate(gp.param_keys)}
tlag_plot = np.arange(0, 3 * model.peq, tsamp)
fig, axes = plt.subplots(3, 1, figsize=(12, 11))
gp.plot_prediction(theta=theta_map, ax=axes[0],
model_color="steelblue", model_label="MAP")
gp.plot_prediction(theta=theta_median, ax=axes[0],
model_color="tomato", model_label="Posterior median",
data_color=None, data_label=None)
gp.plot_acf(theta=theta_map, ax=axes[1], tlags=tlag_plot,
model_color="steelblue", model_label="MAP")
gp.plot_acf(theta=theta_median, ax=axes[1], tlags=tlag_plot,
model_color="tomato", model_label="Posterior median")
gp.plot_psd(theta=theta_map, ax=axes[2],
model_color="steelblue", model_label="MAP")
gp.plot_psd(theta=theta_median, ax=axes[2],
model_color="tomato", model_label="Posterior median")
fig.suptitle("MAP vs posterior median", fontsize=14)
fig.tight_layout()
plt.show()