Nested Sampling with Dynesty#

This notebook fits GP hyperparameters to a synthetic stellar lightcurve using nested sampling via dynesty. Unlike MCMC, nested sampling also estimates the Bayesian evidence, which is useful for model comparison.

Sections

  1. Build the spot model and simulate a lightcurve

  2. Find the MAP estimate with GPSolver.fit_map()

  3. Run nested sampling with DynestySampler

  4. Inspect diagnostics and posterior summary

  5. Dynamic nested sampling

  6. Visualize the posterior

See also

For gradient-based MCMC sampling, see the BlackJAX MCMC tutorial.

import sys
sys.path.append("../..")

import numpy as np
import matplotlib.pyplot as plt

from spotgp import (
    TrapezoidSymmetricEnvelope,
    VisibilityFunction,
    SpotEvolutionModel,
    LightcurveModel,
    GPSolver,
)
from spotgp.mcmc import DynestySampler

np.random.seed(42)

1. Build the spot model and simulate a lightcurve#

This section is identical to the BlackJAX MCMC tutorial so that both samplers can be compared on the same data.

envelope = TrapezoidSymmetricEnvelope(lspot=12.0, tau_spot=6.0)
visibility = VisibilityFunction(peq=5.0, kappa=0.3, inc=np.pi / 3)
model = SpotEvolutionModel(
    envelope=envelope,
    visibility=visibility,
    nspot_rate=0.25,   # spots per day
    alpha_max=0.05,    # peak spot angular radius [rad]
    fspot=0.0,
)
print("param_keys:", model.param_keys)
print(f"sigma_k = {model.sigma_k:.5f}")
tsim  = 200
tsamp = 0.5
nspot = int(tsim * 0.25)

lc = LightcurveModel.from_spot_model(
    model, nspot=nspot, tsim=tsim, tsamp=tsamp,
    long=[0, 2 * np.pi],
)

sigma_n  = 0.3 * np.std(lc.flux)
flux_obs = lc.flux + np.random.normal(0, sigma_n, lc.flux.shape)
flux_err = np.full_like(flux_obs, sigma_n)
tobs     = lc.t

print(f"Lightcurve: {len(tobs)} points over {tsim} days")
print(f"Signal std: {np.std(lc.flux):.5f}   Noise: {sigma_n:.5f}")

fig, ax = plt.subplots(figsize=(13, 4))
ax.errorbar(tobs, (flux_obs - 1) * 100, yerr=flux_err * 100,
            fmt=".k", ms=3, capsize=0, alpha=0.5, label="Observed")
ax.plot(tobs, (lc.flux - 1) * 100, "r-", lw=1.2, label="True")
ax.set_xlabel("Time [days]", fontsize=20)
ax.set_ylabel(r"$\Delta$ Flux [%]", fontsize=20)
ax.legend(fontsize=14)
ax.set_xlim(tobs[0], tobs[-1])
fig.tight_layout()
plt.show()

2. Find the MAP estimate#

See also

For a detailed tutorial on MAP optimization see GPSolver Tutorial

bounds = {
    "peq":         (2.0, 10.0),
    "kappa":       (-0.5, 0.8),
    "inc":         (0.1, np.pi / 2),
    "lspot":       (1.0, 30.0),
    "tau_spot":    (0.5, 20.0),
    "log_sigma_k": (-5.0, -1.0),
}

gp = GPSolver(tobs, flux_obs, flux_err, model, bounds=bounds).build_jax()

print("param_keys:", gp.param_keys)
print("bandwidth :", gp.bandwidth, "pts")
import time
t0 = time.time()
theta_map, opt_result = gp.fit_map(nopt=5, method="L-BFGS-B")
print(f"fit_map() wall time: {time.time() - t0:.1f} s  (converged: {opt_result.success})")

# True values for comparison
theta_true = {
    "peq":      model.peq,
    "kappa":    model.kappa,
    "inc":      model.inc,
    "lspot":    model.lspot,
    "tau_spot": model.tau_spot,
    "sigma_k":  model.sigma_k,
}

print(f"\n{'param':>12s}  {'true':>10s}  {'MAP':>10s}")
print("-" * 38)
for k, v_true in theta_true.items():
    v_map = theta_map.get(k, theta_map.get("log_" + k, float("nan")))
    print(f"{k:>12s}  {v_true:10.4f}  {v_map:10.4f}")

3. Run nested sampling with DynestySampler#

DynestySampler wraps GPSolver and provides nested sampling via dynesty. The workflow is:

  1. Create the sampler with a GPSolver and (optionally) a custom prior transform

  2. run_sampling() — run the nested sampler to completion

By default, DynestySampler constructs a uniform prior transform from the GP bounds. You can pass a custom prior_transform function to the constructor if you need non-uniform priors.

Key parameters for run_sampling:

Parameter

Default

Description

nlive

500

Number of live points

dlogz

0.01

Stopping criterion on remaining log-evidence

bound

"multi"

Bounding method ("none", "single", "multi", "balls", "cubes")

sample

"auto"

Sampling method ("auto", "unif", "rwalk", "slice", "rslice")

Unlike MCMC, nested sampling does not require warmup or step-size adaptation. It also produces the log-evidence \(\ln \mathcal{Z}\) as a direct output, which is useful for model comparison via Bayes factors.

sampler = DynestySampler(gp, save_dir="results/dynesty_demo")

t0 = time.time()
samples, info = sampler.run_sampling(
    nlive=300,
    dlogz=0.5,
    bound="multi",
    sample="rwalk",
)
elapsed = time.time() - t0

print(f"\nWall time     : {elapsed:.1f} s")
print(f"Samples shape : {samples.shape}")
print(f"log(Z)        : {info['logz']:.2f} +/- {info['logzerr']:.2f}")
print(f"Iterations    : {info['niter']}")
print(f"Efficiency    : {info['eff']:.1f}%")

Custom prior transform#

By default, DynestySampler uses a uniform prior over the bounds. If you need a different prior (e.g. a Gaussian or log-uniform prior on certain parameters), pass a custom prior_transform function:

def my_prior_transform(u):
    """Map unit cube [0, 1]^n to physical parameter space."""
    theta = np.empty_like(u)
    theta[0] = 2.0 + u[0] * 8.0          # peq: uniform [2, 10]
    theta[1] = -0.5 + u[1] * 1.3         # kappa: uniform [-0.5, 0.8]
    # ... etc.
    return theta

sampler = DynestySampler(gp, prior_transform=my_prior_transform)

4. Diagnostics and posterior summary#

sampler.summary() prints mean, std, and 16/50/84 percentiles for each parameter — the same interface as BlackJAXSampler.

stats = sampler.summary()

Checkpointing#

Results are automatically saved when run_sampling() completes. You can also save and load manually:

# Save checkpoint
sampler.save_checkpoint()

# Load samples from disk (without restoring the full sampler state)
all_samples = DynestySampler.load_samples("results/dynesty_demo/mcmc_checkpoint.npz")
print(f"Samples loaded from disk: {all_samples.shape}")

# Restore full sampler state from checkpoint
sampler2 = DynestySampler(gp, save_dir="results/dynesty_demo")
sampler2.load_checkpoint()
print(f"log(Z) from checkpoint: {sampler2._info['logz']:.2f}")

5. Dynamic nested sampling#

Dynamic nested sampling adaptively allocates live points to improve both evidence and posterior estimates. Use run_dynamic_sampling() instead of run_sampling():

  • nlive_init controls the initial baseline run

  • nlive_batch controls how many live points are added per batch

  • wt_kwargs={"pfrac": 0.8} weights 80% toward posterior accuracy and 20% toward evidence

sampler_dyn = DynestySampler(gp, save_dir="results/dynesty_dynamic_demo")

t0 = time.time()
samples_dyn, info_dyn = sampler_dyn.run_dynamic_sampling(
    nlive_init=250,
    nlive_batch=100,
    dlogz_init=0.5,
    maxbatch=5,
    wt_kwargs={"pfrac": 0.8},
)
elapsed = time.time() - t0

print(f"\nWall time     : {elapsed:.1f} s")
print(f"Samples shape : {samples_dyn.shape}")
print(f"log(Z)        : {info_dyn['logz']:.2f} +/- {info_dyn['logzerr']:.2f}")
print(f"Iterations    : {info_dyn['niter']}")
sampler_dyn.summary()

6. Visualize the posterior#

6a. Corner plot#

import corner

# Use the static nested sampling samples for plotting
display_samples = samples.copy()
lsk_idx = list(gp.param_keys).index("log_sigma_k")
display_samples[:, lsk_idx] = 10 ** display_samples[:, lsk_idx]

display_labels = [k.replace("log_sigma_k", "sigma_k") for k in gp.param_keys]
truths = [theta_true.get(k.replace("log_", ""), None) for k in gp.param_keys]

fig = corner.corner(
    display_samples,
    labels=display_labels,
    truths=truths,
    show_titles=True,
    title_fmt=".3f",
    color="steelblue",
)
fig.suptitle("Dynesty posterior corner plot", fontsize=14, y=1.01)
plt.show()

6b. Posterior covariance ellipses#

The inherited plot_covariance() method overlays the Laplace approximation (Gaussian ellipses from the Hessian at the MAP) against the nested sampling posterior samples.

fig, axes = sampler.plot_covariance(
    method="hessian_map",
    true_params=theta_true,
    samples=samples,
    color="steelblue",
    alpha=0.3,
)
plt.show()

6c. Posterior predictive check#

Plot the GP prediction at the posterior median against the data.

# Posterior median parameter vector
theta_median = {k: float(np.median(samples[:, i]))
                for i, k in enumerate(gp.param_keys)}

tlag_plot = np.arange(0, 3 * model.peq, tsamp)
fig, axes = plt.subplots(3, 1, figsize=(12, 11))

gp.plot_prediction(theta=theta_map,    ax=axes[0],
                   model_color="steelblue", model_label="MAP")
gp.plot_prediction(theta=theta_median, ax=axes[0],
                   model_color="tomato",    model_label="Posterior median",
                   data_color=None,         data_label=None)

gp.plot_acf(theta=theta_map,    ax=axes[1], tlags=tlag_plot,
            model_color="steelblue", model_label="MAP")
gp.plot_acf(theta=theta_median, ax=axes[1], tlags=tlag_plot,
            model_color="tomato",    model_label="Posterior median")

gp.plot_psd(theta=theta_map,    ax=axes[2],
            model_color="steelblue", model_label="MAP")
gp.plot_psd(theta=theta_median, ax=axes[2],
            model_color="tomato",    model_label="Posterior median")

fig.suptitle("MAP vs posterior median (dynesty)", fontsize=14)
fig.tight_layout()
plt.show()