import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
try:
from .params import resolve_hparam
from .envelope import (
EnvelopeFunction,
TrapezoidAsymmetricEnvelope,
SkewedGaussianEnvelope,
ExponentialEnvelope,
compute_R_Gamma_numerical,
)
from .spot_model import (
VisibilityFunction, EdgeOnVisibilityFunction, SpotEvolutionModel,
_cn_squared_coefficients_jax, _gauss_legendre_grid,
)
except ImportError:
from params import resolve_hparam
from envelope import (
EnvelopeFunction,
TrapezoidAsymmetricEnvelope,
SkewedGaussianEnvelope,
ExponentialEnvelope,
compute_R_Gamma_numerical,
)
from spot_model import (
VisibilityFunction, EdgeOnVisibilityFunction, SpotEvolutionModel,
_cn_squared_coefficients_jax, _gauss_legendre_grid,
)
__all__ = ["AnalyticKernel", "NonstationaryAnalyticKernel",
"compute_R_Gamma_numerical"]
[docs]
class AnalyticKernel:
"""
JAX-accelerated analytic GP kernel for stellar rotation variability.
Parameters
----------
model_or_hparam : SpotEvolutionModel or dict
Either a SpotEvolutionModel instance (new API) or a raw hparam dict
(backward-compatible old API).
n_harmonics : int
Number of Fourier harmonics for the visibility function (default 3).
n_lat : int
Number of latitude quadrature points (default 64).
lat_range : tuple
(min, max) latitude in radians (default (-pi/2, pi/2)).
quadrature : str
Latitude integration method: "trapezoid" or "gauss-legendre".
"""
def __init__(self, model_or_hparam, n_harmonics=3, n_lat=64,
lat_range=None, quadrature="trapezoid"):
# ── Accept SpotEvolutionModel or legacy hparam dict ────────────────
if isinstance(model_or_hparam, SpotEvolutionModel):
self.spot_model = model_or_hparam
self.hparam = model_or_hparam.to_hparam()
else:
# Backward compat: dict input
self.hparam = resolve_hparam(model_or_hparam)
self.spot_model = SpotEvolutionModel.from_hparam(self.hparam)
# ── Unpack commonly-used params ────────────────────────────────────
self.envelope = self.spot_model.envelope
self.visibility = self.spot_model.visibility
self.peq = self.spot_model.peq
self.kappa = self.spot_model.kappa
self.inc = self.spot_model.inc
self.lspot = self.spot_model.lspot
self.sigma_k = self.spot_model.sigma_k
self.tau_spot = self.spot_model.tau_spot
# ── Envelope-type attributes (backward compat) ────────────────────
if isinstance(self.envelope, SkewedGaussianEnvelope):
self.envelope_type = "skew_normal"
self.sigma_sn = self.envelope.sigma_sn
self.n_sn = self.envelope.n_sn
self.tau_em = self.tau_spot
self.tau_dec = self.tau_spot
self.asymmetric = False
# Re-use grids from the envelope object
self._R_Gamma_lag_grid = self.envelope._R_lag_grid
self._R_Gamma_vals = self.envelope._R_vals
self._Gh_sq_omega_grid = self.envelope._Gh_omega_grid
self._Gh_sq_vals = self.envelope._Gh_sq_vals
elif isinstance(self.envelope, TrapezoidAsymmetricEnvelope):
self.envelope_type = "trapezoid_asymmetric"
self.asymmetric = True
self.tau_em = self.envelope.tau_em
self.tau_dec = self.envelope.tau_dec
self._te = min(self.tau_em, self.tau_dec)
self._td = max(self.tau_em, self.tau_dec)
elif isinstance(self.envelope, ExponentialEnvelope):
self.envelope_type = "exponential"
self.asymmetric = False
self.tau_em = self.tau_spot
self.tau_dec = self.tau_spot
else:
# Default: symmetric trapezoid (or any other future type)
self.envelope_type = "trapezoid_symmetric"
self.asymmetric = False
self.tau_em = self.tau_spot
self.tau_dec = self.tau_spot
# ── Kernel config ──────────────────────────────────────────────────
self.n_harmonics = n_harmonics
self.n_lat = n_lat
self.lat_range = (lat_range if lat_range is not None
else self.spot_model.latitude_distribution.lat_range)
self.quadrature = quadrature
if quadrature == "gauss-legendre":
self._quad_nodes, self._quad_weights = _gauss_legendre_grid(
n_lat, lat_range[0], lat_range[1])
elif quadrature == "trapezoid":
self._quad_nodes = None
self._quad_weights = None
else:
raise ValueError(
f"Unknown quadrature method: {quadrature!r}. "
"Use 'trapezoid' or 'gauss-legendre'.")
# ── Core kernel helpers ─────────────────────────────────────────────────
[docs]
def omega0(self, phi):
"""Latitude-dependent rotation angular frequency [rad/day]."""
return self.visibility.omega0(phi)
[docs]
def R_Gamma(self, lag):
"""Autocorrelation of the squared envelope (delegates to envelope)."""
return self.envelope.R_Gamma(jnp.asarray(lag))
[docs]
def cn_squared(self, phi):
"""Squared Fourier visibility coefficients at latitude phi."""
return self.visibility.cn_squared(phi, self.n_harmonics)
# ── Single-latitude kernel ──────────────────────────────────────────────
[docs]
def kernel_single_latitude(self, lag, phi):
"""Single-spot kernel at a fixed latitude."""
lag = jnp.asarray(lag, dtype=float).ravel()
R = self.R_Gamma(lag)
cn_sq = self.cn_squared(phi)
w0 = self.omega0(phi)
ns = jnp.arange(1, len(cn_sq))
cosine_terms = jnp.sum(
cn_sq[1:] * jnp.cos(ns * w0 * lag[:, None]), axis=1)
return R * (cn_sq[0] + 2 * cosine_terms)
# ── Stationary kernel (without sigma_k² scaling) ─────────────────────
def _kernel_stationary(self, lag, lat_dist=None):
"""
Stationary kernel *without* the σ_k² prefactor.
Returns R_Γ(τ) · Σ_n w_n |c_n|² cos(n·ω₀·τ), averaged over latitude.
"""
lag = jnp.asarray(lag, dtype=float)
orig_shape = lag.shape
lag_flat = lag.ravel()
if isinstance(self.visibility, EdgeOnVisibilityFunction):
R = self.R_Gamma(lag_flat)
cn_sq = self.visibility.cn_squared(0.0, self.n_harmonics)
w0 = self.visibility.omega0(0.0)
ns = jnp.arange(1, self.n_harmonics + 1)
cosine_terms = jnp.sum(
cn_sq[1:] * jnp.cos(ns * w0 * lag_flat[:, None]), axis=1)
K = R * (cn_sq[0] + 2 * cosine_terms)
return K.reshape(orig_shape)
if lat_dist is None:
lat_dist = self.spot_model.latitude_distribution
R = self.R_Gamma(lag_flat)
n_harmonics = self.n_harmonics
def _lat_contribution(phi):
cn_sq = self.cn_squared(phi)
w0 = self.omega0(phi)
ns = jnp.arange(1, n_harmonics + 1)
cosine_terms = jnp.sum(
cn_sq[1:] * jnp.cos(ns * w0 * lag_flat[:, None]), axis=1)
return cn_sq[0] + 2 * cosine_terms
if self.quadrature == "gauss-legendre":
phi_grid = self._quad_nodes
quad_weights = self._quad_weights
user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
weights = user_weights * quad_weights
norm = jnp.sum(weights)
else:
phi_min, phi_max = self.lat_range
phi_grid = jnp.linspace(phi_min, phi_max, self.n_lat)
dphi = phi_grid[1] - phi_grid[0]
user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
weights = user_weights * dphi
norm = jnp.trapezoid(user_weights, phi_grid)
def _scan_body(K_acc, idx):
phi = phi_grid[idx]
w = weights[idx]
return K_acc + w * _lat_contribution(phi), None
K, _ = jax.lax.scan(
_scan_body, jnp.zeros_like(lag_flat), jnp.arange(len(phi_grid)))
K = R * K / norm
return K.reshape(orig_shape)
# ── Full kernel (latitude-averaged) ────────────────────────────────────
[docs]
def kernel(self, lag, lat_dist=None):
"""
Full GP kernel averaged over latitude.
Uses jax.lax.scan for memory-efficient accumulation: only one
lag-sized buffer is live at a time — O(M) instead of O(n_lat·M).
When the visibility function is an EdgeOnVisibilityFunction, the
latitude-averaged ``|c_n|^2`` are known constants and the latitude
loop is bypassed entirely.
Parameters
----------
lag : array_like
Time lags [days]. Can be 1D or 2D.
lat_dist : callable or None
Latitude probability density. If None, uniform.
Returns
-------
K : ndarray, same shape as lag input.
"""
K = self._kernel_stationary(lag, lat_dist=lat_dist)
return np.asarray(self.sigma_k ** 2 * K)
[docs]
def kernel_solid_body(self, lag, lat_dist=None):
"""Kernel for solid-body rotation (kappa=0)."""
lag = jnp.asarray(lag, dtype=float)
if lat_dist is None:
lat_dist = self.spot_model.latitude_distribution
if self.quadrature == "gauss-legendre":
phi_grid = self._quad_nodes
quad_weights = self._quad_weights
all_cn_sq = jax.vmap(
lambda phi: _cn_squared_coefficients_jax(
self.inc, phi, self.n_harmonics))(phi_grid)
user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
norm = jnp.sum(user_weights * quad_weights)
cn_sq_avg = jnp.sum(
user_weights[:, None] * quad_weights[:, None] * all_cn_sq,
axis=0) / norm
else:
phi_min, phi_max = self.lat_range
phi_grid = jnp.linspace(phi_min, phi_max, self.n_lat)
user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
norm = jnp.trapezoid(user_weights, phi_grid)
all_cn_sq = jax.vmap(
lambda phi: _cn_squared_coefficients_jax(
self.inc, phi, self.n_harmonics))(phi_grid)
cn_sq_avg = jnp.sum(
user_weights[:, None] * all_cn_sq, axis=0
) * (phi_grid[1] - phi_grid[0]) / norm
w0 = 2 * jnp.pi / self.peq
R = self.R_Gamma(lag)
ns = jnp.arange(1, len(cn_sq_avg))
cosine_terms = jnp.sum(
cn_sq_avg[1:] * jnp.cos(ns * w0 * lag[:, None]), axis=1)
return np.asarray(R * (cn_sq_avg[0] + 2 * cosine_terms) * self.sigma_k ** 2)
# ── Power spectral density ──────────────────────────────────────────────
[docs]
def compute_psd(self, omega, lat_dist=None):
"""
Analytic power spectral density.
Parameters
----------
omega : array_like
Angular frequencies [rad/day].
lat_dist : callable or None
Latitude probability density.
Returns
-------
freq : ndarray [cycles/day]
power : ndarray
"""
omega = jnp.asarray(omega, dtype=float)
if lat_dist is None:
lat_dist = self.spot_model.latitude_distribution
# Build the per-latitude PSD contribution based on envelope type
if isinstance(self.envelope, (SkewedGaussianEnvelope, ExponentialEnvelope)):
# Use envelope's Gamma_hat_sq directly
def _psd_at_lat(phi):
cn_sq = self.cn_squared(phi)
w0 = self.omega0(phi)
contrib = cn_sq[0] * self.envelope.Gamma_hat_sq(omega)
def _harmonic(n):
return cn_sq[n] * (
self.envelope.Gamma_hat_sq(omega - n * w0)
+ self.envelope.Gamma_hat_sq(omega + n * w0))
ns = jnp.arange(1, len(cn_sq))
harmonic_contribs = jax.vmap(lambda n: _harmonic(n))(ns)
return contrib + jnp.sum(harmonic_contribs, axis=0)
else:
# Trapezoid types use the closed-form _Gamma_hat
def _psd_at_lat(phi):
cn_sq = self.cn_squared(phi)
w0 = self.omega0(phi)
Gh_0 = self.envelope.Gamma_hat(omega)
contrib = cn_sq[0] * Gh_0 ** 2
def _harmonic(n):
Gh_p = self.envelope.Gamma_hat(omega - n * w0)
Gh_m = self.envelope.Gamma_hat(omega + n * w0)
return cn_sq[n] * (Gh_p ** 2 + Gh_m ** 2)
ns = jnp.arange(1, len(cn_sq))
harmonic_contribs = jax.vmap(lambda n: _harmonic(n))(ns)
return contrib + jnp.sum(harmonic_contribs, axis=0)
if self.quadrature == "gauss-legendre":
phi_grid = self._quad_nodes
quad_weights = self._quad_weights
all_contribs = jax.vmap(_psd_at_lat)(phi_grid)
user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
norm = jnp.sum(user_weights * quad_weights)
psd = jnp.sum(
user_weights[:, None] * quad_weights[:, None]
* all_contribs, axis=0) / norm
else:
phi_min, phi_max = self.lat_range
phi_grid = jnp.linspace(phi_min, phi_max, self.n_lat)
dphi = phi_grid[1] - phi_grid[0]
user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
norm = jnp.trapezoid(user_weights, phi_grid)
all_contribs = jax.vmap(_psd_at_lat)(phi_grid)
psd = jnp.sum(user_weights[:, None] * all_contribs, axis=0) * dphi / norm
psd = psd * self.sigma_k ** 2
self.psd_omega = np.asarray(omega)
self.psd_freq = np.asarray(omega / (2 * jnp.pi))
self.psd_power = np.asarray(psd)
return self.psd_freq, self.psd_power
[docs]
def build_jax(self, n_lag=256):
"""
Pre-compile and warm up JAX JIT computation for this kernel.
``jax.lax.scan`` (used inside ``kernel()``) triggers XLA compilation
on its first call for a given array shape. That compilation can take
several seconds and is easy to mistake for slow runtime. Call
``build_jax()`` once after constructing the kernel to pay that cost
upfront — subsequent calls to ``kernel()`` and ``compute_psd()`` with
the same shape will be fast.
Parameters
----------
n_lag : int
Length of the dummy lag array used to drive compilation (default
256). The actual value does not matter as long as it is
representative of the sizes you will use at runtime.
Returns
-------
self : AnalyticKernel
Returns ``self`` so the call can be chained:
``ak = AnalyticKernel(model).build_jax()``.
"""
import time
dummy_lag = jnp.linspace(0.0, float(self.peq) * 3.0, n_lag)
dummy_omega = jnp.linspace(0.0, 4.0 * float(np.pi / self.peq), n_lag)
t0 = time.time()
jax.block_until_ready(self.kernel(dummy_lag))
jax.block_until_ready(self.compute_psd(dummy_omega))
print(f"JAX kernel compiled in {np.round(time.time() - t0, 2)}s")
t0 = time.time()
jax.block_until_ready(self.kernel(dummy_lag))
jax.block_until_ready(self.compute_psd(dummy_omega))
print(f"JAX kernel recompute in {np.round(time.time() - t0, 2)}s")
return self
def __call__(self, lag, **kwargs):
"""Evaluate the kernel at the given lags."""
return self.kernel(lag, **kwargs)
class NonstationaryAnalyticKernel(AnalyticKernel):
"""
Non-stationary extension of AnalyticKernel with time-dependent σ_k.
The covariance between times t1 and t2 is:
K(t1, t2) = σ_k(t1) · σ_k(t2) · K_stationary(|t1 - t2|)
where K_stationary is the latitude-averaged kernel without the σ_k²
prefactor. This factorization guarantees positive semi-definiteness
for any non-negative σ_k(t).
Parameters
----------
model_or_hparam : SpotEvolutionModel or dict
Same as AnalyticKernel.
sigma_k_func : callable
Function mapping time (scalar or array) to σ_k values.
Signature: ``sigma_k_func(t) -> array_like``.
**kwargs
Forwarded to AnalyticKernel (n_harmonics, n_lat, etc.).
Examples
--------
>>> def activity_cycle(t, sigma0=0.01, amp=0.5, period=365.0):
... return sigma0 * (1 + amp * jnp.sin(2 * jnp.pi * t / period))
>>> nsk = NonstationaryAnalyticKernel(model, sigma_k_func=activity_cycle)
>>> K = nsk.kernel_matrix(t_obs)
"""
def __init__(self, model_or_hparam, sigma_k_func, **kwargs):
super().__init__(model_or_hparam, **kwargs)
self.sigma_k_func = sigma_k_func
def kernel_matrix(self, t, lat_dist=None):
"""
Build the full N×N covariance matrix for observation times.
Parameters
----------
t : array_like, shape (N,)
Observation times [days].
lat_dist : callable or None
Latitude probability density (forwarded to parent).
Returns
-------
K : ndarray, shape (N, N)
"""
t = jnp.asarray(t, dtype=float).ravel()
lag = jnp.abs(t[:, None] - t[None, :])
K_stat = self._kernel_stationary(lag, lat_dist=lat_dist)
sk = jnp.asarray(self.sigma_k_func(t))
K = sk[:, None] * sk[None, :] * K_stat
return np.asarray(K)
def kernel(self, lag, lat_dist=None):
"""
Stationary kernel using the constant ``self.sigma_k``.
Provided for backward compatibility — use ``kernel_matrix(t)``
for the non-stationary covariance.
"""
return super().kernel(lag, lat_dist=lat_dist)
def __call__(self, t, **kwargs):
"""Evaluate the non-stationary kernel matrix at observation times."""
return self.kernel_matrix(t, **kwargs)