API Reference (src)#
Lightcurve Model#
- class src.lightcurve.LightcurveModel(peq=4.0, kappa=0.0, inc=1.5707963267948966, nspot=None, tau_spot=None, tem=2, tdec=2, alpha_max=0.1, fspot=0, lspot=5, long=[0, 6.283185307179586], lat=[-1.5707963267948966, 1.5707963267948966], tsim=28, tsamp=0.02, limb_darkening=False, tmax=None, rotate=True, grow=True, nspot_rate=None)[source]#
Bases:
objectJAX-accelerated star with spots and its lightcurve.
Same interface as the numpy version but uses JAX for vectorized computation across all spots simultaneously.
- Parameters:
peq (float) – Equatorial period of the star.
kappa (float) – Differential rotation shear.
inc (float) – Inclination of the star.
nspot (int) – Number of spots.
tau_spot (float, optional) – Timescale for both emergence and decay of the spots. Defaults to None.
tem (float, optional) – Emergence timescale of the spots. Defaults to 2.
tdec (float, optional) – Decay timescale of the spots. Defaults to 2.
alpha_max (float, optional) – Maximum angular area of the spots. Defaults to 0.1.
fspot (float, optional) – Spot contrast fraction. Defaults to 0.
lspot (float, optional) – Spot lifetime. Defaults to 5.
long (list, optional) – Range of spot longitudes. Defaults to [0, 2*pi].
lat (list, optional) – Range of spot latitudes. Defaults to [0, pi].
tsim (float, optional) – End simulation time. Defaults to 28.
tsamp (float, optional) – Sampling cadence. Defaults to 0.02.
limb_darkening (bool, optional) – Flag to enable limb darkening. Defaults to False.
- classmethod from_spot_model(spot_model: SpotEvolutionModel, nspot: int = None, *, nspot_rate: float = None, **kwargs)[source]#
Construct a LightcurveModel from a SpotEvolutionModel.
- Parameters:
spot_model (SpotEvolutionModel) – Fully configured spot evolution model.
nspot (int, optional) – Total number of spots to simulate.
nspot_rate (float, optional) – Spot emergence rate [spots/day]. The actual number of spots is
max(1, int(nspot_rate * tsim)). Exactly one ofnspotornspot_ratemust be provided.**kwargs – Forwarded to LightcurveModel.__init__ (e.g. tsim, tsamp, lat, long).
- Return type:
- classmethod from_hparam(hparam: dict, nspot: int = None, *, nspot_rate: float = None, **kwargs)[source]#
Construct a LightcurveModel from a GPSolver-compatible hparam dict.
Accepts the same raw hparam dict that GPSolver/AnalyticKernel take, including all amplitude modes (sigma_k, nspot_rate, or nspot), and both symmetric (tau) and asymmetric (tau_em + tau_dec) envelopes. This removes the need to manually decompose the dict in scripts.
- Parameters:
hparam (dict) – Raw hyperparameter dict. Must contain peq, kappa, inc, lspot, tau_spot (or tau_em/tau_dec), and an amplitude specification.
nspot (int, optional) – Total number of spots to simulate.
nspot_rate (float, optional) – Spot emergence rate [spots/day]. Exactly one of
nspotornspot_ratemust be provided.**kwargs – Forwarded to LightcurveModel.__init__ (e.g. tsim, tsamp, lat, long).
- Return type:
- Flux(teval)[source]#
Compute the full lightcurve using JAX vmap over all spots.
Instead of a Python loop over nspot, all spots are computed in parallel via JAX’s vmap.
- animate_lightcurve(fps=30, duration=10.0, outfile=None, dpi=150, show_spots=True, show_grid=True, show_params=True, figsize=(14, 5.5), save_last_frame=None, show_dr=True, label_size=18)[source]#
Animate the starspot evolution with two panels: a 2D projection of the rotating star (left) and the lightcurve (right).
- Parameters:
fps (int) – Frames per second (default 30).
duration (float) – Animation duration in seconds (default 10).
outfile (str or None) – Output file path (.mp4 or .gif). If None, returns the animation object without saving.
dpi (int) – Resolution (default 150).
show_spots (bool) – If True, show individual spot contributions on the lightcurve panel (default True).
show_grid (bool) – If True, draw latitude/longitude grid on the star (default True).
show_params (bool) – If True, show parameter annotation on the lightcurve panel (default True).
figsize (tuple) – Figure size (default (14, 5.5)).
save_last_frame (str or None) – If provided, save the last frame of the animation as a static image to this file path (e.g. “frame.png”).
show_dr (bool) – If True, color the stellar disk by latitude-dependent rotation frequency and display a colorbar (default True).
label_size (int or float) – Font size for all labels, tick marks, and text in the plot (default 18).
- Returns:
anim – The animation object.
- Return type:
matplotlib.animation.FuncAnimation
Analytic Kernel#
- class src.analytic_kernel.AnalyticKernel(model_or_hparam, n_harmonics=3, n_lat=64, lat_range=None, quadrature='trapezoid')[source]#
Bases:
objectJAX-accelerated analytic GP kernel for stellar rotation variability.
- Parameters:
model_or_hparam (SpotEvolutionModel or dict) – Either a SpotEvolutionModel instance (new API) or a raw hparam dict (backward-compatible old API).
n_harmonics (int) – Number of Fourier harmonics for the visibility function (default 3).
n_lat (int) – Number of latitude quadrature points (default 64).
lat_range (tuple) – (min, max) latitude in radians (default (-pi/2, pi/2)).
quadrature (str) – Latitude integration method: “trapezoid” or “gauss-legendre”.
- kernel(lag, lat_dist=None)[source]#
Full GP kernel averaged over latitude.
Uses jax.lax.scan for memory-efficient accumulation: only one lag-sized buffer is live at a time — O(M) instead of O(n_lat·M).
When the visibility function is an EdgeOnVisibilityFunction, the latitude-averaged |c_n|^2 are known constants and the latitude loop is bypassed entirely.
- Parameters:
lag (array_like) – Time lags [days]. Can be 1D or 2D.
lat_dist (callable or None) – Latitude probability density. If None, uniform.
- Returns:
K
- Return type:
ndarray, same shape as lag input.
- compute_psd(omega, lat_dist=None)[source]#
Analytic power spectral density.
- Parameters:
omega (array_like) – Angular frequencies [rad/day].
lat_dist (callable or None) – Latitude probability density.
- Returns:
freq (ndarray [cycles/day])
power (ndarray)
- build_jax(n_lag=256)[source]#
Pre-compile and warm up JAX JIT computation for this kernel.
jax.lax.scan(used insidekernel()) triggers XLA compilation on its first call for a given array shape. That compilation can take several seconds and is easy to mistake for slow runtime. Callbuild_jax()once after constructing the kernel to pay that cost upfront — subsequent calls tokernel()andcompute_psd()with the same shape will be fast.- Parameters:
n_lag (int) – Length of the dummy lag array used to drive compilation (default 256). The actual value does not matter as long as it is representative of the sizes you will use at runtime.
- Returns:
self – Returns
selfso the call can be chained:ak = AnalyticKernel(model).build_jax().- Return type:
Numerical Kernel#
- class src.numerical_kernel.NumericalKernel(model_or_hparam, tsim=20, tsamp=0.05, nsim=1000.0, verbose=True)[source]#
Bases:
objectGaussian Process for Stellar Rotation
- Parameters:
hparam (dict) –
Dict of hyperparameters. Required keys: peq, kappa, inc, lspot, tau_spot, alpha_max. For the kernel amplitude, provide EITHER:
sigma_k : overall amplitude prefactor, OR
nspot + fspot : number of spots and spot contrast, from which sigma_k is computed as sqrt(N_spot) * (1 - f_spot) / pi.
Note: nspot is always required for the numerical simulations.
tsim (float) – Simulation time (default: 20).
tsamp (float) – Time sampling (default: 0.05).
nsim (int) – Number of simulations (default: 1e3).
verbose (bool) – Whether to print verbose output (default: True).
GP Solver#
- class src.gp_solver.GPSolver(x, y, yerr, model_or_hparam, kernel_type='analytic', mean=None, fit_sigma_n=False, bounds=None, log_prior=None, matrix_solver='cholesky_banded', bandwidth=None, save_dir=None, **kernel_kwargs)[source]#
Bases:
objectJAX-accelerated Gaussian Process solver for stellar lightcurves.
Handles covariance matrix construction, Cholesky factorization, log-likelihood evaluation, prediction, MAP estimation, and mass matrix computation.
- Parameters:
x (array_like, shape (N,)) – Observation times [days].
y (array_like, shape (N,)) – Observed flux values.
yerr (array_like, shape (N,) or float) – Measurement uncertainties (1-sigma).
hparam (dict) – Kernel hyperparameters. Required keys: peq, kappa, inc, lspot, tau_spot. For amplitude, provide either sigma_k directly or all of nspot, fspot, and alpha_max (sigma_k computed automatically).
kernel_type ({"analytic"}) – Which kernel to use (default: “analytic”).
mean (float or callable or None) – Mean function.
fit_sigma_n (bool) – If True, include white noise amplitude sigma_n as a free parameter for optimization/sampling (default False).
bounds (dict or None) – Parameter bounds for optimization. If None, uses defaults.
log_prior (callable or None) – Custom log-prior function f(theta_arr) -> scalar. If None, uses soft uniform within bounds.
kernel_kwargs (dict) – Extra kwargs forwarded to the kernel constructor.
- DEFAULT_BOUNDS = {'inc': (0.01, 3.1315926535897933), 'kappa': (0.001, 0.999), 'lspot': (0.1, 20.0), 'n_sn': (-10.0, 10.0), 'peq': (0.5, 50.0), 'sigma_k': (1e-06, 1.0), 'sigma_n': (1e-06, 0.1), 'sigma_sn': (0.05, 10.0), 'tau_dec': (0.05, 10.0), 'tau_em': (0.05, 10.0), 'tau_spot': (0.05, 10.0)}#
- build_jax()[source]#
Pre-compile and warm up all JAX JIT functions for this solver.
_build_logposterior()creates four@jax.jit-decorated functions (log_posterior,neg_log_posterior,grad_log_posterior,grad_neg_log_posterior) that each trigger a separate XLA compilation on their first call. Combined with the banded Cholesky solver, that can add up to 10–30 s of invisible overhead before a fit or MCMC run starts.Call
build_jax()once after constructing the solver to pay that cost upfront.- Returns:
self – Returns
selfso the call can be chained:gp = GPSolver(...).build_jax().- Return type:
- log_likelihood()[source]#
Marginal log-likelihood of the data under the GP.
- Returns:
logL – log p(y | X, theta)
- Return type:
- predict(xpred, return_cov=False)[source]#
Predictive distribution at new input locations.
- Parameters:
xpred (array_like, shape (M,)) – Prediction times.
return_cov (bool) – If True, return full predictive covariance.
- Returns:
mu_pred (ndarray, shape (M,))
var_pred (ndarray, shape (M,) or (M, M))
- plot_prediction(theta=None, n_points=2000, n_sigma=(1, 2), ax=None, data_color='k', model_color='r', show_legend=True, xlim=None, ylim=None, model_label='GP mean', data_label='Data')[source]#
Plot the GP posterior mean and uncertainty bands over the data.
If
thetais provided the GP is temporarily updated to those hyperparameters before predicting, so the prediction reflects the given parameter values rather than whatever was last set internally.- Parameters:
theta (dict or array_like, shape (6,), optional) – Kernel parameters. Accepts a physical dict with keys from
KERNEL_HPARAM_KEYS, a sampling-space dict withlog_- prefixed keys (e.g.log_sigma_k), or a length-6 array. If None, uses the current internal hyperparameters.n_points (int) – Number of prediction points spanning the data baseline.
n_sigma (int or sequence of int) – Which sigma levels to shade. E.g.
(1, 2)draws both ±1σ and ±2σ bands (default). Pass a single int for one band.ax (matplotlib Axes, optional) – Axes to plot on. If None, creates a new figure.
data_color (str) – Colors for data points and model curve/bands.
model_color (str) – Colors for data points and model curve/bands.
show_legend (bool) – Whether to draw a legend.
xlim (tuple, optional) – Limits for the x-axis. If None, defaults to the data range.
ylim (tuple, optional) – Limits for the y-axis. If None, defaults to the data range.
- Returns:
ax
- Return type:
matplotlib Axes
- compute_acf(tlags, normalize=True)[source]#
Compute the empirical autocorrelation function of (x, y, yerr) by binning data pairs into time-lag bins.
- Parameters:
tlags (array_like, shape (M,)) – Bin edges for time lags [days]. The ACF is evaluated at bin centers: 0.5*(tlags[:-1] + tlags[1:]).
normalize (bool) – If True (default), normalize so ACF(0) = variance of y.
- Returns:
lag_centers (ndarray, shape (M-1,)) – Bin centers.
acf (ndarray, shape (M-1,)) – Empirical ACF at each bin center.
- compute_kernel(tlags)[source]#
Evaluate the analytic kernel at the given time lags.
- Parameters:
tlags (array_like, shape (M,)) – Time lags [days].
- Returns:
K – Kernel values at each lag.
- Return type:
ndarray, shape (M,)
- fit_acf(theta0=None, keys=None, tlags=None, n_bins=50, method='L-BFGS-B', maxiter=500, ftol=0, gtol=1e-08, disp=False, nopt=1, ncore=None, rng=None, _save=True)[source]#
Fit the analytic kernel to the empirical ACF via least-squares.
Minimizes sum_i (ACF_data(lag_i) - K(lag_i; theta))^2 over the kernel hyperparameters, using JAX gradients and scipy.
- Parameters:
theta0 (dict or array_like, optional) –
- Starting point. Can be:
None: uses self.theta0 (kernel params only, no sigma_n).
dict: values for any subset of kernel keys set the starting point. If
keysis not given, the dict keys that overlap withKERNEL_HPARAM_KEYSare treated as the free variables; the rest are held fixed. Extra keys not inKERNEL_HPARAM_KEYSare ignored.array_like: full kernel theta vector (6 elements).
keys (list of str, optional) – Which parameters to vary during optimization. Overrides the automatic inference from a dict
theta0. Parameters not listed are held fixed at their current values. If None and theta0 is not a dict, all kernel parameters are varied.tlags (array_like, optional) – Bin edges for compute_acf. If None, linearly spaced from 0 to half the baseline with n_bins+1 edges.
n_bins (int) – Number of lag bins (used when tlags is None).
method (str) – Scipy optimizer method.
maxiter (int) – Maximum iterations.
ftol (float) – Function-value convergence tolerance (default 0, disabled).
gtol (float) – Gradient-norm convergence tolerance (default 1e-8).
disp (bool) – If True, print optimizer convergence messages (default False).
nopt (int) – Number of independent optimisation trials (default 1). When > 1,
fit_acf_parallelis called and the best result across all trials is returned.ncore (int or None) – Number of parallel workers. Only used when
nopt > 1.rng (numpy.random.Generator, optional) – RNG for random starting points. Only used when
nopt > 1.
- Returns:
theta_dict (dict) – Full dictionary of all kernel hyperparameters (fixed + optimized).
result (scipy.optimize.OptimizeResult) – Full optimizer output.
- fit_acf_parallel(nopt=10, ncore=None, keys=None, tlags=None, n_bins=50, method='nelder-mead', maxiter=500, ftol=0, gtol=1e-08, disp=False, return_all=False, rng=None)[source]#
Run
fit_acffrom multiple random starting points in parallel.Starting points are drawn uniformly within the kernel parameter bounds.
- Parameters:
nopt (int) – Number of independent optimization trials (default 10).
ncore (int or None) – Number of parallel workers. If None, uses
noptor the number of available CPUs, whichever is smaller.keys (list of str, optional) – Free parameters (forwarded to
fit_acf).tlags – Forwarded to
fit_acf.n_bins – Forwarded to
fit_acf.method (str) – Optimizer method (default “nelder-mead”).
maxiter – Forwarded to
fit_acf.ftol – Forwarded to
fit_acf.gtol – Forwarded to
fit_acf.disp – Forwarded to
fit_acf.return_all (bool) – If True, return all solutions sorted by objective value. If False (default), return only the best solution.
rng (numpy.random.Generator, optional) – Random number generator for reproducibility.
- Returns:
theta_best (dict (or list of dict if
return_all=True)) – Best-fit kernel hyperparameters.result_best (scipy.optimize.OptimizeResult) – (or list of OptimizeResult if
return_all=True)
- fit_acf_psd(theta0=None, keys=None, tlags=None, n_bins=50, n_freq=200, dt_kernel=None, acf_weight=1.0, psd_weight=1.0, method='L-BFGS-B', maxiter=500, ftol=0, gtol=1e-08, disp=False)[source]#
Fit kernel parameters jointly to the empirical ACF and PSD.
Minimizes a weighted sum of two normalized mean-squared-error terms:
loss = acf_weight * acf_loss + psd_weight * psd_loss
where
acf_loss = mean((ACF_data - K_model)^2) / mean(ACF_data^2)
is the relative MSE of the kernel against the empirical ACF (unnormalized autocovariance), and
psd_loss = mean((PSD_data_norm - PSD_model_norm)^2)
is the MSE between the Lomb-Scargle periodogram and the analytic kernel PSD, both normalized to unit integral so the comparison is independent of overall amplitude.
The model PSD is computed via a direct cosine transform of the kernel evaluated on a uniform lag grid, making it fully differentiable with respect to the kernel parameters.
- Parameters:
theta0 (dict or array_like, optional) – Starting point in
self.param_keysspace (sampling space, withlog_-prefixed keys where applicable). Follows the same convention asfit_map: None usesself.theta0, a dict overrides named entries and infers free keys, an array is used directly.keys (list of str, optional) – Parameters to vary during optimization (names from
self.param_keys). Defaults to all kernel parameters (first 6 entries ofself.param_keys, i.e. excludingsigma_nif present).tlags (array_like, optional) – Bin edges for the empirical ACF. If None,
n_bins+1edges linearly spaced from 0 to half the baseline.n_bins (int) – Number of ACF lag bins when
tlagsis None (default 50).n_freq (int) – Number of frequency points for the Lomb-Scargle periodogram (default 200).
dt_kernel (float, optional) – Uniform lag spacing [days] for evaluating the analytic kernel before the direct cosine transform. Defaults to one-fifth of the median data spacing.
acf_weight (float) – Weight for the ACF loss term (default 1.0).
psd_weight (float) – Weight for the PSD loss term (default 1.0).
method (str) – Scipy optimizer method (default
"L-BFGS-B").maxiter (int) – Maximum optimizer iterations (default 500).
ftol (float) – Convergence tolerances forwarded to scipy.
gtol (float) – Convergence tolerances forwarded to scipy.
disp (bool) – Print optimizer messages if True.
- Returns:
theta_dict (dict) – Best-fit parameters in
self.param_keysspace.result (scipy.optimize.OptimizeResult)
- plot_acf(theta=None, tlags=None, n_bins=50, ax=None, normalize=False, data_color='k', model_color='r', show_legend=True, xlim=None, ylim=None, model_label='Analytic ACF', data_label='Data ACF')[source]#
Plot the empirical ACF and optionally the analytic kernel.
- Parameters:
theta (dict or array_like, shape (6,), optional) – Kernel parameters. Accepts a physical dict with keys from
KERNEL_HPARAM_KEYS, a sampling-space dict withlog_- prefixed keys (e.g.log_sigma_k), or a length-6 array. If provided, the analytic kernel is overplotted.tlags (array_like, optional) – Bin edges for compute_acf. If None, linearly spaced from 0 to half the baseline with n_bins+1 edges.
n_bins (int) – Number of lag bins (used when tlags is None).
ax (matplotlib Axes, optional) – Axes to plot on. If None, creates a new figure.
normalize (bool) – If True (default), normalize both curves by the data variance so ACF(0) ≈ 1.
xlim (tuple, optional) – Limits for the x-axis. If None, defaults to the data range.
ylim (tuple, optional) – Limits for the y-axis. If None, defaults to the data range.
- Returns:
ax
- Return type:
matplotlib Axes
- plot_psd(theta=None, n_freq=500, dt_kernel=None, ax=None, data_color='k', model_color='r', show_legend=True, xlim=None, ylim=None, model_label='Analytic PSD', data_label='Data Lomb-Scargle')[source]#
Plot the empirical PSD (Lomb-Scargle) and optionally the analytic kernel PSD (FFT of the autocovariance function).
Both curves are normalized so their integral over positive frequencies equals the data variance, making them directly comparable.
- Parameters:
theta (dict or array_like, shape (6,), optional) – Kernel parameters. Accepts a physical dict with keys from
KERNEL_HPARAM_KEYS, a sampling-space dict withlog_- prefixed keys (e.g.log_sigma_k), or a length-6 array. If provided, the analytic kernel PSD is overplotted.n_freq (int) – Number of frequency points for the Lomb-Scargle periodogram.
dt_kernel (float, optional) – Time step [days] for evaluating the analytic kernel on a uniform grid before FFT. Defaults to one-fifth of the median data spacing.
ax (matplotlib Axes, optional) – Axes to plot on. If None, creates a new figure.
data_color (str) – Colors for the data and model curves.
model_color (str) – Colors for the data and model curves.
show_legend (bool) – Whether to draw a legend.
xlim (tuple, optional) – Limits for the x-axis. If None, defaults to the data range.
ylim (tuple, optional) – Limits for the y-axis. If None, defaults to the data range.
- Returns:
ax
- Return type:
matplotlib Axes
- plot_covariance_matrix(theta=None, ax=None, cmap='RdBu_r', show_colorbar=True, vmax=None, nbins=50, show=False, filename='covariance_matrix.png')[source]#
Plot the GP covariance matrix K (signal only, no noise).
Entries outside the banded support are set to zero, matching the
cholesky_bandedapproximation. The matrix is binned tonbins x nbinsbefore plotting. The bandwidth boundary is drawn as dashed lines, and band width plus matrix sparsity are annotated.- Parameters:
theta (dict or array_like, optional) – Kernel hyperparameters. Accepts a physical dict with keys from
param_keys, a sampling-space dict withlog_-prefixed keys, or a raw array. If None, uses the currentself.hparamvalues.ax (matplotlib Axes, optional) – Axes to plot on. If None, a new figure is created.
cmap (str, optional) – Colormap name. Defaults to
"RdBu_r"(diverging, centred at zero).show_colorbar (bool, optional) – Whether to add a colorbar. Default True.
vmax (float, optional) – Symmetric color scale limit
[-vmax, vmax]. If None, uses the maximum absolute value of the banded matrix.nbins (int, optional) – Bin the N x N matrix down to
nbins x nbinsby averaging non-overlapping blocks before plotting. Default 50.show (bool, optional) – If True, call
plt.show(). Default False.filename (str, optional) – Filename used when saving to
save_dir. Default"covariance_matrix.png".
- Returns:
ax
- Return type:
matplotlib Axes
- get_theta()[source]#
Return the current kernel hyperparameters as a dictionary.
- Returns:
theta – Keys and values for all kernel (and optionally noise) hyperparameters, e.g. {“peq”: 5.0, “kappa”: 0.2, …}.
- Return type:
- update_hparam(hparam)[source]#
Update hyperparameters and rebuild kernel and covariance.
Accepts either a SpotEvolutionModel or a hparam dict.
- fit_map(theta0=None, keys=None, method='L-BFGS-B', maxiter=500, ftol=0, gtol=1e-08, disp=False, nopt=1, ncore=None, rng=None, _save=True)[source]#
Find the maximum a posteriori (MAP) estimate.
Uses scipy.optimize.minimize with JAX-computed gradients. When
nopt > 1, delegates tofit_map_parallelwhich runsnoptindependent trials from random starting points (drawn uniformly within the bounds) and returns the best result.- Parameters:
theta0 (dict or array_like, optional) –
- Starting point. Can be:
None: uses self.theta0 (current hyperparameters).
dict: values for any subset of param_keys set the starting point. If
keysis not given, the dict keys that overlap withself.param_keysare treated as the free variables to optimize; the rest are held fixed. Extra keys not inparam_keysare ignored.array_like: full theta vector (length
n_params).
Ignored when
nopt > 1(starting points are randomised).keys (list of str, optional) – Which parameters to vary during optimization. Overrides the automatic inference from a dict
theta0. Parameters not listed are held fixed at their current values. If None and theta0 is not a dict, all parameters are varied.method (str) – Scipy optimizer method (default “L-BFGS-B”).
maxiter (int) – Maximum iterations.
ftol (float) – Function-value convergence tolerance for L-BFGS-B (default 0, i.e. disabled so that convergence is controlled by
gtol).gtol (float) – Gradient-norm convergence tolerance (default 1e-8).
disp (bool) – If True, print optimizer convergence messages (default False).
nopt (int) – Number of independent optimisation trials (default 1). When > 1,
fit_map_parallelis called and the best result across all trials is returned.ncore (int or None) – Number of parallel workers for multi-start runs. If None, uses
noptor the number of available CPUs, whichever is smaller. Only used whennopt > 1.rng (numpy.random.Generator, optional) – RNG for random starting points. Only used when
nopt > 1.
- Returns:
theta_dict (dict) – Full dictionary of all hyperparameters (fixed + optimized).
result (scipy OptimizeResult) – Full optimizer output.
- fit_map_parallel(nopt=10, ncore=None, keys=None, method='nelder-mead', maxiter=500, ftol=0, gtol=1e-08, disp=False, return_all=False, rng=None)[source]#
Run
fit_mapfrom multiple random starting points in parallel.Starting points are drawn uniformly within the parameter bounds.
- Parameters:
nopt (int) – Number of independent optimization trials (default 10).
ncore (int or None) – Number of parallel workers. If None, uses
noptor the number of available CPUs, whichever is smaller.keys (list of str, optional) – Free parameters (forwarded to
fit_map).method (str) – Optimizer method (default “nelder-mead”).
maxiter – Forwarded to
fit_map.ftol – Forwarded to
fit_map.gtol – Forwarded to
fit_map.disp – Forwarded to
fit_map.return_all (bool) – If True, return all solutions sorted by objective value. If False (default), return only the best solution.
rng (numpy.random.Generator, optional) – Random number generator for reproducibility.
- Returns:
theta_best (dict (or list of dict if
return_all=True)) – Best-fit hyperparameters.result_best (scipy.optimize.OptimizeResult) – (or list of OptimizeResult if
return_all=True)
- mass_matrix_hessian_map(theta_map=None)[source]#
Estimate the inverse mass matrix from the Hessian of the negative log-likelihood at the MAP.
M^{-1} = H^{-1}whereH = d^2(-log L)/d theta^2at the MAP.
- mass_matrix_fisher(theta_map=None)[source]#
Estimate the inverse mass matrix from the Fisher information.
For the GP log-likelihood:
I_{ij} = (1/2) tr(K^{-1} dK/dtheta_i K^{-1} dK/dtheta_j)
When
matrix_solver="cholesky_full", the kernel derivatives dK/dtheta_i are computed via JAX forward-mode autodiff (jacfwd) on the full N×N covariance matrix.When
matrix_solver="cholesky_banded", the exact Fisher requires the dense N×N kernel and its inverse, which would defeat the purpose of banded storage. Instead, the Fisher is approximated by the Hessian of the banded negative log-likelihood at the MAP (Fisher ≈ observed information at the MLE).
Power Spectral Density#
- src.psd.compute_psd(y, t=None, dt=None, normalization='psd', freq_min=None, freq_max=None, n_freq=None, samples_per_peak=5)[source]#
Compute the Power Spectral Density of a time series using astropy.timeseries.LombScargle.
Works for both evenly and unevenly sampled data.
- Parameters:
y (array-like, shape (N,)) – Time series values.
t (array-like, shape (N,), optional) – Sample times. If None, integer indices scaled by
dtare used.dt (float, optional) – Sampling interval. Used only when
tis None (default: 1).normalization ({"psd", "standard", "model", "log"}) – Passed directly to LombScargle.autopower / power.
freq_min (float, optional) – Minimum frequency to evaluate.
freq_max (float, optional) – Maximum frequency to evaluate.
n_freq (int, optional) – Number of frequency grid points.
samples_per_peak (float, optional) – Controls the frequency grid density (default 5).
- Returns:
freq (ndarray) – Frequencies in cycles per unit time.
power (ndarray) – PSD evaluated at each frequency.
MCMC Sampler#
- class src.mcmc.MCMCSampler(gp)[source]#
Bases:
objectBase MCMC sampler for GP hyperparameters.
Wraps a GPSolver object and provides shared storage, diagnostics, summary statistics, corner plots, and dict conversion. Subclasses implement specific sampling algorithms (e.g. NUTS).
- Parameters:
gp (GPSolver) – A configured GPSolver instance.
- property param_keys#
- property n_params#
- summary()[source]#
Print summary statistics of the posterior samples.
- Returns:
stats – Parameter names mapped to (mean, std, 16%, 50%, 84%).
- Return type:
- plot_covariance(method='fisher', theta_map=None, n_sigma=2, n_grid=200, samples=None, figsize=None, color='C0', alpha=0.3, true_params=None, savefig=None, **corner_kwargs)[source]#
Corner plot of 2D covariance ellipses from the Hessian or Fisher matrix, with 1D marginal Gaussians on the diagonal.
Uses
corner.cornerto lay out the figure when MCMC samples are provided, and overlays the Laplace/Fisher Gaussian approximation (ellipses + 1D marginals).- Parameters:
method ({"fisher", "hessian_map", "laplace"}) – Which matrix to use for the Gaussian approximation.
theta_map (array_like, optional) – Center of the ellipses. If None, uses MAP estimate.
n_sigma (float) – Number of sigma for the ellipse contours (default 2).
n_grid (int) – Grid resolution for the ellipse curves (default 200).
samples (array_like, optional) – If provided, plotted as the corner histogram/contours. If None, the figure is created with empty axes and only the Gaussian approximation is drawn.
figsize (tuple, optional) – Figure size.
color (str) – Color for Gaussian ellipses and marginals (default “C0”).
alpha (float) – Fill alpha for the ellipse interiors (default 0.3).
true_params (dict or array_like, optional) – True parameter values to mark with crosshairs.
savefig (str, optional) – If provided, save figure to this path.
**corner_kwargs – Extra keyword arguments forwarded to
corner.corner(e.g.quantiles,show_titles,hist_kwargs).
- Returns:
fig, axes
- Return type:
matplotlib Figure and 2D array of Axes.
- class src.mcmc.BlackJAXSampler(gp, save_dir=None)[source]#
Bases:
MCMCSamplerNUTS sampler using the BlackJAX library.
Inherits diagnostics, summary, plotting, and dict conversion from MCMCSampler. Adds
run_nutsfor gradient-based No-U-Turn sampling with dual-averaging step-size adaptation.- Parameters:
- run_nuts(n_samples=1000, n_warmup=500, theta_init=None, mass_matrix_method='hessian_map', step_size=None, rng_key=None, target_accept=0.8, progress_bar=False, n_chains=1, checkpoint_file=None)[source]#
Run BlackJAX NUTS sampler.
Uses
blackjax.window_adaptationfor JIT-compiled warmup (step-size and mass-matrix adaptation), thenjax.lax.scanfor the sampling loop. Both paths avoid Python-level loops, minimizing retracing overhead and memory from accumulated intermediates.Warmup always runs a single chain to adapt the step size and mass matrix. When
n_chains > 1, the adapted parameters are shared across all chains, which are initialized with jittered copies of the warmup endpoint and sampled in parallel viajax.vmap.- Parameters:
n_samples (int) – Number of post-warmup samples per chain (default 1000).
n_warmup (int) – Number of warmup steps for step-size adaptation (default 500).
theta_init (dict or array_like, optional) – Initial position. If None, uses GPSolver’s MAP estimate.
mass_matrix_method ({"hessian_map", "fisher", "laplace", "diagonal", None}) – Method to estimate the mass matrix (delegated to GPSolver).
step_size (float, optional) – Initial NUTS step size before adaptation. If None, a heuristic based on the mass matrix scale is used.
rng_key (jax.random.PRNGKey, optional) – Random key. Default: PRNGKey(0).
target_accept (float) – Target acceptance rate for dual averaging (default 0.8).
progress_bar (bool) – If True, print periodic progress updates during the lax.scan sampling loop (default False).
n_chains (int) – Number of independent chains to run in parallel via
jax.vmap(default 1). All chains share the same adapted step size and mass matrix from a single warmup.checkpoint_file (str, optional) – Default file path for
save_checkpoint. When set, callingsave_checkpoint()with no arguments will use this path.
- Returns:
samples (jnp.ndarray) – Shape
(n_samples, n_params)whenn_chains=1, or(n_chains, n_samples, n_params)whenn_chains > 1.info (dict) – Sampling diagnostics (arrays have a leading chain dimension when
n_chains > 1).
- save_checkpoint(path=None, append_samples=True, plot_corner=False)[source]#
Save sampler state to disk for later resumption.
When
append_samples=True(the default), new samples are appended to any existing samples already stored inpath, andself.samplesis cleared from memory. This enables a sample-checkpoint-clear loop that keeps memory usage constant.- Parameters:
path (str, optional) – File path (saved as
.npz). If None, uses thecheckpoint_fileset inrun_nuts, orsave_dir/checkpoint.npzifsave_dirwas set.append_samples (bool) – If True, append current
self.samplesto any samples already on disk, then clearself.samplesfrom memory. If False, overwrite with only the current in-memory samples.plot_corner (bool) – If True, load all samples currently on disk after saving and write a corner plot to
save_dir/corner_plot.png(or alongside the checkpoint file ifsave_diris not set).
- load_checkpoint(path)[source]#
Restore sampler state from a checkpoint file.
Loads only the NUTS state and adapted kernel parameters needed to resume sampling. Samples stored in the file are not loaded into memory — use
load_samplesto read them later.- Parameters:
path (str) – Path to a
.npzcheckpoint file.
- static load_samples(path, flatten_chains=True)[source]#
Read all samples from a checkpoint file without loading the sampler state.
- Parameters:
- Returns:
samples – Shape
(n_total, n_params)whenflatten_chains=True, or(n_chains, n_samples, n_params)otherwise.- Return type:
np.ndarray
- resume_nuts(n_samples=1000, n_chains=None, rng_key=None, progress_bar=False)[source]#
Continue NUTS sampling from a previous run or loaded checkpoint.
Skips warmup entirely and uses the previously adapted step size and mass matrix (shared across all chains). Returns only the new batch of samples. Call
save_checkpointafterward to append the batch to disk and free memory.- Parameters:
n_samples (int) – Number of additional samples per chain (default 1000).
n_chains (int, optional) – Number of chains to run. If None (default), uses the value stored in the sampler state (from
run_nutsorload_checkpoint).rng_key (jax.random.PRNGKey, optional) – Random key. If None, advances from the last key used.
progress_bar (bool) – If True, print periodic progress updates (default False). Only supported for single-chain runs.
- Returns:
samples (jnp.ndarray) – Shape
(n_samples, n_params)for single chain, or(n_chains, n_samples, n_params)for multiple chains.info (dict) – Diagnostics for this batch.