MCMC Sampling with BlackJAX#

This notebook fits GP hyperparameters to a synthetic stellar lightcurve using No-U-Turn Sampling (NUTS) via BlackJAX.

Sections

  1. Build the spot model and simulate a lightcurve

  2. Find the MAP estimate with GPSolver.fit_map()

  3. Run NUTS warmup and sampling with BlackJAXSampler

  4. Inspect diagnostics and posterior summary

  5. Resume sampling in batches (constant memory)

  6. Visualize the posterior — corner plot and covariance ellipses

import sys
sys.path.append("../..")

import numpy as np
import matplotlib.pyplot as plt

from spotgp import (
    TrapezoidSymmetricEnvelope,
    VisibilityFunction,
    SpotEvolutionModel,
    LightcurveModel,
    GPSolver,
)
from spotgp.mcmc import BlackJAXSampler

np.random.seed(42)

1. Build the spot model and simulate a lightcurve#

envelope = TrapezoidSymmetricEnvelope(lspot=12.0, tau_spot=6.0)
visibility = VisibilityFunction(peq=5.0, kappa=0.3, inc=np.pi / 3)
model = SpotEvolutionModel(
    envelope=envelope,
    visibility=visibility,
    nspot_rate=0.25,   # spots per day
    alpha_max=0.05,    # peak spot angular radius [rad]
    fspot=0.0,
)
print("param_keys:", model.param_keys)
print(f"sigma_k = {model.sigma_k:.5f}")
tsim  = 200
tsamp = 0.5
nspot = int(tsim * 0.25)

lc = LightcurveModel.from_spot_model(
    model, nspot=nspot, tsim=tsim, tsamp=tsamp,
    long=[0, 2 * np.pi],
)

sigma_n  = 0.3 * np.std(lc.flux)
flux_obs = lc.flux + np.random.normal(0, sigma_n, lc.flux.shape)
flux_err = np.full_like(flux_obs, sigma_n)
tobs     = lc.t

print(f"Lightcurve: {len(tobs)} points over {tsim} days")
print(f"Signal std: {np.std(lc.flux):.5f}   Noise: {sigma_n:.5f}")

fig, ax = plt.subplots(figsize=(13, 4))
ax.errorbar(tobs, (flux_obs - 1) * 100, yerr=flux_err * 100,
            fmt=".k", ms=3, capsize=0, alpha=0.5, label="Observed")
ax.plot(tobs, (lc.flux - 1) * 100, "r-", lw=1.2, label="True")
ax.set_xlabel("Time [days]", fontsize=20)
ax.set_ylabel(r"$\Delta$ Flux [%]", fontsize=20)
ax.legend(fontsize=14)
ax.set_xlim(tobs[0], tobs[-1])
fig.tight_layout()
plt.show()

2. Find the MAP estimate#

See also

For a detailed tutorial on MAP optimization see GPSolver Tutorial

bounds = {
    "peq":         (2.0, 10.0),
    "kappa":       (-0.5, 0.8),
    "inc":         (0.1, np.pi / 2),
    "lspot":       (1.0, 30.0),
    "tau_spot":    (0.5, 20.0),
    "log_sigma_k": (-5.0, -1.0),
}

gp = GPSolver(tobs, flux_obs, flux_err, model, bounds=bounds).build_jax()

print("param_keys:", gp.param_keys)
print("bandwidth :", gp.bandwidth, "pts")
import time
t0 = time.time()
theta_map, opt_result = gp.fit_map(nopt=5, method="L-BFGS-B")
print(f"fit_map() wall time: {time.time() - t0:.1f} s  (converged: {opt_result.success})")

# True values for comparison
theta_true = {
    "peq":      model.peq,
    "kappa":    model.kappa,
    "inc":      model.inc,
    "lspot":    model.lspot,
    "tau_spot": model.tau_spot,
    "sigma_k":  model.sigma_k,
}

print(f"\n{'param':>12s}  {'true':>10s}  {'MAP':>10s}")
print("-" * 38)
for k, v_true in theta_true.items():
    v_map = theta_map.get(k, theta_map.get("log_" + k, float("nan")))
    print(f"{k:>12s}  {v_true:10.4f}  {v_map:10.4f}")
tlag_plot = np.arange(0, 3 * model.peq, tsamp)
fig, axes = plt.subplots(3, 1, figsize=(12, 11))
gp.plot_prediction(theta=theta_map, ax=axes[0], model_color="steelblue", model_label="MAP fit")
gp.plot_acf(theta=theta_map, ax=axes[1], tlags=tlag_plot,
            model_color="steelblue", model_label="MAP fit")
gp.plot_psd(theta=theta_map, ax=axes[2], model_color="steelblue", model_label="MAP fit")
fig.suptitle("MAP fit", fontsize=14)
fig.tight_layout()
plt.show()

3. Run NUTS warmup and sampling#

BlackJAXSampler wraps GPSolver and provides gradient-based NUTS sampling via BlackJAX. The workflow has three steps:

  1. run_map() — find MAP solutions (or load from disk)

  2. run_warmup() — adapt step size and mass matrix

  3. run_sampling() — draw post-warmup samples (call repeatedly for batches)

Key parameters for run_warmup:

Parameter

Default

Description

n_warmup

500

Warmup steps for dual-averaging step-size and mass-matrix adaptation

n_chains

1

Independent chains run in parallel via jax.pmap

theta_init

MAP estimate

Starting position

mass_matrix_method

"hessian_map"

Initial mass matrix: "hessian_map", "fisher", "laplace", "diagonal", or None

target_accept

0.8

Target acceptance rate for dual averaging

checkpoint_file

None

Path to save state so sampling can be resumed later

warmup_method

"window_adaptation"

"window_adaptation", "pathfinder", or "dual_averaging"

Warmup runs a single chain to adapt the step size and mass matrix; these are then shared across all chains when n_chains > 1.

sampler = BlackJAXSampler(gp, save_dir="results/mcmc_demo")

sampler.run_warmup(
    n_warmup=500,
    n_chains=2,
    theta_init=theta_map,
    mass_matrix_method="hessian_map",
    target_accept=0.8,
    checkpoint_file="results/mcmc_demo/checkpoint.npz",
)

samples, info = sampler.run_sampling(n_samples=500)

print(f"\nSamples shape : {np.asarray(samples).shape}  "
      f"(n_chains, n_samples, n_params)")
print(f"Divergences   : {info['n_divergent']}")
print(f"Acceptance rate: {np.mean(info['acceptance_rate']):.3f}")

4. Diagnostics and posterior summary#

sampler.summary() prints mean, std, and 16/50/84 percentiles for each parameter. info contains per-step diagnostics from the sampling loop.

stats = sampler.summary()
# Trace plot: one row per parameter, both chains overlaid
samples_np = np.asarray(samples)   # (n_chains, n_samples, n_params)
n_chains, n_samp, n_params = samples_np.shape

fig, axes = plt.subplots(n_params, 1, figsize=(12, 2.2 * n_params), sharex=True)
chain_colors = ["steelblue", "tomato"]

for i, (ax, key) in enumerate(zip(axes, gp.param_keys)):
    for c in range(n_chains):
        ax.plot(samples_np[c, :, i], lw=0.6, alpha=0.7,
                color=chain_colors[c % len(chain_colors)], label=f"chain {c}")
    # True value (physical space; log_sigma_k is shown as log10)
    true_val = theta_true.get(key, theta_true.get(key[4:], None))
    if key.startswith("log_") and true_val is not None:
        true_val = np.log10(true_val)
    if true_val is not None:
        ax.axhline(true_val, color="k", lw=1.2, ls="--", alpha=0.7)
    ax.set_ylabel(key, fontsize=11)
    if i == 0:
        ax.legend(fontsize=10, loc="upper right")

axes[-1].set_xlabel("Sample", fontsize=13)
fig.suptitle("NUTS trace plot", fontsize=13)
fig.tight_layout()
plt.show()

5. Resume sampling in batches#

save_checkpoint() appends the current batch to disk and frees self.samples from memory. run_sampling() continues from the last NUTS state, skipping warmup entirely.

This pattern keeps memory usage constant regardless of total sample count — the full chain is built on disk while only one batch lives in RAM at a time.

# Save the first batch and clear in-memory samples
sampler.save_checkpoint()
print("In-memory samples after checkpoint:", sampler.samples)  # None
# Run two more batches of 500 samples each
for batch in range(2):
    samples_batch, info_batch = sampler.run_sampling(n_samples=500)
    sampler.save_checkpoint()
    print(f"Batch {batch + 1}: {np.asarray(samples_batch).shape}  "
          f"acceptance={np.mean(info_batch['acceptance_rate']):.3f}")
# Load all samples from disk for analysis
all_samples = BlackJAXSampler.load_samples(
    "results/mcmc_demo/checkpoint.npz",
    flatten_chains=True,   # collapse chains → (n_total, n_params)
)
print(f"Total samples (all batches, all chains): {all_samples.shape}")

# Per-chain arrays for R-hat diagnostics
all_samples_chains = BlackJAXSampler.load_samples(
    "results/mcmc_demo/checkpoint.npz",
    flatten_chains=False,  # (n_chains, n_samples, n_params)
)
print(f"Per-chain shape: {all_samples_chains.shape}")

6. Visualize the posterior#

6a. Corner plot#

import corner

# Convert log_sigma_k samples back to physical sigma_k for display
display_samples = all_samples.copy()
lsk_idx = list(gp.param_keys).index("log_sigma_k")
display_samples[:, lsk_idx] = 10 ** display_samples[:, lsk_idx]

display_labels = [k.replace("log_sigma_k", "sigma_k") for k in gp.param_keys]
truths = [theta_true.get(k.replace("log_", ""), None) for k in gp.param_keys]

fig = corner.corner(
    display_samples,
    labels=display_labels,
    truths=truths,
    show_titles=True,
    title_fmt=".3f",
    color="steelblue",
)
fig.suptitle("Posterior corner plot", fontsize=14, y=1.01)
plt.show()

6b. Posterior covariance ellipses#

sampler.plot_covariance() overlays Gaussian covariance ellipses from the Laplace approximation at the MAP (method="hessian_map") against the MCMC samples, making it easy to spot non-Gaussianity.

fig, axes = sampler.plot_covariance(
    method="hessian_map",
    true_params=theta_true,
    samples=all_samples,
    color="steelblue",
    alpha=0.3,
)
plt.show()

6c. Posterior predictive check#

Plot the GP prediction at the posterior median against the data to verify the fit quality.

# Posterior median parameter vector
theta_median = {k: float(np.median(all_samples[:, i]))
                for i, k in enumerate(gp.param_keys)}

tlag_plot = np.arange(0, 3 * model.peq, tsamp)
fig, axes = plt.subplots(3, 1, figsize=(12, 11))

gp.plot_prediction(theta=theta_map,    ax=axes[0],
                   model_color="steelblue", model_label="MAP")
gp.plot_prediction(theta=theta_median, ax=axes[0],
                   model_color="tomato",    model_label="Posterior median",
                   data_color=None,         data_label=None)

gp.plot_acf(theta=theta_map,    ax=axes[1], tlags=tlag_plot,
            model_color="steelblue", model_label="MAP")
gp.plot_acf(theta=theta_median, ax=axes[1], tlags=tlag_plot,
            model_color="tomato",    model_label="Posterior median")

gp.plot_psd(theta=theta_map,    ax=axes[2],
            model_color="steelblue", model_label="MAP")
gp.plot_psd(theta=theta_median, ax=axes[2],
            model_color="tomato",    model_label="Posterior median")

fig.suptitle("MAP vs posterior median", fontsize=14)
fig.tight_layout()
plt.show()