API Reference (spotgp)

Contents

API Reference (spotgp)#

Lightcurve Model#

class spotgp.lightcurve.LightcurveModel(peq=4.0, kappa=0.0, inc=1.5707963267948966, nspot=None, tau_spot=None, tem=2, tdec=2, alpha_max=0.1, fspot=0, lspot=5, long=[0, 6.283185307179586], lat=[-1.5707963267948966, 1.5707963267948966], tsim=28, tsamp=0.02, limb_darkening=False, tmax=None, rotate=True, grow=True, nspot_rate=None)[source]#

Bases: object

JAX-accelerated star with spots and its lightcurve.

Same interface as the numpy version but uses JAX for vectorized computation across all spots simultaneously.

Parameters:
  • peq (float) – Equatorial period of the star.

  • kappa (float) – Differential rotation shear.

  • inc (float) – Inclination of the star.

  • nspot (int) – Number of spots.

  • tau_spot (float, optional) – Timescale for both emergence and decay of the spots. Defaults to None.

  • tem (float, optional) – Emergence timescale of the spots. Defaults to 2.

  • tdec (float, optional) – Decay timescale of the spots. Defaults to 2.

  • alpha_max (float, optional) – Maximum angular area of the spots. Defaults to 0.1.

  • fspot (float, optional) – Spot contrast fraction. Defaults to 0.

  • lspot (float, optional) – Spot lifetime. Defaults to 5.

  • long (list, optional) – Range of spot longitudes. Defaults to [0, 2*pi].

  • lat (list, optional) – Range of spot latitudes. Defaults to [0, pi].

  • tsim (float, optional) – End simulation time. Defaults to 28.

  • tsamp (float, optional) – Sampling cadence. Defaults to 0.02.

  • limb_darkening (bool, optional) – Flag to enable limb darkening. Defaults to False.

classmethod from_spot_model(spot_model: SpotEvolutionModel, nspot: int = None, *, nspot_rate: float = None, **kwargs)[source]#

Construct a LightcurveModel from a SpotEvolutionModel.

Parameters:
  • spot_model (SpotEvolutionModel) – Fully configured spot evolution model.

  • nspot (int, optional) – Total number of spots to simulate.

  • nspot_rate (float, optional) – Spot emergence rate [spots/day]. The actual number of spots is max(1, int(nspot_rate * tsim)). Exactly one of nspot or nspot_rate must be provided.

  • **kwargs – Forwarded to LightcurveModel.__init__ (e.g. tsim, tsamp, lat, long).

Return type:

LightcurveModel

classmethod from_hparam(hparam: dict, nspot: int = None, *, nspot_rate: float = None, **kwargs)[source]#

Construct a LightcurveModel from a GPSolver-compatible hparam dict.

Accepts the same raw hparam dict that GPSolver/AnalyticKernel take, including all amplitude modes (sigma_k, nspot_rate, or nspot), and both symmetric (tau) and asymmetric (tau_em + tau_dec) envelopes. This removes the need to manually decompose the dict in scripts.

Parameters:
  • hparam (dict) – Raw hyperparameter dict. Must contain peq, kappa, inc, lspot, tau_spot (or tau_em/tau_dec), and an amplitude specification.

  • nspot (int, optional) – Total number of spots to simulate.

  • nspot_rate (float, optional) – Spot emergence rate [spots/day]. Exactly one of nspot or nspot_rate must be provided.

  • **kwargs – Forwarded to LightcurveModel.__init__ (e.g. tsim, tsamp, lat, long).

Return type:

LightcurveModel

Flux(teval)[source]#

Compute the full lightcurve using JAX vmap over all spots.

Instead of a Python loop over nspot, all spots are computed in parallel via JAX’s vmap.

plot_lightcurve(show_spots=True, show_title=True)[source]#

Plot the lightcurve.

animate_lightcurve(fps=30, duration=10.0, outfile=None, dpi=150, show_spots=True, show_grid=True, show_params=True, figsize=(14, 5.5), save_last_frame=None, show_dr=True, label_size=18)[source]#

Animate the starspot evolution with two panels: a 2D projection of the rotating star (left) and the lightcurve (right).

Parameters:
  • fps (int) – Frames per second (default 30).

  • duration (float) – Animation duration in seconds (default 10).

  • outfile (str or None) – Output file path (.mp4 or .gif). If None, returns the animation object without saving.

  • dpi (int) – Resolution (default 150).

  • show_spots (bool) – If True, show individual spot contributions on the lightcurve panel (default True).

  • show_grid (bool) – If True, draw latitude/longitude grid on the star (default True).

  • show_params (bool) – If True, show parameter annotation on the lightcurve panel (default True).

  • figsize (tuple) – Figure size (default (14, 5.5)).

  • save_last_frame (str or None) – If provided, save the last frame of the animation as a static image to this file path (e.g. “frame.png”).

  • show_dr (bool) – If True, color the stellar disk by latitude-dependent rotation frequency and display a colorbar (default True).

  • label_size (int or float) – Font size for all labels, tick marks, and text in the plot (default 18).

Returns:

anim – The animation object.

Return type:

matplotlib.animation.FuncAnimation

animate_butterfly(fps=30, duration=10.0, outfile=None, dpi=150, show_spots=True, show_grid=True, show_params=True, figsize=(18, 5.5), save_last_frame=None, show_dr=True, label_size=18)[source]#

Animate the starspot evolution with three panels: a 2D projection of the rotating star (left), the lightcurve (center), and a butterfly diagram of spot latitude vs. time (right).

Parameters:
  • fps (int) – Frames per second (default 30).

  • duration (float) – Animation duration in seconds (default 10).

  • outfile (str or None) – Output file path (.mp4 or .gif). If None, returns the animation object without saving.

  • dpi (int) – Resolution (default 150).

  • show_spots (bool) – If True, show individual spot contributions on the lightcurve panel (default True).

  • show_grid (bool) – If True, draw latitude/longitude grid on the star (default True).

  • show_params (bool) – If True, show parameter annotation above the figure (default True).

  • figsize (tuple) – Figure size (default (18, 5.5)).

  • save_last_frame (str or None) – If provided, save the last frame of the animation as a static image to this file path (e.g. “frame.png”).

  • show_dr (bool) – If True, color the stellar disk by latitude-dependent rotation frequency and display a colorbar (default True).

  • label_size (int or float) – Font size for all labels, tick marks, and text in the plot (default 18).

Returns:

anim – The animation object.

Return type:

matplotlib.animation.FuncAnimation

Analytic Kernel#

class spotgp.analytic_kernel.AnalyticKernel(model_or_hparam, n_harmonics=3, n_lat=64, lat_range=None, quadrature='trapezoid')[source]#

Bases: object

JAX-accelerated analytic GP kernel for stellar rotation variability.

Parameters:
  • model_or_hparam (SpotEvolutionModel or dict) – Either a SpotEvolutionModel instance (new API) or a raw hparam dict (backward-compatible old API).

  • n_harmonics (int) – Number of Fourier harmonics for the visibility function (default 3).

  • n_lat (int) – Number of latitude quadrature points (default 64).

  • lat_range (tuple) – (min, max) latitude in radians (default (-pi/2, pi/2)).

  • quadrature (str) – Latitude integration method: “trapezoid” or “gauss-legendre”.

omega0(phi)[source]#

Latitude-dependent rotation angular frequency [rad/day].

R_Gamma(lag)[source]#

Autocorrelation of the squared envelope (delegates to envelope).

cn_squared(phi)[source]#

Squared Fourier visibility coefficients at latitude phi.

kernel_single_latitude(lag, phi)[source]#

Single-spot kernel at a fixed latitude.

kernel(lag, lat_dist=None)[source]#

Full GP kernel averaged over latitude.

Uses jax.lax.scan for memory-efficient accumulation: only one lag-sized buffer is live at a time — O(M) instead of O(n_lat·M).

When the visibility function is an EdgeOnVisibilityFunction, the latitude-averaged |c_n|^2 are known constants and the latitude loop is bypassed entirely.

Parameters:
  • lag (array_like) – Time lags [days]. Can be 1D or 2D.

  • lat_dist (callable or None) – Latitude probability density. If None, uniform.

Returns:

K

Return type:

ndarray, same shape as lag input.

kernel_solid_body(lag, lat_dist=None)[source]#

Kernel for solid-body rotation (kappa=0).

compute_psd(omega, lat_dist=None)[source]#

Analytic power spectral density.

Parameters:
  • omega (array_like) – Angular frequencies [rad/day].

  • lat_dist (callable or None) – Latitude probability density.

Returns:

  • freq (ndarray [cycles/day])

  • power (ndarray)

build_jax(n_lag=256)[source]#

Pre-compile and warm up JAX JIT computation for this kernel.

jax.lax.scan (used inside kernel()) triggers XLA compilation on its first call for a given array shape. That compilation can take several seconds and is easy to mistake for slow runtime. Call build_jax() once after constructing the kernel to pay that cost upfront — subsequent calls to kernel() and compute_psd() with the same shape will be fast.

Parameters:

n_lag (int) – Length of the dummy lag array used to drive compilation (default 256). The actual value does not matter as long as it is representative of the sizes you will use at runtime.

Returns:

self – Returns self so the call can be chained: ak = AnalyticKernel(model).build_jax().

Return type:

AnalyticKernel

Numerical Kernel#

class spotgp.numerical_kernel.NumericalKernel(model_or_hparam, tsim=20, tsamp=0.05, nsim=1000.0, verbose=True)[source]#

Bases: object

Gaussian Process for Stellar Rotation

Parameters:
  • hparam (dict) –

    Dict of hyperparameters. Required keys: peq, kappa, inc, lspot, tau_spot, alpha_max. For the kernel amplitude, provide EITHER:

    • sigma_k : overall amplitude prefactor, OR

    • nspot + fspot : number of spots and spot contrast, from which sigma_k is computed as sqrt(N_spot) * (1 - f_spot) / pi.

    Note: nspot is always required for the numerical simulations.

  • tsim (float) – Simulation time (default: 20).

  • tsamp (float) – Time sampling (default: 0.05).

  • nsim (int) – Number of simulations (default: 1e3).

  • verbose (bool) – Whether to print verbose output (default: True).

get_acf()[source]#
compute_psd(tarr=None, freq_min=None, freq_max=None, normalization='psd', nsims=100)[source]#
plot_autocorrelation()[source]#

Plot the autocorrelation function.

Returns: - fig: matplotlib Figure object

spotgp.numerical_kernel.generate_sims(theta, nsim=1000.0, **kwargs)[source]#

Generate synthetic lightcurves for a given set of parameters.

Parameters:
  • theta (tuple) – Tuple containing peq, kappa, inc, and nspot parameters.

  • nsim (int, optional) – Number of simulations. Defaults to 1000.

  • **kwargs – Additional arguments for the StarSpot class.

Returns:

Array of synthetic lightcurves.

Return type:

numpy.ndarray

spotgp.numerical_kernel.avg_covariance_tlag(K)[source]#

GP Solver#

class spotgp.gp_solver.GPSolver(data_or_x, y=None, yerr=None, model_or_hparam=None, kernel_type='analytic', mean=None, fit_sigma_n=False, bounds=None, log_prior=None, matrix_solver='cholesky_banded', bandwidth=None, save_dir=None, **kernel_kwargs)[source]#

Bases: object

JAX-accelerated Gaussian Process solver for stellar lightcurves.

Handles covariance matrix construction, Cholesky factorization, log-likelihood evaluation, prediction, MAP estimation, and mass matrix computation.

Parameters:
  • data_or_x (TimeSeriesData or array_like, shape (N,)) – Either a TimeSeriesData object, or observation times [days]. When a TimeSeriesData is passed, y and yerr must be None (they are read from the data object).

  • y (array_like, shape (N,), optional) – Observed flux values. Required when data_or_x is an array.

  • yerr (array_like, shape (N,) or float, optional) – Measurement uncertainties (1-sigma). Required when data_or_x is an array.

  • model_or_hparam (SpotEvolutionModel or dict) – Either a SpotEvolutionModel (new API) or a raw hparam dict (backward-compatible old API).

  • kernel_type ({"analytic"}) – Which kernel to use (default: “analytic”).

  • mean (float or callable or None) – Mean function.

  • fit_sigma_n (bool) – If True, include white noise amplitude sigma_n as a free parameter for optimization/sampling (default False).

  • bounds (dict or None) – Parameter bounds for optimization. If None, uses defaults.

  • log_prior (callable or None) – Custom log-prior function f(theta_arr) -> scalar. If None, uses soft uniform within bounds.

  • kernel_kwargs (dict) – Extra kwargs forwarded to the kernel constructor.

DEFAULT_BOUNDS = {'inc': (0.01, 3.1315926535897933), 'kappa': (0.001, 0.999), 'lat_max': (0.0, 1.5707963267948966), 'lat_min': (0.0, 1.5707963267948966), 'lspot': (0.1, 20.0), 'n_sn': (-10.0, 10.0), 'peq': (0.5, 50.0), 'sigma_k': (1e-06, 1.0), 'sigma_n': (1e-06, 0.1), 'sigma_sn': (0.05, 10.0), 'tau_dec': (0.05, 10.0), 'tau_em': (0.05, 10.0), 'tau_spot': (0.05, 10.0)}#
build_jax(recompute=True)[source]#

Pre-compile and warm up all JAX JIT functions for this solver.

_build_logposterior() creates four @jax.jit-decorated functions (log_posterior, neg_log_posterior, grad_log_posterior, grad_neg_log_posterior) that each trigger a separate XLA compilation on their first call. Combined with the banded Cholesky solver, that can add up to 10–30 s of invisible overhead before a fit or MCMC run starts.

Call build_jax() once after constructing the solver to pay that cost upfront.

Returns:

self – Returns self so the call can be chained: gp = GPSolver(...).build_jax().

Return type:

GPSolver

log_likelihood()[source]#

Marginal log-likelihood of the data under the GP.

Returns:

logL – log p(y | X, theta)

Return type:

float

predict(xpred, return_cov=False)[source]#

Predictive distribution at new input locations.

Parameters:
  • xpred (array_like, shape (M,)) – Prediction times.

  • return_cov (bool) – If True, return full predictive covariance.

Returns:

  • mu_pred (ndarray, shape (M,))

  • var_pred (ndarray, shape (M,) or (M, M))

plot_prediction(theta=None, n_points=2000, n_sigma=(1, 2), ax=None, data_color='k', model_color='r', show_legend=True, xlim=None, ylim=None, model_label='GP mean', data_label='Data')[source]#

Plot the GP posterior mean and uncertainty bands over the data.

If theta is provided the GP is temporarily updated to those hyperparameters before predicting, so the prediction reflects the given parameter values rather than whatever was last set internally.

Parameters:
  • theta (dict or array_like, shape (6,), optional) – Kernel parameters. Accepts a physical dict with keys from KERNEL_HPARAM_KEYS, a sampling-space dict with log_- prefixed keys (e.g. log_sigma_k), or a length-6 array. If None, uses the current internal hyperparameters.

  • n_points (int) – Number of prediction points spanning the data baseline.

  • n_sigma (int or sequence of int) – Which sigma levels to shade. E.g. (1, 2) draws both ±1σ and ±2σ bands (default). Pass a single int for one band.

  • ax (matplotlib Axes, optional) – Axes to plot on. If None, creates a new figure.

  • data_color (str) – Colors for data points and model curve/bands.

  • model_color (str) – Colors for data points and model curve/bands.

  • show_legend (bool) – Whether to draw a legend.

  • xlim (tuple, optional) – Limits for the x-axis. If None, defaults to the data range.

  • ylim (tuple, optional) – Limits for the y-axis. If None, defaults to the data range.

Returns:

ax

Return type:

matplotlib Axes

sample_prior(xpred, n_samples=1, rng=None)[source]#

Draw samples from the GP prior.

sample_posterior(xpred, n_samples=1, rng=None)[source]#

Draw samples from the GP posterior.

sample_lightcurves(theta=None, xpred=None, n_samples=5, n_points=2000, source='prior', rng=None)[source]#

Sample lightcurves from the GP prior or posterior.

Parameters:
  • theta (dict or array_like, optional) – Kernel parameters. Accepts a physical dict with keys from param_keys, a sampling-space dict with log_-prefixed keys, or a flat array matching param_keys. If None, uses the current internal hyperparameters.

  • xpred (array_like, optional) – Times at which to evaluate the samples. If None, uses n_points evenly spaced times spanning the data baseline.

  • n_samples (int) – Number of lightcurve samples to draw (default 5).

  • n_points (int) – Number of prediction points when xpred is None (default 2000).

  • source ({'prior', 'posterior'}) – Whether to sample from the GP prior or posterior (default ‘prior’).

  • rng (numpy.random.Generator, optional) – Random number generator for reproducibility.

Returns:

  • xpred (ndarray, shape (M,)) – Prediction times.

  • samples (ndarray, shape (n_samples, M)) – Sampled lightcurves.

plot_samples(theta=None, xpred=None, n_samples=5, n_points=2000, source='prior', rng=None, ax=None, show_data=True, data_color='k', sample_alpha=0.7, sample_lw=1.0, cmap='tab10', show_legend=True, xlim=None, ylim=None)[source]#

Plot sampled lightcurves from the GP prior or posterior.

Parameters:
  • theta (dict or array_like, optional) – Kernel parameters (see sample_lightcurves).

  • xpred (array_like, optional) – Prediction times. If None, n_points evenly spaced.

  • n_samples (int) – Number of samples to draw and plot (default 5).

  • n_points (int) – Number of prediction points when xpred is None.

  • source ({'prior', 'posterior'}) – Sample from the GP prior or posterior (default ‘prior’).

  • rng (numpy.random.Generator, optional) – Random number generator for reproducibility.

  • ax (matplotlib Axes, optional) – Axes to plot on. If None, creates a new figure.

  • show_data (bool) – Whether to overlay the observed data (default True).

  • data_color (str) – Color for data points (default ‘k’).

  • sample_alpha (float) – Opacity for sample curves (default 0.7).

  • sample_lw (float) – Line width for sample curves (default 1.0).

  • cmap (str) – Matplotlib colormap name for sample colors (default ‘tab10’).

  • show_legend (bool) – Whether to draw a legend (default True).

  • xlim (tuple, optional) – Axis limits.

  • ylim (tuple, optional) – Axis limits.

Returns:

ax

Return type:

matplotlib Axes

compute_acf(tlags=None, n_bins=50, normalize=True)[source]#

Compute the empirical autocorrelation function of the data.

Delegates to self.data.compute_acf().

Parameters:
  • tlags (array_like, optional) – Bin edges for time lags [days]. If provided, n_bins is inferred as len(tlags) - 1. If None, n_bins linearly spaced bins from 0 to half the baseline are used.

  • n_bins (int) – Number of lag bins (used when tlags is None, default 50).

  • normalize (bool) – If True (default), normalize so ACF(0) ~ 1.

Returns:

  • lag_centers (ndarray, shape (n_bins,)) – Bin centers.

  • acf (ndarray, shape (n_bins,)) – Empirical ACF at each bin center.

compute_kernel(tlags)[source]#

Evaluate the analytic kernel at the given time lags.

Parameters:

tlags (array_like, shape (M,)) – Time lags [days].

Returns:

K – Kernel values at each lag.

Return type:

ndarray, shape (M,)

fit_acf(theta0=None, keys=None, tlags=None, n_bins=50, method='L-BFGS-B', maxiter=500, ftol=0, gtol=1e-08, disp=False, nopt=1, ncore=None, rng=None, _save=True)[source]#

Fit the analytic kernel to the empirical ACF via least-squares.

Minimizes sum_i (ACF_data(lag_i) - K(lag_i; theta))^2 over the kernel hyperparameters, using JAX gradients and scipy.

Parameters:
  • theta0 (dict or array_like, optional) –

    Starting point. Can be:
    • None: uses self.theta0 (kernel params only, no sigma_n).

    • dict: values for any subset of kernel keys set the starting point. If keys is not given, the dict keys that overlap with KERNEL_HPARAM_KEYS are treated as the free variables; the rest are held fixed. Extra keys not in KERNEL_HPARAM_KEYS are ignored.

    • array_like: full kernel theta vector (6 elements).

  • keys (list of str, optional) – Which parameters to vary during optimization. Overrides the automatic inference from a dict theta0. Parameters not listed are held fixed at their current values. If None and theta0 is not a dict, all kernel parameters are varied.

  • tlags (array_like, optional) – Bin edges for compute_acf. If None, linearly spaced from 0 to half the baseline with n_bins+1 edges.

  • n_bins (int) – Number of lag bins (used when tlags is None).

  • method (str) – Scipy optimizer method.

  • maxiter (int) – Maximum iterations.

  • ftol (float) – Function-value convergence tolerance (default 0, disabled).

  • gtol (float) – Gradient-norm convergence tolerance (default 1e-8).

  • disp (bool) – If True, print optimizer convergence messages (default False).

  • nopt (int) – Number of independent optimisation trials (default 1). When > 1, fit_acf_parallel is called and the best result across all trials is returned.

  • ncore (int or None) – Number of parallel workers. Only used when nopt > 1.

  • rng (numpy.random.Generator, optional) – RNG for random starting points. Only used when nopt > 1.

Returns:

  • theta_dict (dict) – Full dictionary of all kernel hyperparameters (fixed + optimized).

  • result (scipy.optimize.OptimizeResult) – Full optimizer output.

fit_acf_parallel(nopt=10, ncore=None, keys=None, tlags=None, n_bins=50, method='nelder-mead', maxiter=500, ftol=0, gtol=1e-08, disp=False, return_all=False, rng=None)[source]#

Run fit_acf from multiple random starting points in parallel.

Starting points are drawn uniformly within the kernel parameter bounds.

Parameters:
  • nopt (int) – Number of independent optimization trials (default 10).

  • ncore (int or None) – Number of parallel workers. If None, uses nopt or the number of available CPUs, whichever is smaller.

  • keys (list of str, optional) – Free parameters (forwarded to fit_acf).

  • tlags – Forwarded to fit_acf.

  • n_bins – Forwarded to fit_acf.

  • method (str) – Optimizer method (default “nelder-mead”).

  • maxiter – Forwarded to fit_acf.

  • ftol – Forwarded to fit_acf.

  • gtol – Forwarded to fit_acf.

  • disp – Forwarded to fit_acf.

  • return_all (bool) – If True, return all solutions sorted by objective value. If False (default), return only the best solution.

  • rng (numpy.random.Generator, optional) – Random number generator for reproducibility.

Returns:

  • theta_best (dict (or list of dict if return_all=True)) – Best-fit kernel hyperparameters.

  • result_best (scipy.optimize.OptimizeResult) – (or list of OptimizeResult if return_all=True)

fit_acf_psd(theta0=None, keys=None, tlags=None, n_bins=50, n_freq=200, dt_kernel=None, acf_weight=1.0, psd_weight=1.0, method='L-BFGS-B', maxiter=500, ftol=0, gtol=1e-08, disp=False)[source]#

Fit kernel parameters jointly to the empirical ACF and PSD.

Minimizes a weighted sum of two normalized mean-squared-error terms:

loss = acf_weight * acf_loss + psd_weight * psd_loss

where

acf_loss = mean((ACF_data - K_model)^2) / mean(ACF_data^2)

is the relative MSE of the kernel against the empirical ACF (unnormalized autocovariance), and

psd_loss = mean((PSD_data_norm - PSD_model_norm)^2)

is the MSE between the Lomb-Scargle periodogram and the analytic kernel PSD, both normalized to unit integral so the comparison is independent of overall amplitude.

The model PSD is computed via a direct cosine transform of the kernel evaluated on a uniform lag grid, making it fully differentiable with respect to the kernel parameters.

Parameters:
  • theta0 (dict or array_like, optional) – Starting point in self.param_keys space (sampling space, with log_-prefixed keys where applicable). Follows the same convention as fit_map: None uses self.theta0, a dict overrides named entries and infers free keys, an array is used directly.

  • keys (list of str, optional) – Parameters to vary during optimization (names from self.param_keys). Defaults to all kernel parameters (first 6 entries of self.param_keys, i.e. excluding sigma_n if present).

  • tlags (array_like, optional) – Bin edges for the empirical ACF. If None, n_bins+1 edges linearly spaced from 0 to half the baseline.

  • n_bins (int) – Number of ACF lag bins when tlags is None (default 50).

  • n_freq (int) – Number of frequency points for the Lomb-Scargle periodogram (default 200).

  • dt_kernel (float, optional) – Uniform lag spacing [days] for evaluating the analytic kernel before the direct cosine transform. Defaults to one-fifth of the median data spacing.

  • acf_weight (float) – Weight for the ACF loss term (default 1.0).

  • psd_weight (float) – Weight for the PSD loss term (default 1.0).

  • method (str) – Scipy optimizer method (default "L-BFGS-B").

  • maxiter (int) – Maximum optimizer iterations (default 500).

  • ftol (float) – Convergence tolerances forwarded to scipy.

  • gtol (float) – Convergence tolerances forwarded to scipy.

  • disp (bool) – Print optimizer messages if True.

Returns:

  • theta_dict (dict) – Best-fit parameters in self.param_keys space.

  • result (scipy.optimize.OptimizeResult)

plot_acf(theta=None, tlags=None, n_bins=50, ax=None, normalize=False, data_color='k', model_color='r', show_legend=True, xlim=None, ylim=None, model_label='Analytic ACF', data_label='Data ACF')[source]#

Plot the empirical ACF and optionally the analytic kernel.

Parameters:
  • theta (dict or array_like, shape (6,), optional) – Kernel parameters. Accepts a physical dict with keys from KERNEL_HPARAM_KEYS, a sampling-space dict with log_- prefixed keys (e.g. log_sigma_k), or a length-6 array. If provided, the analytic kernel is overplotted.

  • tlags (array_like, optional) – Bin edges for compute_acf. If None, linearly spaced from 0 to half the baseline with n_bins+1 edges.

  • n_bins (int) – Number of lag bins (used when tlags is None).

  • ax (matplotlib Axes, optional) – Axes to plot on. If None, creates a new figure.

  • normalize (bool) – If True (default), normalize both curves by the data variance so ACF(0) ≈ 1.

  • xlim (tuple, optional) – Limits for the x-axis. If None, defaults to the data range.

  • ylim (tuple, optional) – Limits for the y-axis. If None, defaults to the data range.

Returns:

ax

Return type:

matplotlib Axes

plot_psd(theta=None, n_freq=500, dt_kernel=None, ax=None, data_color='k', model_color='r', show_legend=True, xlim=None, ylim=None, model_label='Analytic PSD', data_label='Data Lomb-Scargle')[source]#

Plot the empirical PSD (Lomb-Scargle) and optionally the analytic kernel PSD (FFT of the autocovariance function).

Both curves are normalized so their integral over positive frequencies equals the data variance, making them directly comparable.

Parameters:
  • theta (dict or array_like, shape (6,), optional) – Kernel parameters. Accepts a physical dict with keys from KERNEL_HPARAM_KEYS, a sampling-space dict with log_- prefixed keys (e.g. log_sigma_k), or a length-6 array. If provided, the analytic kernel PSD is overplotted.

  • n_freq (int) – Number of frequency points for the Lomb-Scargle periodogram.

  • dt_kernel (float, optional) – Time step [days] for evaluating the analytic kernel on a uniform grid before FFT. Defaults to one-fifth of the median data spacing.

  • ax (matplotlib Axes, optional) – Axes to plot on. If None, creates a new figure.

  • data_color (str) – Colors for the data and model curves.

  • model_color (str) – Colors for the data and model curves.

  • show_legend (bool) – Whether to draw a legend.

  • xlim (tuple, optional) – Limits for the x-axis. If None, defaults to the data range.

  • ylim (tuple, optional) – Limits for the y-axis. If None, defaults to the data range.

Returns:

ax

Return type:

matplotlib Axes

plot_covariance_matrix(theta=None, ax=None, cmap='RdBu_r', show_colorbar=True, vmax=None, nbins=50, show=False, filename='covariance_matrix.png')[source]#

Plot the GP covariance matrix K (signal only, no noise).

Entries outside the banded support are set to zero, matching the cholesky_banded approximation. The matrix is binned to nbins x nbins before plotting. The bandwidth boundary is drawn as dashed lines, and band width plus matrix sparsity are annotated.

Parameters:
  • theta (dict or array_like, optional) – Kernel hyperparameters. Accepts a physical dict with keys from param_keys, a sampling-space dict with log_-prefixed keys, or a raw array. If None, uses the current self.hparam values.

  • ax (matplotlib Axes, optional) – Axes to plot on. If None, a new figure is created.

  • cmap (str, optional) – Colormap name. Defaults to "RdBu_r" (diverging, centred at zero).

  • show_colorbar (bool, optional) – Whether to add a colorbar. Default True.

  • vmax (float, optional) – Symmetric color scale limit [-vmax, vmax]. If None, uses the maximum absolute value of the banded matrix.

  • nbins (int, optional) – Bin the N x N matrix down to nbins x nbins by averaging non-overlapping blocks before plotting. Default 50.

  • show (bool, optional) – If True, call plt.show(). Default False.

  • filename (str, optional) – Filename used when saving to save_dir. Default "covariance_matrix.png".

Returns:

ax

Return type:

matplotlib Axes

get_theta()[source]#

Return the current kernel hyperparameters as a dictionary.

Returns:

theta – Keys and values for all kernel (and optionally noise) hyperparameters, e.g. {“peq”: 5.0, “kappa”: 0.2, …}.

Return type:

dict

update_hparam(hparam)[source]#

Update hyperparameters and rebuild kernel and covariance.

Accepts a SpotEvolutionModel, an hparam dict (legacy keys like lspot), or a theta-style dict whose keys match self.spot_model.param_keys (e.g. tau_em, lat_min).

fit_map(theta0=None, keys=None, method='L-BFGS-B', maxiter=500, ftol=0, gtol=1e-08, disp=False, nopt=1, ncore=None, rng=None, _save=True)[source]#

Find the maximum a posteriori (MAP) estimate.

Uses scipy.optimize.minimize with JAX-computed gradients. When nopt > 1, delegates to fit_map_parallel which runs nopt independent trials from random starting points (drawn uniformly within the bounds) and returns the best result.

Parameters:
  • theta0 (dict or array_like, optional) –

    Starting point. Can be:
    • None: uses self.theta0 (current hyperparameters).

    • dict: values for any subset of param_keys set the starting point. If keys is not given, the dict keys that overlap with self.param_keys are treated as the free variables to optimize; the rest are held fixed. Extra keys not in param_keys are ignored.

    • array_like: full theta vector (length n_params).

    Ignored when nopt > 1 (starting points are randomised).

  • keys (list of str, optional) – Which parameters to vary during optimization. Overrides the automatic inference from a dict theta0. Parameters not listed are held fixed at their current values. If None and theta0 is not a dict, all parameters are varied.

  • method (str) – Scipy optimizer method (default “L-BFGS-B”).

  • maxiter (int) – Maximum iterations.

  • ftol (float) – Function-value convergence tolerance for L-BFGS-B (default 0, i.e. disabled so that convergence is controlled by gtol).

  • gtol (float) – Gradient-norm convergence tolerance (default 1e-8).

  • disp (bool) – If True, print optimizer convergence messages (default False).

  • nopt (int) – Number of independent optimisation trials (default 1). When > 1, fit_map_parallel is called and the best result across all trials is returned.

  • ncore (int or None) – Number of parallel workers for multi-start runs. If None, uses nopt or the number of available CPUs, whichever is smaller. Only used when nopt > 1.

  • rng (numpy.random.Generator, optional) – RNG for random starting points. Only used when nopt > 1.

Returns:

  • theta_dict (dict) – Full dictionary of all hyperparameters (fixed + optimized).

  • result (scipy OptimizeResult) – Full optimizer output.

fit_map_parallel(nopt=10, ncore=None, keys=None, method='nelder-mead', maxiter=500, ftol=0, gtol=1e-08, disp=False, return_all=False, rng=None, theta0=None, jitter=0.01)[source]#

Run fit_map from multiple random starting points in parallel.

Starting points are drawn uniformly within the parameter bounds.

Parameters:
  • nopt (int) – Number of independent optimization trials (default 10).

  • ncore (int or None) – Number of parallel workers. If None, uses nopt or the number of available CPUs, whichever is smaller.

  • keys (list of str, optional) – Free parameters (forwarded to fit_map).

  • method (str) – Optimizer method (default “nelder-mead”).

  • maxiter – Forwarded to fit_map.

  • ftol – Forwarded to fit_map.

  • gtol – Forwarded to fit_map.

  • disp – Forwarded to fit_map.

  • return_all (bool) – If True, return all solutions sorted by objective value. If False (default), return only the best solution.

  • rng (numpy.random.Generator, optional) – Random number generator for reproducibility.

  • theta0 (dict, optional) – Initial parameter guess to include as one of the starting points. Replaces one random start so the total number of trials stays nopt.

Returns:

  • theta_best (dict (or list of dict if return_all=True)) – Best-fit hyperparameters.

  • result_best (scipy.optimize.OptimizeResult) – (or list of OptimizeResult if return_all=True)

mass_matrix_hessian_map(theta_map=None)[source]#

Estimate the inverse mass matrix from the Hessian of the negative log-likelihood at the MAP.

M^{-1} = H^{-1} where H = d^2(-log L)/d theta^2 at the MAP.

Parameters:

theta_map (array_like, optional) – MAP estimate. If None, calls fit_map() first.

Returns:

inv_mass_matrix

Return type:

jnp.ndarray, shape (n_params, n_params)

mass_matrix_fisher(theta_map=None, eigval_clip=1e-06, white_noise=1e-08)[source]#

Estimate the inverse mass matrix from the Fisher information.

For the GP log-likelihood:

I_{ij} = (1/2) tr(K^{-1} dK/dtheta_i K^{-1} dK/dtheta_j)

When matrix_solver="cholesky_full", the kernel derivatives dK/dtheta_i are computed via JAX forward-mode autodiff (jacfwd) on the full N×N covariance matrix.

When matrix_solver="cholesky_banded", the exact Fisher requires the dense N×N kernel and its inverse, which would defeat the purpose of banded storage. Instead, the Fisher is approximated by the Hessian of the banded negative log-likelihood at the MAP (Fisher ≈ observed information at the MLE).

Parameters:

theta_map (array_like, optional) – Point at which to evaluate Fisher. If None, uses MAP.

Returns:

inv_mass_matrix

Return type:

jnp.ndarray, shape (n_params, n_params)

mass_matrix_laplace(theta_map=None, eigval_clip=1e-06)[source]#

Laplace approximation: inverse mass matrix = inverse Hessian of the negative log-likelihood at the MAP.

The posterior is approximated as:

p(theta | data) ~ N(theta_MAP, H^{-1})

Parameters:

theta_map (array_like, optional) – MAP estimate. If None, calls fit_map() first.

Returns:

inv_mass_matrix

Return type:

jnp.ndarray, shape (n_params, n_params)

laplace_samples(n_samples=1000, rng_key=None)[source]#

Draw samples from the Laplace (Gaussian) approximation to the posterior.

Parameters:
  • n_samples (int)

  • rng_key (jax.random.PRNGKey, optional)

Returns:

samples

Return type:

jnp.ndarray, shape (n_samples, n_params)

Power Spectral Density#

spotgp.psd.compute_psd(y, t=None, dt=None, normalization='psd', freq_min=None, freq_max=None, n_freq=None, samples_per_peak=5)[source]#

Compute the Power Spectral Density of a time series using astropy.timeseries.LombScargle.

Works for both evenly and unevenly sampled data.

Parameters:
  • y (array-like, shape (N,)) – Time series values.

  • t (array-like, shape (N,), optional) – Sample times. If None, integer indices scaled by dt are used.

  • dt (float, optional) – Sampling interval. Used only when t is None (default: 1).

  • normalization ({"psd", "standard", "model", "log"}) – Passed directly to LombScargle.autopower / power.

  • freq_min (float, optional) – Minimum frequency to evaluate.

  • freq_max (float, optional) – Maximum frequency to evaluate.

  • n_freq (int, optional) – Number of frequency grid points.

  • samples_per_peak (float, optional) – Controls the frequency grid density (default 5).

Returns:

  • freq (ndarray) – Frequencies in cycles per unit time.

  • power (ndarray) – PSD evaluated at each frequency.

MCMC Sampler#

class spotgp.mcmc.MCMCSampler(gp)[source]#

Bases: object

Base MCMC sampler for GP hyperparameters.

Wraps a GPSolver object and provides shared storage, diagnostics, summary statistics, corner plots, and dict conversion. Subclasses implement specific sampling algorithms (e.g. NUTS).

Parameters:

gp (GPSolver) – A configured GPSolver instance.

property param_keys#
property n_params#
summary()[source]#

Print summary statistics of the posterior samples.

Returns:

stats – Parameter names mapped to (mean, std, 16%, 50%, 84%).

Return type:

dict

plot_covariance(method='fisher', theta_map=None, n_sigma=2, n_grid=200, samples=None, figsize=None, color='C0', alpha=0.3, true_params=None, savefig=None, **corner_kwargs)[source]#

Corner plot of 2D covariance ellipses from the Hessian or Fisher matrix, with 1D marginal Gaussians on the diagonal.

Uses corner.corner to lay out the figure when MCMC samples are provided, and overlays the Laplace/Fisher Gaussian approximation (ellipses + 1D marginals).

Parameters:
  • method ({"fisher", "hessian_map", "laplace"}) – Which matrix to use for the Gaussian approximation.

  • theta_map (array_like, optional) – Center of the ellipses. If None, uses MAP estimate.

  • n_sigma (float) – Number of sigma for the ellipse contours (default 2).

  • n_grid (int) – Grid resolution for the ellipse curves (default 200).

  • samples (array_like, optional) – If provided, plotted as the corner histogram/contours. If None, the figure is created with empty axes and only the Gaussian approximation is drawn.

  • figsize (tuple, optional) – Figure size.

  • color (str) – Color for Gaussian ellipses and marginals (default “C0”).

  • alpha (float) – Fill alpha for the ellipse interiors (default 0.3).

  • true_params (dict or array_like, optional) – True parameter values to mark with crosshairs.

  • savefig (str, optional) – If provided, save figure to this path.

  • **corner_kwargs – Extra keyword arguments forwarded to corner.corner (e.g. quantiles, show_titles, hist_kwargs).

Returns:

fig, axes

Return type:

matplotlib Figure and 2D array of Axes.

plot_corner_map(samples=None, checkpoint_path=None, cmap='viridis', marker_size=40, savefig=None, true_params=None, **corner_kwargs)[source]#

Corner plot of MCMC samples with MAP solutions overlaid as scatter points colored by their log-likelihood.

Parameters:
  • samples (array_like, optional) – Shape (n_samples, n_params). If None, loads from the checkpoint file.

  • checkpoint_path (str, optional) – Path to checkpoint .npz file containing MAP solutions. If None, uses the default checkpoint file.

  • cmap (str) – Colormap for the MAP scatter points (default “viridis”).

  • marker_size (float) – Marker size for scatter points (default 40).

  • savefig (str, optional) – If provided, save figure to this path.

  • true_params (dict or array_like, optional) – True parameter values to mark with crosshairs.

  • **corner_kwargs – Extra keyword arguments forwarded to corner.corner.

Returns:

fig, axes

Return type:

matplotlib Figure and 2D array of Axes.

to_dict(samples=None)[source]#

Convert samples array to a dict keyed by parameter name.

Parameters:

samples (jnp.ndarray, optional) – Shape (n_samples, n_params). If None, uses self.samples.

Returns:

d – {param_name: array of shape (n_samples,)}

Return type:

dict

class spotgp.mcmc.BlackJAXSampler(gp, save_dir='results', checkpoint_file='mcmc_checkpoint.npz')[source]#

Bases: MCMCSampler

NUTS sampler using the BlackJAX library.

Inherits diagnostics, summary, plotting, and dict conversion from MCMCSampler. Adds run_map, run_warmup, and run_sampling for gradient-based No-U-Turn sampling with dual-averaging step-size adaptation.

When multiple chains are requested, sampling is parallelized across available devices via jax.pmap. Chains are distributed evenly across devices (n_chains must be divisible by jax.device_count()). On a single GPU this behaves identically to the previous jax.vmap implementation.

Parameters:
  • gp (GPSolver) – A configured GPSolver instance.

  • save_dir (str, optional) – Directory for all outputs produced by this sampler (corner plots, covariance plots, etc.). Created automatically if it does not exist. When set, save_checkpoint will default to saving the checkpoint inside this directory.

  • checkpoint_file (str, optional) – Path to the checkpoint file. When provided, overrides the default save_dir/mcmc_checkpoint.npz. If neither checkpoint_file nor save_dir is given, no checkpoint file is set until one is passed to a later method.

run_map(nopt=10, keys=None, checkpoint_file=None, theta0=None, **kwargs)[source]#

Find MAP solutions via parallel multi-start optimization.

Runs GPSolver.fit_map_parallel and stores the results. If the checkpoint file already contains MAP data, loads from it instead of re-running the optimization.

Parameters:
  • nopt (int) – Number of independent optimization restarts (default 10).

  • keys (list of str, optional) – Parameter names to optimize. If None, uses all bounded parameters from GPSolver.

  • theta0 (dict, optional) – Initial parameter guess to include as one of the optimization starting points. Replaces one random start so the total number of restarts stays nopt.

  • checkpoint_file (str, optional) – Path to save/load MAP solutions. If provided, also updates the sampler’s default checkpoint path. Defaults to self._checkpoint_file.

  • **kwargs – Additional keyword arguments passed to GPSolver.fit_map_parallel (e.g. method, maxiter).

Returns:

all_theta_maps – All MAP solutions sorted by objective (best first).

Return type:

list of dict

run_warmup(n_warmup=500, theta_init=None, mass_matrix_method='hessian_map', step_size=None, rng_key=None, target_accept=0.8, progress_bar=False, n_chains=1, checkpoint_file=None, warmup_method='window_adaptation', pathfinder_maxiter=100, pathfinder_maxcor=10, pathfinder_num_elbo=200)[source]#

Run warmup phase: adapt step size and mass matrix.

Supports three warmup strategies:

  • "window_adaptation" (default): BlackJAX’s standard dual-averaging window adaptation of both step size and mass matrix.

  • "pathfinder": multi-path Pathfinder via L-BFGS.

  • "dual_averaging": fixes the mass matrix (from Hessian at MAP) and only adapts the step size.

After warmup, adapted parameters are stored on the sampler and a checkpoint is saved (if checkpoint_file is set).

Parameters:
  • n_warmup (int) – Number of warmup steps (default 500).

  • theta_init (dict or array_like, optional) – Initial position. If None, uses GPSolver’s MAP estimate. Can also be a list of dicts or 2-D array for per-chain starting points.

  • mass_matrix_method ({"hessian_map", "fisher", "laplace", "diagonal", None}) – Method to estimate the mass matrix.

  • step_size (float, optional) – Initial NUTS step size. If None, a heuristic is used.

  • rng_key (jax.random.PRNGKey, optional) – Random key. Default: PRNGKey(0).

  • target_accept (float) – Target acceptance rate (default 0.8).

  • progress_bar (bool) – If True, show progress during window adaptation.

  • n_chains (int) – Number of chains (used to validate device count and store per-chain init positions).

  • checkpoint_file (str, optional) – Override the default checkpoint file path. When set, updates self._checkpoint_file for all subsequent save/load operations. Defaults to save_dir/mcmc_checkpoint.npz when save_dir is set.

  • warmup_method ({"window_adaptation", "pathfinder", "dual_averaging"}) – Warmup strategy.

  • pathfinder_maxiter (int) – Max L-BFGS iterations for Pathfinder (default 100).

  • pathfinder_maxcor (int) – L-BFGS history size for Pathfinder (default 10).

  • pathfinder_num_elbo (int) – Number of ELBO samples for Pathfinder (default 200).

run_sampling(n_samples=1000)[source]#

Run NUTS sampling using adapted parameters from run_warmup.

Must be called after run_warmup (or will use parameters restored from a checkpoint).

Parameters:

n_samples (int) – Number of post-warmup samples per chain (default 1000).

Returns:

  • samples (jnp.ndarray) – Shape (n_samples, n_params) when n_chains=1, or (n_chains, n_samples, n_params) when n_chains > 1.

  • info (dict) – Sampling diagnostics (arrays have a leading chain dimension when n_chains > 1).

save_checkpoint(path=None, append_samples=True, plot_corner=False)[source]#

Save sampler state to disk for later resumption.

When append_samples=True (the default), new samples are appended to any existing samples already stored in path, and self.samples is cleared from memory. This enables a sample-checkpoint-clear loop that keeps memory usage constant.

Parameters:
  • path (str, optional) – File path (saved as .npz). If None, uses the checkpoint_file set in run_warmup, or save_dir/checkpoint.npz if save_dir was set.

  • append_samples (bool) – If True, append current self.samples to any samples already on disk, then clear self.samples from memory. If False, overwrite with only the current in-memory samples.

  • plot_corner (bool) – If True, load all samples currently on disk after saving and write a corner plot to save_dir/corner_plot.png (or alongside the checkpoint file if save_dir is not set).

load_checkpoint(checkpoint_file=None)[source]#

Restore sampler state from a checkpoint file.

Loads only the NUTS state and adapted kernel parameters needed to resume sampling. Samples stored in the file are not loaded into memory — use load_samples to read them later.

Parameters:

checkpoint_file (str, optional) – Path to a .npz checkpoint file. If provided, also updates the sampler’s default checkpoint path. If None, uses the default save_dir/mcmc_checkpoint.npz.

run_smc(n_particles=500, n_mcmc_steps=10, n_adapt_steps=25, target_ess=0.5, target_accept=0.6, rng_key=None, step_size=None, mass_matrix_method='hessian_map', theta_init=None, max_tempering_steps=200, checkpoint_every=10, checkpoint_file=None, particle_batch_size=None, max_num_doublings=10)[source]#

Run adaptive tempered Sequential Monte Carlo.

Starts from the prior and anneals toward the full posterior using an adaptive temperature schedule. At each tempering step, particles are resampled and rejuvenated with NUTS moves. The NUTS step size is re-adapted via dual averaging at each tempering stage using a representative particle.

Parameters:
  • n_particles (int) – Number of SMC particles (default 500).

  • n_mcmc_steps (int) – NUTS rejuvenation steps per tempering stage (default 10).

  • n_adapt_steps (int) – Dual-averaging warmup steps to adapt the NUTS step size at each tempering stage (default 25).

  • target_ess (float) – Target effective sample size as a fraction of n_particles (default 0.5).

  • target_accept (float) – Target NUTS acceptance rate for dual averaging (default 0.6).

  • rng_key (jax.random.PRNGKey, optional) – Random key. Default: PRNGKey(42).

  • step_size (float, optional) – Initial NUTS step size. If None, a heuristic from the mass matrix is used.

  • mass_matrix_method (str, optional) – Method to estimate the inverse mass matrix (default "hessian_map"). Set to None to use an identity matrix.

  • theta_init (dict or array_like, optional) – Reference point for mass matrix estimation. If None, the MAP estimate is used.

  • max_tempering_steps (int) – Safety limit on the number of tempering stages (default 200).

  • checkpoint_every (int) – Save a checkpoint every this many tempering steps (default 10). Set to 0 to disable periodic checkpointing.

  • checkpoint_file (str, optional) – Override the default checkpoint file path.

  • particle_batch_size (int, optional) – Process particles in batches of this size to limit GPU memory usage. When multiple GPUs are visible the batches are distributed across devices via jax.pmap. n_particles must be divisible by this value (and by batch_size * n_devices for multi-GPU). If None, all particles are evaluated at once (original blackjax behavior).

  • max_num_doublings (int, optional) – Maximum NUTS tree depth (default 10). Lower values (e.g. 5-6) reduce peak GPU memory per particle at the cost of shorter trajectories.

Returns:

  • samples (np.ndarray, shape (n_particles, n_params)) – Weighted posterior particles at the final temperature.

  • info (dict) – Diagnostics including tempering schedule and log evidence estimate.

static load_samples(path, flatten_chains=True)[source]#

Read all samples from a checkpoint file without loading the sampler state.

Parameters:
  • path (str) – Path to a .npz checkpoint file.

  • flatten_chains (bool) – If True (default), collapse the chain dimension so the returned array is always (n_total, n_params). Set to False to get the raw (n_chains, n_samples, n_params) array for per-chain diagnostics (e.g. R-hat).

Returns:

samples – Shape (n_total, n_params) when flatten_chains=True, or (n_chains, n_samples, n_params) otherwise.

Return type:

np.ndarray