Nested Sampling with Dynesty#
This notebook fits GP hyperparameters to a synthetic stellar lightcurve using nested sampling via dynesty. Unlike MCMC, nested sampling also estimates the Bayesian evidence, which is useful for model comparison.
Sections
Build the spot model and simulate a lightcurve
Find the MAP estimate with
GPSolver.fit_map()Run nested sampling with
DynestySamplerInspect diagnostics and posterior summary
Dynamic nested sampling
Visualize the posterior
See also
For gradient-based MCMC sampling, see the BlackJAX MCMC tutorial.
import sys
sys.path.append("../..")
import numpy as np
import matplotlib.pyplot as plt
from spotgp import (
TrapezoidSymmetricEnvelope,
VisibilityFunction,
SpotEvolutionModel,
LightcurveModel,
GPSolver,
)
from spotgp.mcmc import DynestySampler
np.random.seed(42)
1. Build the spot model and simulate a lightcurve#
This section is identical to the BlackJAX MCMC tutorial so that both samplers can be compared on the same data.
envelope = TrapezoidSymmetricEnvelope(lspot=12.0, tau_spot=6.0)
visibility = VisibilityFunction(peq=5.0, kappa=0.3, inc=np.pi / 3)
model = SpotEvolutionModel(
envelope=envelope,
visibility=visibility,
nspot_rate=0.25, # spots per day
alpha_max=0.05, # peak spot angular radius [rad]
fspot=0.0,
)
print("param_keys:", model.param_keys)
print(f"sigma_k = {model.sigma_k:.5f}")
tsim = 200
tsamp = 0.5
nspot = int(tsim * 0.25)
lc = LightcurveModel.from_spot_model(
model, nspot=nspot, tsim=tsim, tsamp=tsamp,
long=[0, 2 * np.pi],
)
sigma_n = 0.3 * np.std(lc.flux)
flux_obs = lc.flux + np.random.normal(0, sigma_n, lc.flux.shape)
flux_err = np.full_like(flux_obs, sigma_n)
tobs = lc.t
print(f"Lightcurve: {len(tobs)} points over {tsim} days")
print(f"Signal std: {np.std(lc.flux):.5f} Noise: {sigma_n:.5f}")
fig, ax = plt.subplots(figsize=(13, 4))
ax.errorbar(tobs, (flux_obs - 1) * 100, yerr=flux_err * 100,
fmt=".k", ms=3, capsize=0, alpha=0.5, label="Observed")
ax.plot(tobs, (lc.flux - 1) * 100, "r-", lw=1.2, label="True")
ax.set_xlabel("Time [days]", fontsize=20)
ax.set_ylabel(r"$\Delta$ Flux [%]", fontsize=20)
ax.legend(fontsize=14)
ax.set_xlim(tobs[0], tobs[-1])
fig.tight_layout()
plt.show()
2. Find the MAP estimate#
See also
For a detailed tutorial on MAP optimization see GPSolver Tutorial
bounds = {
"peq": (2.0, 10.0),
"kappa": (-0.5, 0.8),
"inc": (0.1, np.pi / 2),
"lspot": (1.0, 30.0),
"tau_spot": (0.5, 20.0),
"log_sigma_k": (-5.0, -1.0),
}
gp = GPSolver(tobs, flux_obs, flux_err, model, bounds=bounds).build_jax()
print("param_keys:", gp.param_keys)
print("bandwidth :", gp.bandwidth, "pts")
import time
t0 = time.time()
theta_map, opt_result = gp.fit_map(nopt=5, method="L-BFGS-B")
print(f"fit_map() wall time: {time.time() - t0:.1f} s (converged: {opt_result.success})")
# True values for comparison
theta_true = {
"peq": model.peq,
"kappa": model.kappa,
"inc": model.inc,
"lspot": model.lspot,
"tau_spot": model.tau_spot,
"sigma_k": model.sigma_k,
}
print(f"\n{'param':>12s} {'true':>10s} {'MAP':>10s}")
print("-" * 38)
for k, v_true in theta_true.items():
v_map = theta_map.get(k, theta_map.get("log_" + k, float("nan")))
print(f"{k:>12s} {v_true:10.4f} {v_map:10.4f}")
3. Run nested sampling with DynestySampler#
DynestySampler wraps GPSolver and provides nested sampling via dynesty.
The workflow is:
Create the sampler with a
GPSolverand (optionally) a custom prior transformrun_sampling()— run the nested sampler to completion
By default, DynestySampler constructs a uniform prior transform from the
GP bounds. You can pass a custom prior_transform function to the constructor
if you need non-uniform priors.
Key parameters for run_sampling:
Parameter |
Default |
Description |
|---|---|---|
|
500 |
Number of live points |
|
0.01 |
Stopping criterion on remaining log-evidence |
|
|
Bounding method ( |
|
|
Sampling method ( |
Unlike MCMC, nested sampling does not require warmup or step-size adaptation. It also produces the log-evidence \(\ln \mathcal{Z}\) as a direct output, which is useful for model comparison via Bayes factors.
sampler = DynestySampler(gp, save_dir="results/dynesty_demo")
t0 = time.time()
samples, info = sampler.run_sampling(
nlive=300,
dlogz=0.5,
bound="multi",
sample="rwalk",
)
elapsed = time.time() - t0
print(f"\nWall time : {elapsed:.1f} s")
print(f"Samples shape : {samples.shape}")
print(f"log(Z) : {info['logz']:.2f} +/- {info['logzerr']:.2f}")
print(f"Iterations : {info['niter']}")
print(f"Efficiency : {info['eff']:.1f}%")
Custom prior transform#
By default, DynestySampler uses a uniform prior over the bounds. If you
need a different prior (e.g. a Gaussian or log-uniform prior on certain
parameters), pass a custom prior_transform function:
def my_prior_transform(u):
"""Map unit cube [0, 1]^n to physical parameter space."""
theta = np.empty_like(u)
theta[0] = 2.0 + u[0] * 8.0 # peq: uniform [2, 10]
theta[1] = -0.5 + u[1] * 1.3 # kappa: uniform [-0.5, 0.8]
# ... etc.
return theta
sampler = DynestySampler(gp, prior_transform=my_prior_transform)
4. Diagnostics and posterior summary#
sampler.summary() prints mean, std, and 16/50/84 percentiles for each
parameter — the same interface as BlackJAXSampler.
stats = sampler.summary()
Checkpointing#
Results are automatically saved when run_sampling() completes.
You can also save and load manually:
# Save checkpoint
sampler.save_checkpoint()
# Load samples from disk (without restoring the full sampler state)
all_samples = DynestySampler.load_samples("results/dynesty_demo/mcmc_checkpoint.npz")
print(f"Samples loaded from disk: {all_samples.shape}")
# Restore full sampler state from checkpoint
sampler2 = DynestySampler(gp, save_dir="results/dynesty_demo")
sampler2.load_checkpoint()
print(f"log(Z) from checkpoint: {sampler2._info['logz']:.2f}")
5. Dynamic nested sampling#
Dynamic nested sampling adaptively allocates live points to improve
both evidence and posterior estimates. Use run_dynamic_sampling() instead
of run_sampling():
nlive_initcontrols the initial baseline runnlive_batchcontrols how many live points are added per batchwt_kwargs={"pfrac": 0.8}weights 80% toward posterior accuracy and 20% toward evidence
sampler_dyn = DynestySampler(gp, save_dir="results/dynesty_dynamic_demo")
t0 = time.time()
samples_dyn, info_dyn = sampler_dyn.run_dynamic_sampling(
nlive_init=250,
nlive_batch=100,
dlogz_init=0.5,
maxbatch=5,
wt_kwargs={"pfrac": 0.8},
)
elapsed = time.time() - t0
print(f"\nWall time : {elapsed:.1f} s")
print(f"Samples shape : {samples_dyn.shape}")
print(f"log(Z) : {info_dyn['logz']:.2f} +/- {info_dyn['logzerr']:.2f}")
print(f"Iterations : {info_dyn['niter']}")
sampler_dyn.summary()
6. Visualize the posterior#
6a. Corner plot#
import corner
# Use the static nested sampling samples for plotting
display_samples = samples.copy()
lsk_idx = list(gp.param_keys).index("log_sigma_k")
display_samples[:, lsk_idx] = 10 ** display_samples[:, lsk_idx]
display_labels = [k.replace("log_sigma_k", "sigma_k") for k in gp.param_keys]
truths = [theta_true.get(k.replace("log_", ""), None) for k in gp.param_keys]
fig = corner.corner(
display_samples,
labels=display_labels,
truths=truths,
show_titles=True,
title_fmt=".3f",
color="steelblue",
)
fig.suptitle("Dynesty posterior corner plot", fontsize=14, y=1.01)
plt.show()
6b. Posterior covariance ellipses#
The inherited plot_covariance() method overlays the Laplace approximation
(Gaussian ellipses from the Hessian at the MAP) against the nested sampling
posterior samples.
fig, axes = sampler.plot_covariance(
method="hessian_map",
true_params=theta_true,
samples=samples,
color="steelblue",
alpha=0.3,
)
plt.show()
6c. Posterior predictive check#
Plot the GP prediction at the posterior median against the data.
# Posterior median parameter vector
theta_median = {k: float(np.median(samples[:, i]))
for i, k in enumerate(gp.param_keys)}
tlag_plot = np.arange(0, 3 * model.peq, tsamp)
fig, axes = plt.subplots(3, 1, figsize=(12, 11))
gp.plot_prediction(theta=theta_map, ax=axes[0],
model_color="steelblue", model_label="MAP")
gp.plot_prediction(theta=theta_median, ax=axes[0],
model_color="tomato", model_label="Posterior median",
data_color=None, data_label=None)
gp.plot_acf(theta=theta_map, ax=axes[1], tlags=tlag_plot,
model_color="steelblue", model_label="MAP")
gp.plot_acf(theta=theta_median, ax=axes[1], tlags=tlag_plot,
model_color="tomato", model_label="Posterior median")
gp.plot_psd(theta=theta_map, ax=axes[2],
model_color="steelblue", model_label="MAP")
gp.plot_psd(theta=theta_median, ax=axes[2],
model_color="tomato", model_label="Posterior median")
fig.suptitle("MAP vs posterior median (dynesty)", fontsize=14)
fig.tight_layout()
plt.show()