API Reference (spotgp)#
Lightcurve Model#
- class spotgp.lightcurve.LightcurveModel(peq=4.0, kappa=0.0, inc=1.5707963267948966, nspot=None, tau_spot=None, tem=2, tdec=2, alpha_max=0.1, fspot=0, lspot=5, long=[0, 6.283185307179586], lat=[-1.5707963267948966, 1.5707963267948966], tsim=28, tsamp=0.02, limb_darkening=False, tmax=None, rotate=True, grow=True, nspot_rate=None)[source]#
Bases:
objectJAX-accelerated star with spots and its lightcurve.
Same interface as the numpy version but uses JAX for vectorized computation across all spots simultaneously.
- Parameters:
peq (float) – Equatorial period of the star.
kappa (float) – Differential rotation shear.
inc (float) – Inclination of the star.
nspot (int) – Number of spots.
tau_spot (float, optional) – Timescale for both emergence and decay of the spots. Defaults to None.
tem (float, optional) – Emergence timescale of the spots. Defaults to 2.
tdec (float, optional) – Decay timescale of the spots. Defaults to 2.
alpha_max (float, optional) – Maximum angular area of the spots. Defaults to 0.1.
fspot (float, optional) – Spot contrast fraction. Defaults to 0.
lspot (float, optional) – Spot lifetime. Defaults to 5.
long (list, optional) – Range of spot longitudes. Defaults to [0, 2*pi].
lat (list, optional) – Range of spot latitudes. Defaults to [0, pi].
tsim (float, optional) – End simulation time. Defaults to 28.
tsamp (float, optional) – Sampling cadence. Defaults to 0.02.
limb_darkening (bool, optional) – Flag to enable limb darkening. Defaults to False.
- classmethod from_spot_model(spot_model: SpotEvolutionModel, nspot: int = None, *, nspot_rate: float = None, **kwargs)[source]#
Construct a LightcurveModel from a SpotEvolutionModel.
- Parameters:
spot_model (SpotEvolutionModel) – Fully configured spot evolution model.
nspot (int, optional) – Total number of spots to simulate.
nspot_rate (float, optional) – Spot emergence rate [spots/day]. The actual number of spots is
max(1, int(nspot_rate * tsim)). Exactly one ofnspotornspot_ratemust be provided.**kwargs – Forwarded to LightcurveModel.__init__ (e.g. tsim, tsamp, lat, long).
- Return type:
- classmethod from_hparam(hparam: dict, nspot: int = None, *, nspot_rate: float = None, **kwargs)[source]#
Construct a LightcurveModel from a GPSolver-compatible hparam dict.
Accepts the same raw hparam dict that GPSolver/AnalyticKernel take, including all amplitude modes (sigma_k, nspot_rate, or nspot), and both symmetric (tau) and asymmetric (tau_em + tau_dec) envelopes. This removes the need to manually decompose the dict in scripts.
- Parameters:
hparam (dict) – Raw hyperparameter dict. Must contain peq, kappa, inc, lspot, tau_spot (or tau_em/tau_dec), and an amplitude specification.
nspot (int, optional) – Total number of spots to simulate.
nspot_rate (float, optional) – Spot emergence rate [spots/day]. Exactly one of
nspotornspot_ratemust be provided.**kwargs – Forwarded to LightcurveModel.__init__ (e.g. tsim, tsamp, lat, long).
- Return type:
- Flux(teval)[source]#
Compute the full lightcurve using JAX vmap over all spots.
Instead of a Python loop over nspot, all spots are computed in parallel via JAX’s vmap.
- animate_lightcurve(fps=30, duration=10.0, outfile=None, dpi=150, show_spots=True, show_grid=True, show_params=True, figsize=(14, 5.5), save_last_frame=None, show_dr=True, label_size=18)[source]#
Animate the starspot evolution with two panels: a 2D projection of the rotating star (left) and the lightcurve (right).
- Parameters:
fps (int) – Frames per second (default 30).
duration (float) – Animation duration in seconds (default 10).
outfile (str or None) – Output file path (.mp4 or .gif). If None, returns the animation object without saving.
dpi (int) – Resolution (default 150).
show_spots (bool) – If True, show individual spot contributions on the lightcurve panel (default True).
show_grid (bool) – If True, draw latitude/longitude grid on the star (default True).
show_params (bool) – If True, show parameter annotation on the lightcurve panel (default True).
figsize (tuple) – Figure size (default (14, 5.5)).
save_last_frame (str or None) – If provided, save the last frame of the animation as a static image to this file path (e.g. “frame.png”).
show_dr (bool) – If True, color the stellar disk by latitude-dependent rotation frequency and display a colorbar (default True).
label_size (int or float) – Font size for all labels, tick marks, and text in the plot (default 18).
- Returns:
anim – The animation object.
- Return type:
matplotlib.animation.FuncAnimation
- animate_butterfly(fps=30, duration=10.0, outfile=None, dpi=150, show_spots=True, show_grid=True, show_params=True, figsize=(18, 5.5), save_last_frame=None, show_dr=True, label_size=18)[source]#
Animate the starspot evolution with three panels: a 2D projection of the rotating star (left), the lightcurve (center), and a butterfly diagram of spot latitude vs. time (right).
- Parameters:
fps (int) – Frames per second (default 30).
duration (float) – Animation duration in seconds (default 10).
outfile (str or None) – Output file path (.mp4 or .gif). If None, returns the animation object without saving.
dpi (int) – Resolution (default 150).
show_spots (bool) – If True, show individual spot contributions on the lightcurve panel (default True).
show_grid (bool) – If True, draw latitude/longitude grid on the star (default True).
show_params (bool) – If True, show parameter annotation above the figure (default True).
figsize (tuple) – Figure size (default (18, 5.5)).
save_last_frame (str or None) – If provided, save the last frame of the animation as a static image to this file path (e.g. “frame.png”).
show_dr (bool) – If True, color the stellar disk by latitude-dependent rotation frequency and display a colorbar (default True).
label_size (int or float) – Font size for all labels, tick marks, and text in the plot (default 18).
- Returns:
anim – The animation object.
- Return type:
matplotlib.animation.FuncAnimation
Analytic Kernel#
- class spotgp.analytic_kernel.AnalyticKernel(model_or_hparam, n_harmonics=3, n_lat=64, lat_range=None, quadrature='trapezoid')[source]#
Bases:
objectJAX-accelerated analytic GP kernel for stellar rotation variability.
- Parameters:
model_or_hparam (SpotEvolutionModel or dict) – Either a SpotEvolutionModel instance (new API) or a raw hparam dict (backward-compatible old API).
n_harmonics (int) – Number of Fourier harmonics for the visibility function (default 3).
n_lat (int) – Number of latitude quadrature points (default 64).
lat_range (tuple) – (min, max) latitude in radians (default (-pi/2, pi/2)).
quadrature (str) – Latitude integration method: “trapezoid” or “gauss-legendre”.
- kernel(lag, lat_dist=None)[source]#
Full GP kernel averaged over latitude.
Uses jax.lax.scan for memory-efficient accumulation: only one lag-sized buffer is live at a time — O(M) instead of O(n_lat·M).
When the visibility function is an EdgeOnVisibilityFunction, the latitude-averaged
|c_n|^2are known constants and the latitude loop is bypassed entirely.- Parameters:
lag (array_like) – Time lags [days]. Can be 1D or 2D.
lat_dist (callable or None) – Latitude probability density. If None, uniform.
- Returns:
K
- Return type:
ndarray, same shape as lag input.
- compute_psd(omega, lat_dist=None)[source]#
Analytic power spectral density.
- Parameters:
omega (array_like) – Angular frequencies [rad/day].
lat_dist (callable or None) – Latitude probability density.
- Returns:
freq (ndarray [cycles/day])
power (ndarray)
- build_jax(n_lag=256)[source]#
Pre-compile and warm up JAX JIT computation for this kernel.
jax.lax.scan(used insidekernel()) triggers XLA compilation on its first call for a given array shape. That compilation can take several seconds and is easy to mistake for slow runtime. Callbuild_jax()once after constructing the kernel to pay that cost upfront — subsequent calls tokernel()andcompute_psd()with the same shape will be fast.- Parameters:
n_lag (int) – Length of the dummy lag array used to drive compilation (default 256). The actual value does not matter as long as it is representative of the sizes you will use at runtime.
- Returns:
self – Returns
selfso the call can be chained:ak = AnalyticKernel(model).build_jax().- Return type:
Numerical Kernel#
- class spotgp.numerical_kernel.NumericalKernel(model_or_hparam, tsim=20, tsamp=0.05, nsim=1000.0, verbose=True)[source]#
Bases:
objectGaussian Process for Stellar Rotation
- Parameters:
hparam (dict) –
Dict of hyperparameters. Required keys: peq, kappa, inc, lspot, tau_spot, alpha_max. For the kernel amplitude, provide EITHER:
sigma_k : overall amplitude prefactor, OR
nspot + fspot : number of spots and spot contrast, from which sigma_k is computed as sqrt(N_spot) * (1 - f_spot) / pi.
Note: nspot is always required for the numerical simulations.
tsim (float) – Simulation time (default: 20).
tsamp (float) – Time sampling (default: 0.05).
nsim (int) – Number of simulations (default: 1e3).
verbose (bool) – Whether to print verbose output (default: True).
GP Solver#
- class spotgp.gp_solver.GPSolver(data_or_x, y=None, yerr=None, model_or_hparam=None, kernel_type='analytic', mean=None, fit_sigma_n=False, bounds=None, log_prior=None, matrix_solver='cholesky_banded', bandwidth=None, save_dir=None, **kernel_kwargs)[source]#
Bases:
objectJAX-accelerated Gaussian Process solver for stellar lightcurves.
Handles covariance matrix construction, Cholesky factorization, log-likelihood evaluation, prediction, MAP estimation, and mass matrix computation.
- Parameters:
data_or_x (TimeSeriesData or array_like, shape (N,)) – Either a
TimeSeriesDataobject, or observation times [days]. When aTimeSeriesDatais passed,yandyerrmust be None (they are read from the data object).y (array_like, shape (N,), optional) – Observed flux values. Required when
data_or_xis an array.yerr (array_like, shape (N,) or float, optional) – Measurement uncertainties (1-sigma). Required when
data_or_xis an array.model_or_hparam (SpotEvolutionModel or dict) – Either a SpotEvolutionModel (new API) or a raw hparam dict (backward-compatible old API).
kernel_type ({"analytic"}) – Which kernel to use (default: “analytic”).
mean (float or callable or None) – Mean function.
fit_sigma_n (bool) – If True, include white noise amplitude sigma_n as a free parameter for optimization/sampling (default False).
bounds (dict or None) – Parameter bounds for optimization. If None, uses defaults.
log_prior (callable or None) – Custom log-prior function f(theta_arr) -> scalar. If None, uses soft uniform within bounds.
kernel_kwargs (dict) – Extra kwargs forwarded to the kernel constructor.
- DEFAULT_BOUNDS = {'inc': (0.01, 3.1315926535897933), 'kappa': (0.001, 0.999), 'lat_max': (0.0, 1.5707963267948966), 'lat_min': (0.0, 1.5707963267948966), 'lspot': (0.1, 20.0), 'n_sn': (-10.0, 10.0), 'peq': (0.5, 50.0), 'sigma_k': (1e-06, 1.0), 'sigma_n': (1e-06, 0.1), 'sigma_sn': (0.05, 10.0), 'tau_dec': (0.05, 10.0), 'tau_em': (0.05, 10.0), 'tau_spot': (0.05, 10.0)}#
- build_jax(recompute=True)[source]#
Pre-compile and warm up all JAX JIT functions for this solver.
_build_logposterior()creates four@jax.jit-decorated functions (log_posterior,neg_log_posterior,grad_log_posterior,grad_neg_log_posterior) that each trigger a separate XLA compilation on their first call. Combined with the banded Cholesky solver, that can add up to 10–30 s of invisible overhead before a fit or MCMC run starts.Call
build_jax()once after constructing the solver to pay that cost upfront.- Returns:
self – Returns
selfso the call can be chained:gp = GPSolver(...).build_jax().- Return type:
- log_likelihood()[source]#
Marginal log-likelihood of the data under the GP.
- Returns:
logL – log p(y | X, theta)
- Return type:
- predict(xpred, return_cov=False)[source]#
Predictive distribution at new input locations.
- Parameters:
xpred (array_like, shape (M,)) – Prediction times.
return_cov (bool) – If True, return full predictive covariance.
- Returns:
mu_pred (ndarray, shape (M,))
var_pred (ndarray, shape (M,) or (M, M))
- plot_prediction(theta=None, n_points=2000, n_sigma=(1, 2), ax=None, data_color='k', model_color='r', show_legend=True, xlim=None, ylim=None, model_label='GP mean', data_label='Data')[source]#
Plot the GP posterior mean and uncertainty bands over the data.
If
thetais provided the GP is temporarily updated to those hyperparameters before predicting, so the prediction reflects the given parameter values rather than whatever was last set internally.- Parameters:
theta (dict or array_like, shape (6,), optional) – Kernel parameters. Accepts a physical dict with keys from
KERNEL_HPARAM_KEYS, a sampling-space dict withlog_- prefixed keys (e.g.log_sigma_k), or a length-6 array. If None, uses the current internal hyperparameters.n_points (int) – Number of prediction points spanning the data baseline.
n_sigma (int or sequence of int) – Which sigma levels to shade. E.g.
(1, 2)draws both ±1σ and ±2σ bands (default). Pass a single int for one band.ax (matplotlib Axes, optional) – Axes to plot on. If None, creates a new figure.
data_color (str) – Colors for data points and model curve/bands.
model_color (str) – Colors for data points and model curve/bands.
show_legend (bool) – Whether to draw a legend.
xlim (tuple, optional) – Limits for the x-axis. If None, defaults to the data range.
ylim (tuple, optional) – Limits for the y-axis. If None, defaults to the data range.
- Returns:
ax
- Return type:
matplotlib Axes
- sample_lightcurves(theta=None, xpred=None, n_samples=5, n_points=2000, source='prior', rng=None)[source]#
Sample lightcurves from the GP prior or posterior.
- Parameters:
theta (dict or array_like, optional) – Kernel parameters. Accepts a physical dict with keys from
param_keys, a sampling-space dict withlog_-prefixed keys, or a flat array matchingparam_keys. If None, uses the current internal hyperparameters.xpred (array_like, optional) – Times at which to evaluate the samples. If None, uses
n_pointsevenly spaced times spanning the data baseline.n_samples (int) – Number of lightcurve samples to draw (default 5).
n_points (int) – Number of prediction points when
xpredis None (default 2000).source ({'prior', 'posterior'}) – Whether to sample from the GP prior or posterior (default ‘prior’).
rng (numpy.random.Generator, optional) – Random number generator for reproducibility.
- Returns:
xpred (ndarray, shape (M,)) – Prediction times.
samples (ndarray, shape (n_samples, M)) – Sampled lightcurves.
- plot_samples(theta=None, xpred=None, n_samples=5, n_points=2000, source='prior', rng=None, ax=None, show_data=True, data_color='k', sample_alpha=0.7, sample_lw=1.0, cmap='tab10', show_legend=True, xlim=None, ylim=None)[source]#
Plot sampled lightcurves from the GP prior or posterior.
- Parameters:
theta (dict or array_like, optional) – Kernel parameters (see
sample_lightcurves).xpred (array_like, optional) – Prediction times. If None,
n_pointsevenly spaced.n_samples (int) – Number of samples to draw and plot (default 5).
n_points (int) – Number of prediction points when
xpredis None.source ({'prior', 'posterior'}) – Sample from the GP prior or posterior (default ‘prior’).
rng (numpy.random.Generator, optional) – Random number generator for reproducibility.
ax (matplotlib Axes, optional) – Axes to plot on. If None, creates a new figure.
show_data (bool) – Whether to overlay the observed data (default True).
data_color (str) – Color for data points (default ‘k’).
sample_alpha (float) – Opacity for sample curves (default 0.7).
sample_lw (float) – Line width for sample curves (default 1.0).
cmap (str) – Matplotlib colormap name for sample colors (default ‘tab10’).
show_legend (bool) – Whether to draw a legend (default True).
xlim (tuple, optional) – Axis limits.
ylim (tuple, optional) – Axis limits.
- Returns:
ax
- Return type:
matplotlib Axes
- compute_acf(tlags=None, n_bins=50, normalize=True)[source]#
Compute the empirical autocorrelation function of the data.
Delegates to
self.data.compute_acf().- Parameters:
tlags (array_like, optional) – Bin edges for time lags [days]. If provided,
n_binsis inferred aslen(tlags) - 1. If None,n_binslinearly spaced bins from 0 to half the baseline are used.n_bins (int) – Number of lag bins (used when
tlagsis None, default 50).normalize (bool) – If True (default), normalize so ACF(0) ~ 1.
- Returns:
lag_centers (ndarray, shape (n_bins,)) – Bin centers.
acf (ndarray, shape (n_bins,)) – Empirical ACF at each bin center.
- compute_kernel(tlags)[source]#
Evaluate the analytic kernel at the given time lags.
- Parameters:
tlags (array_like, shape (M,)) – Time lags [days].
- Returns:
K – Kernel values at each lag.
- Return type:
ndarray, shape (M,)
- fit_acf(theta0=None, keys=None, tlags=None, n_bins=50, method='L-BFGS-B', maxiter=500, ftol=0, gtol=1e-08, disp=False, nopt=1, ncore=None, rng=None, _save=True)[source]#
Fit the analytic kernel to the empirical ACF via least-squares.
Minimizes sum_i (ACF_data(lag_i) - K(lag_i; theta))^2 over the kernel hyperparameters, using JAX gradients and scipy.
- Parameters:
theta0 (dict or array_like, optional) –
- Starting point. Can be:
None: uses self.theta0 (kernel params only, no sigma_n).
dict: values for any subset of kernel keys set the starting point. If
keysis not given, the dict keys that overlap withKERNEL_HPARAM_KEYSare treated as the free variables; the rest are held fixed. Extra keys not inKERNEL_HPARAM_KEYSare ignored.array_like: full kernel theta vector (6 elements).
keys (list of str, optional) – Which parameters to vary during optimization. Overrides the automatic inference from a dict
theta0. Parameters not listed are held fixed at their current values. If None and theta0 is not a dict, all kernel parameters are varied.tlags (array_like, optional) – Bin edges for compute_acf. If None, linearly spaced from 0 to half the baseline with n_bins+1 edges.
n_bins (int) – Number of lag bins (used when tlags is None).
method (str) – Scipy optimizer method.
maxiter (int) – Maximum iterations.
ftol (float) – Function-value convergence tolerance (default 0, disabled).
gtol (float) – Gradient-norm convergence tolerance (default 1e-8).
disp (bool) – If True, print optimizer convergence messages (default False).
nopt (int) – Number of independent optimisation trials (default 1). When > 1,
fit_acf_parallelis called and the best result across all trials is returned.ncore (int or None) – Number of parallel workers. Only used when
nopt > 1.rng (numpy.random.Generator, optional) – RNG for random starting points. Only used when
nopt > 1.
- Returns:
theta_dict (dict) – Full dictionary of all kernel hyperparameters (fixed + optimized).
result (scipy.optimize.OptimizeResult) – Full optimizer output.
- fit_acf_parallel(nopt=10, ncore=None, keys=None, tlags=None, n_bins=50, method='nelder-mead', maxiter=500, ftol=0, gtol=1e-08, disp=False, return_all=False, rng=None)[source]#
Run
fit_acffrom multiple random starting points in parallel.Starting points are drawn uniformly within the kernel parameter bounds.
- Parameters:
nopt (int) – Number of independent optimization trials (default 10).
ncore (int or None) – Number of parallel workers. If None, uses
noptor the number of available CPUs, whichever is smaller.keys (list of str, optional) – Free parameters (forwarded to
fit_acf).tlags – Forwarded to
fit_acf.n_bins – Forwarded to
fit_acf.method (str) – Optimizer method (default “nelder-mead”).
maxiter – Forwarded to
fit_acf.ftol – Forwarded to
fit_acf.gtol – Forwarded to
fit_acf.disp – Forwarded to
fit_acf.return_all (bool) – If True, return all solutions sorted by objective value. If False (default), return only the best solution.
rng (numpy.random.Generator, optional) – Random number generator for reproducibility.
- Returns:
theta_best (dict (or list of dict if
return_all=True)) – Best-fit kernel hyperparameters.result_best (scipy.optimize.OptimizeResult) – (or list of OptimizeResult if
return_all=True)
- fit_acf_psd(theta0=None, keys=None, tlags=None, n_bins=50, n_freq=200, dt_kernel=None, acf_weight=1.0, psd_weight=1.0, method='L-BFGS-B', maxiter=500, ftol=0, gtol=1e-08, disp=False)[source]#
Fit kernel parameters jointly to the empirical ACF and PSD.
Minimizes a weighted sum of two normalized mean-squared-error terms:
loss = acf_weight * acf_loss + psd_weight * psd_loss
where
acf_loss = mean((ACF_data - K_model)^2) / mean(ACF_data^2)
is the relative MSE of the kernel against the empirical ACF (unnormalized autocovariance), and
psd_loss = mean((PSD_data_norm - PSD_model_norm)^2)
is the MSE between the Lomb-Scargle periodogram and the analytic kernel PSD, both normalized to unit integral so the comparison is independent of overall amplitude.
The model PSD is computed via a direct cosine transform of the kernel evaluated on a uniform lag grid, making it fully differentiable with respect to the kernel parameters.
- Parameters:
theta0 (dict or array_like, optional) – Starting point in
self.param_keysspace (sampling space, withlog_-prefixed keys where applicable). Follows the same convention asfit_map: None usesself.theta0, a dict overrides named entries and infers free keys, an array is used directly.keys (list of str, optional) – Parameters to vary during optimization (names from
self.param_keys). Defaults to all kernel parameters (first 6 entries ofself.param_keys, i.e. excludingsigma_nif present).tlags (array_like, optional) – Bin edges for the empirical ACF. If None,
n_bins+1edges linearly spaced from 0 to half the baseline.n_bins (int) – Number of ACF lag bins when
tlagsis None (default 50).n_freq (int) – Number of frequency points for the Lomb-Scargle periodogram (default 200).
dt_kernel (float, optional) – Uniform lag spacing [days] for evaluating the analytic kernel before the direct cosine transform. Defaults to one-fifth of the median data spacing.
acf_weight (float) – Weight for the ACF loss term (default 1.0).
psd_weight (float) – Weight for the PSD loss term (default 1.0).
method (str) – Scipy optimizer method (default
"L-BFGS-B").maxiter (int) – Maximum optimizer iterations (default 500).
ftol (float) – Convergence tolerances forwarded to scipy.
gtol (float) – Convergence tolerances forwarded to scipy.
disp (bool) – Print optimizer messages if True.
- Returns:
theta_dict (dict) – Best-fit parameters in
self.param_keysspace.result (scipy.optimize.OptimizeResult)
- plot_acf(theta=None, tlags=None, n_bins=50, ax=None, normalize=False, data_color='k', model_color='r', show_legend=True, xlim=None, ylim=None, model_label='Analytic ACF', data_label='Data ACF')[source]#
Plot the empirical ACF and optionally the analytic kernel.
- Parameters:
theta (dict or array_like, shape (6,), optional) – Kernel parameters. Accepts a physical dict with keys from
KERNEL_HPARAM_KEYS, a sampling-space dict withlog_- prefixed keys (e.g.log_sigma_k), or a length-6 array. If provided, the analytic kernel is overplotted.tlags (array_like, optional) – Bin edges for compute_acf. If None, linearly spaced from 0 to half the baseline with n_bins+1 edges.
n_bins (int) – Number of lag bins (used when tlags is None).
ax (matplotlib Axes, optional) – Axes to plot on. If None, creates a new figure.
normalize (bool) – If True (default), normalize both curves by the data variance so ACF(0) ≈ 1.
xlim (tuple, optional) – Limits for the x-axis. If None, defaults to the data range.
ylim (tuple, optional) – Limits for the y-axis. If None, defaults to the data range.
- Returns:
ax
- Return type:
matplotlib Axes
- plot_psd(theta=None, n_freq=500, dt_kernel=None, ax=None, data_color='k', model_color='r', show_legend=True, xlim=None, ylim=None, model_label='Analytic PSD', data_label='Data Lomb-Scargle')[source]#
Plot the empirical PSD (Lomb-Scargle) and optionally the analytic kernel PSD (FFT of the autocovariance function).
Both curves are normalized so their integral over positive frequencies equals the data variance, making them directly comparable.
- Parameters:
theta (dict or array_like, shape (6,), optional) – Kernel parameters. Accepts a physical dict with keys from
KERNEL_HPARAM_KEYS, a sampling-space dict withlog_- prefixed keys (e.g.log_sigma_k), or a length-6 array. If provided, the analytic kernel PSD is overplotted.n_freq (int) – Number of frequency points for the Lomb-Scargle periodogram.
dt_kernel (float, optional) – Time step [days] for evaluating the analytic kernel on a uniform grid before FFT. Defaults to one-fifth of the median data spacing.
ax (matplotlib Axes, optional) – Axes to plot on. If None, creates a new figure.
data_color (str) – Colors for the data and model curves.
model_color (str) – Colors for the data and model curves.
show_legend (bool) – Whether to draw a legend.
xlim (tuple, optional) – Limits for the x-axis. If None, defaults to the data range.
ylim (tuple, optional) – Limits for the y-axis. If None, defaults to the data range.
- Returns:
ax
- Return type:
matplotlib Axes
- plot_covariance_matrix(theta=None, ax=None, cmap='RdBu_r', show_colorbar=True, vmax=None, nbins=50, show=False, filename='covariance_matrix.png')[source]#
Plot the GP covariance matrix K (signal only, no noise).
Entries outside the banded support are set to zero, matching the
cholesky_bandedapproximation. The matrix is binned tonbins x nbinsbefore plotting. The bandwidth boundary is drawn as dashed lines, and band width plus matrix sparsity are annotated.- Parameters:
theta (dict or array_like, optional) – Kernel hyperparameters. Accepts a physical dict with keys from
param_keys, a sampling-space dict withlog_-prefixed keys, or a raw array. If None, uses the currentself.hparamvalues.ax (matplotlib Axes, optional) – Axes to plot on. If None, a new figure is created.
cmap (str, optional) – Colormap name. Defaults to
"RdBu_r"(diverging, centred at zero).show_colorbar (bool, optional) – Whether to add a colorbar. Default True.
vmax (float, optional) – Symmetric color scale limit
[-vmax, vmax]. If None, uses the maximum absolute value of the banded matrix.nbins (int, optional) – Bin the N x N matrix down to
nbins x nbinsby averaging non-overlapping blocks before plotting. Default 50.show (bool, optional) – If True, call
plt.show(). Default False.filename (str, optional) – Filename used when saving to
save_dir. Default"covariance_matrix.png".
- Returns:
ax
- Return type:
matplotlib Axes
- get_theta()[source]#
Return the current kernel hyperparameters as a dictionary.
- Returns:
theta – Keys and values for all kernel (and optionally noise) hyperparameters, e.g. {“peq”: 5.0, “kappa”: 0.2, …}.
- Return type:
- update_hparam(hparam)[source]#
Update hyperparameters and rebuild kernel and covariance.
Accepts a SpotEvolutionModel, an hparam dict (legacy keys like
lspot), or a theta-style dict whose keys matchself.spot_model.param_keys(e.g.tau_em,lat_min).
- fit_map(theta0=None, keys=None, method='L-BFGS-B', maxiter=500, ftol=0, gtol=1e-08, disp=False, nopt=1, ncore=None, rng=None, _save=True)[source]#
Find the maximum a posteriori (MAP) estimate.
Uses scipy.optimize.minimize with JAX-computed gradients. When
nopt > 1, delegates tofit_map_parallelwhich runsnoptindependent trials from random starting points (drawn uniformly within the bounds) and returns the best result.- Parameters:
theta0 (dict or array_like, optional) –
- Starting point. Can be:
None: uses self.theta0 (current hyperparameters).
dict: values for any subset of param_keys set the starting point. If
keysis not given, the dict keys that overlap withself.param_keysare treated as the free variables to optimize; the rest are held fixed. Extra keys not inparam_keysare ignored.array_like: full theta vector (length
n_params).
Ignored when
nopt > 1(starting points are randomised).keys (list of str, optional) – Which parameters to vary during optimization. Overrides the automatic inference from a dict
theta0. Parameters not listed are held fixed at their current values. If None and theta0 is not a dict, all parameters are varied.method (str) – Scipy optimizer method (default “L-BFGS-B”).
maxiter (int) – Maximum iterations.
ftol (float) – Function-value convergence tolerance for L-BFGS-B (default 0, i.e. disabled so that convergence is controlled by
gtol).gtol (float) – Gradient-norm convergence tolerance (default 1e-8).
disp (bool) – If True, print optimizer convergence messages (default False).
nopt (int) – Number of independent optimisation trials (default 1). When > 1,
fit_map_parallelis called and the best result across all trials is returned.ncore (int or None) – Number of parallel workers for multi-start runs. If None, uses
noptor the number of available CPUs, whichever is smaller. Only used whennopt > 1.rng (numpy.random.Generator, optional) – RNG for random starting points. Only used when
nopt > 1.
- Returns:
theta_dict (dict) – Full dictionary of all hyperparameters (fixed + optimized).
result (scipy OptimizeResult) – Full optimizer output.
- fit_map_parallel(nopt=10, ncore=None, keys=None, method='nelder-mead', maxiter=500, ftol=0, gtol=1e-08, disp=False, return_all=False, rng=None, theta0=None, jitter=0.01)[source]#
Run
fit_mapfrom multiple random starting points in parallel.Starting points are drawn uniformly within the parameter bounds.
- Parameters:
nopt (int) – Number of independent optimization trials (default 10).
ncore (int or None) – Number of parallel workers. If None, uses
noptor the number of available CPUs, whichever is smaller.keys (list of str, optional) – Free parameters (forwarded to
fit_map).method (str) – Optimizer method (default “nelder-mead”).
maxiter – Forwarded to
fit_map.ftol – Forwarded to
fit_map.gtol – Forwarded to
fit_map.disp – Forwarded to
fit_map.return_all (bool) – If True, return all solutions sorted by objective value. If False (default), return only the best solution.
rng (numpy.random.Generator, optional) – Random number generator for reproducibility.
theta0 (dict, optional) – Initial parameter guess to include as one of the starting points. Replaces one random start so the total number of trials stays
nopt.
- Returns:
theta_best (dict (or list of dict if
return_all=True)) – Best-fit hyperparameters.result_best (scipy.optimize.OptimizeResult) – (or list of OptimizeResult if
return_all=True)
- mass_matrix_hessian_map(theta_map=None)[source]#
Estimate the inverse mass matrix from the Hessian of the negative log-likelihood at the MAP.
M^{-1} = H^{-1}whereH = d^2(-log L)/d theta^2at the MAP.
- mass_matrix_fisher(theta_map=None, eigval_clip=1e-06, white_noise=1e-08)[source]#
Estimate the inverse mass matrix from the Fisher information.
For the GP log-likelihood:
I_{ij} = (1/2) tr(K^{-1} dK/dtheta_i K^{-1} dK/dtheta_j)
When
matrix_solver="cholesky_full", the kernel derivatives dK/dtheta_i are computed via JAX forward-mode autodiff (jacfwd) on the full N×N covariance matrix.When
matrix_solver="cholesky_banded", the exact Fisher requires the dense N×N kernel and its inverse, which would defeat the purpose of banded storage. Instead, the Fisher is approximated by the Hessian of the banded negative log-likelihood at the MAP (Fisher ≈ observed information at the MLE).
Power Spectral Density#
- spotgp.psd.compute_psd(y, t=None, dt=None, normalization='psd', freq_min=None, freq_max=None, n_freq=None, samples_per_peak=5)[source]#
Compute the Power Spectral Density of a time series using astropy.timeseries.LombScargle.
Works for both evenly and unevenly sampled data.
- Parameters:
y (array-like, shape (N,)) – Time series values.
t (array-like, shape (N,), optional) – Sample times. If None, integer indices scaled by
dtare used.dt (float, optional) – Sampling interval. Used only when
tis None (default: 1).normalization ({"psd", "standard", "model", "log"}) – Passed directly to LombScargle.autopower / power.
freq_min (float, optional) – Minimum frequency to evaluate.
freq_max (float, optional) – Maximum frequency to evaluate.
n_freq (int, optional) – Number of frequency grid points.
samples_per_peak (float, optional) – Controls the frequency grid density (default 5).
- Returns:
freq (ndarray) – Frequencies in cycles per unit time.
power (ndarray) – PSD evaluated at each frequency.
MCMC Sampler#
- class spotgp.mcmc.MCMCSampler(gp)[source]#
Bases:
objectBase MCMC sampler for GP hyperparameters.
Wraps a GPSolver object and provides shared storage, diagnostics, summary statistics, corner plots, and dict conversion. Subclasses implement specific sampling algorithms (e.g. NUTS).
- Parameters:
gp (GPSolver) – A configured GPSolver instance.
- property param_keys#
- property n_params#
- summary()[source]#
Print summary statistics of the posterior samples.
- Returns:
stats – Parameter names mapped to (mean, std, 16%, 50%, 84%).
- Return type:
- plot_covariance(method='fisher', theta_map=None, n_sigma=2, n_grid=200, samples=None, figsize=None, color='C0', alpha=0.3, true_params=None, savefig=None, **corner_kwargs)[source]#
Corner plot of 2D covariance ellipses from the Hessian or Fisher matrix, with 1D marginal Gaussians on the diagonal.
Uses
corner.cornerto lay out the figure when MCMC samples are provided, and overlays the Laplace/Fisher Gaussian approximation (ellipses + 1D marginals).- Parameters:
method ({"fisher", "hessian_map", "laplace"}) – Which matrix to use for the Gaussian approximation.
theta_map (array_like, optional) – Center of the ellipses. If None, uses MAP estimate.
n_sigma (float) – Number of sigma for the ellipse contours (default 2).
n_grid (int) – Grid resolution for the ellipse curves (default 200).
samples (array_like, optional) – If provided, plotted as the corner histogram/contours. If None, the figure is created with empty axes and only the Gaussian approximation is drawn.
figsize (tuple, optional) – Figure size.
color (str) – Color for Gaussian ellipses and marginals (default “C0”).
alpha (float) – Fill alpha for the ellipse interiors (default 0.3).
true_params (dict or array_like, optional) – True parameter values to mark with crosshairs.
savefig (str, optional) – If provided, save figure to this path.
**corner_kwargs – Extra keyword arguments forwarded to
corner.corner(e.g.quantiles,show_titles,hist_kwargs).
- Returns:
fig, axes
- Return type:
matplotlib Figure and 2D array of Axes.
- plot_corner_map(samples=None, checkpoint_path=None, cmap='viridis', marker_size=40, savefig=None, true_params=None, **corner_kwargs)[source]#
Corner plot of MCMC samples with MAP solutions overlaid as scatter points colored by their log-likelihood.
- Parameters:
samples (array_like, optional) – Shape
(n_samples, n_params). If None, loads from the checkpoint file.checkpoint_path (str, optional) – Path to checkpoint
.npzfile containing MAP solutions. If None, uses the default checkpoint file.cmap (str) – Colormap for the MAP scatter points (default “viridis”).
marker_size (float) – Marker size for scatter points (default 40).
savefig (str, optional) – If provided, save figure to this path.
true_params (dict or array_like, optional) – True parameter values to mark with crosshairs.
**corner_kwargs – Extra keyword arguments forwarded to
corner.corner.
- Returns:
fig, axes
- Return type:
matplotlib Figure and 2D array of Axes.
- class spotgp.mcmc.BlackJAXSampler(gp, save_dir='results', checkpoint_file='mcmc_checkpoint.npz')[source]#
Bases:
MCMCSamplerNUTS sampler using the BlackJAX library.
Inherits diagnostics, summary, plotting, and dict conversion from MCMCSampler. Adds
run_map,run_warmup, andrun_samplingfor gradient-based No-U-Turn sampling with dual-averaging step-size adaptation.When multiple chains are requested, sampling is parallelized across available devices via
jax.pmap. Chains are distributed evenly across devices (n_chainsmust be divisible byjax.device_count()). On a single GPU this behaves identically to the previousjax.vmapimplementation.- Parameters:
gp (GPSolver) – A configured GPSolver instance.
save_dir (str, optional) – Directory for all outputs produced by this sampler (corner plots, covariance plots, etc.). Created automatically if it does not exist. When set,
save_checkpointwill default to saving the checkpoint inside this directory.checkpoint_file (str, optional) – Path to the checkpoint file. When provided, overrides the default
save_dir/mcmc_checkpoint.npz. If neithercheckpoint_filenorsave_diris given, no checkpoint file is set until one is passed to a later method.
- run_map(nopt=10, keys=None, checkpoint_file=None, theta0=None, **kwargs)[source]#
Find MAP solutions via parallel multi-start optimization.
Runs
GPSolver.fit_map_paralleland stores the results. If the checkpoint file already contains MAP data, loads from it instead of re-running the optimization.- Parameters:
nopt (int) – Number of independent optimization restarts (default 10).
keys (list of str, optional) – Parameter names to optimize. If None, uses all bounded parameters from GPSolver.
theta0 (dict, optional) – Initial parameter guess to include as one of the optimization starting points. Replaces one random start so the total number of restarts stays
nopt.checkpoint_file (str, optional) – Path to save/load MAP solutions. If provided, also updates the sampler’s default checkpoint path. Defaults to
self._checkpoint_file.**kwargs – Additional keyword arguments passed to
GPSolver.fit_map_parallel(e.g.method,maxiter).
- Returns:
all_theta_maps – All MAP solutions sorted by objective (best first).
- Return type:
- run_warmup(n_warmup=500, theta_init=None, mass_matrix_method='hessian_map', step_size=None, rng_key=None, target_accept=0.8, progress_bar=False, n_chains=1, checkpoint_file=None, warmup_method='window_adaptation', pathfinder_maxiter=100, pathfinder_maxcor=10, pathfinder_num_elbo=200)[source]#
Run warmup phase: adapt step size and mass matrix.
Supports three warmup strategies:
"window_adaptation"(default): BlackJAX’s standard dual-averaging window adaptation of both step size and mass matrix."pathfinder": multi-path Pathfinder via L-BFGS."dual_averaging": fixes the mass matrix (from Hessian at MAP) and only adapts the step size.
After warmup, adapted parameters are stored on the sampler and a checkpoint is saved (if
checkpoint_fileis set).- Parameters:
n_warmup (int) – Number of warmup steps (default 500).
theta_init (dict or array_like, optional) – Initial position. If None, uses GPSolver’s MAP estimate. Can also be a list of dicts or 2-D array for per-chain starting points.
mass_matrix_method ({"hessian_map", "fisher", "laplace", "diagonal", None}) – Method to estimate the mass matrix.
step_size (float, optional) – Initial NUTS step size. If None, a heuristic is used.
rng_key (jax.random.PRNGKey, optional) – Random key. Default: PRNGKey(0).
target_accept (float) – Target acceptance rate (default 0.8).
progress_bar (bool) – If True, show progress during window adaptation.
n_chains (int) – Number of chains (used to validate device count and store per-chain init positions).
checkpoint_file (str, optional) – Override the default checkpoint file path. When set, updates
self._checkpoint_filefor all subsequent save/load operations. Defaults tosave_dir/mcmc_checkpoint.npzwhensave_diris set.warmup_method ({"window_adaptation", "pathfinder", "dual_averaging"}) – Warmup strategy.
pathfinder_maxiter (int) – Max L-BFGS iterations for Pathfinder (default 100).
pathfinder_maxcor (int) – L-BFGS history size for Pathfinder (default 10).
pathfinder_num_elbo (int) – Number of ELBO samples for Pathfinder (default 200).
- run_sampling(n_samples=1000)[source]#
Run NUTS sampling using adapted parameters from
run_warmup.Must be called after
run_warmup(or will use parameters restored from a checkpoint).- Parameters:
n_samples (int) – Number of post-warmup samples per chain (default 1000).
- Returns:
samples (jnp.ndarray) – Shape
(n_samples, n_params)whenn_chains=1, or(n_chains, n_samples, n_params)whenn_chains > 1.info (dict) – Sampling diagnostics (arrays have a leading chain dimension when
n_chains > 1).
- save_checkpoint(path=None, append_samples=True, plot_corner=False)[source]#
Save sampler state to disk for later resumption.
When
append_samples=True(the default), new samples are appended to any existing samples already stored inpath, andself.samplesis cleared from memory. This enables a sample-checkpoint-clear loop that keeps memory usage constant.- Parameters:
path (str, optional) – File path (saved as
.npz). If None, uses thecheckpoint_fileset inrun_warmup, orsave_dir/checkpoint.npzifsave_dirwas set.append_samples (bool) – If True, append current
self.samplesto any samples already on disk, then clearself.samplesfrom memory. If False, overwrite with only the current in-memory samples.plot_corner (bool) – If True, load all samples currently on disk after saving and write a corner plot to
save_dir/corner_plot.png(or alongside the checkpoint file ifsave_diris not set).
- load_checkpoint(checkpoint_file=None)[source]#
Restore sampler state from a checkpoint file.
Loads only the NUTS state and adapted kernel parameters needed to resume sampling. Samples stored in the file are not loaded into memory — use
load_samplesto read them later.- Parameters:
checkpoint_file (str, optional) – Path to a
.npzcheckpoint file. If provided, also updates the sampler’s default checkpoint path. If None, uses the defaultsave_dir/mcmc_checkpoint.npz.
- run_smc(n_particles=500, n_mcmc_steps=10, n_adapt_steps=25, target_ess=0.5, target_accept=0.6, rng_key=None, step_size=None, mass_matrix_method='hessian_map', theta_init=None, max_tempering_steps=200, checkpoint_every=10, checkpoint_file=None, particle_batch_size=None, max_num_doublings=10)[source]#
Run adaptive tempered Sequential Monte Carlo.
Starts from the prior and anneals toward the full posterior using an adaptive temperature schedule. At each tempering step, particles are resampled and rejuvenated with NUTS moves. The NUTS step size is re-adapted via dual averaging at each tempering stage using a representative particle.
- Parameters:
n_particles (int) – Number of SMC particles (default 500).
n_mcmc_steps (int) – NUTS rejuvenation steps per tempering stage (default 10).
n_adapt_steps (int) – Dual-averaging warmup steps to adapt the NUTS step size at each tempering stage (default 25).
target_ess (float) – Target effective sample size as a fraction of
n_particles(default 0.5).target_accept (float) – Target NUTS acceptance rate for dual averaging (default 0.6).
rng_key (jax.random.PRNGKey, optional) – Random key. Default: PRNGKey(42).
step_size (float, optional) – Initial NUTS step size. If None, a heuristic from the mass matrix is used.
mass_matrix_method (str, optional) – Method to estimate the inverse mass matrix (default
"hessian_map"). Set to None to use an identity matrix.theta_init (dict or array_like, optional) – Reference point for mass matrix estimation. If None, the MAP estimate is used.
max_tempering_steps (int) – Safety limit on the number of tempering stages (default 200).
checkpoint_every (int) – Save a checkpoint every this many tempering steps (default 10). Set to 0 to disable periodic checkpointing.
checkpoint_file (str, optional) – Override the default checkpoint file path.
particle_batch_size (int, optional) – Process particles in batches of this size to limit GPU memory usage. When multiple GPUs are visible the batches are distributed across devices via
jax.pmap.n_particlesmust be divisible by this value (and bybatch_size * n_devicesfor multi-GPU). If None, all particles are evaluated at once (original blackjax behavior).max_num_doublings (int, optional) – Maximum NUTS tree depth (default 10). Lower values (e.g. 5-6) reduce peak GPU memory per particle at the cost of shorter trajectories.
- Returns:
samples (np.ndarray, shape (n_particles, n_params)) – Weighted posterior particles at the final temperature.
info (dict) – Diagnostics including tempering schedule and log evidence estimate.