"""
envelope.py — Spot envelope function hierarchy.
Defines the abstract base class EnvelopeFunction and five concrete
implementations: TrapezoidSymmetricEnvelope, TrapezoidAsymmetricEnvelope,
SkewedGaussianEnvelope, ExponentialEnvelope, and ExponentialAsymmetricEnvelope.
To define a custom envelope, subclass EnvelopeFunction and implement:
- tau_spot (property) : characteristic timescale [days] [REQUIRED]
- Gamma(t) : normalized envelope, peak = 1 [REQUIRED]
- param_dict : {name: value} of free parameters [optional, needed for GPSolver]
- lspot (property) : plateau duration [days] [optional, default 0]
- R_Gamma(lag) : autocorrelation [optional, default: FFT]
- Gamma_hat(omega) : |FT[Gamma]|(omega) [optional, default: FFT]
- kernel_support() : upper lag support [optional, default: lspot+6*tau_spot]
"""
from __future__ import annotations
import numpy as np
import jax
import jax.numpy as jnp
from abc import ABC, abstractmethod
try:
from .distributions import as_distribution, is_distributed
except ImportError:
from distributions import as_distribution, is_distributed
__all__ = [
"EnvelopeFunction",
"TrapezoidSymmetricEnvelope",
"TrapezoidAsymmetricEnvelope",
"SkewedGaussianEnvelope",
"ExponentialEnvelope",
"ExponentialAsymmetricEnvelope",
# low-level helpers (re-exported for backward compat with analytic_kernel)
"compute_R_Gamma_numerical",
"_Gamma_hat",
"_R_Gamma_symmetric",
"_R_Gamma_asymmetric",
"_skew_normal_envelope_func",
"_compute_Gamma_hat_sq_numerical",
]
# ── Low-level JAX helpers ───────────────────────────────────────────────────
@jax.jit
def _Gamma_hat(omega, ell, tau_spot):
"""
Fourier transform of the normalized squared envelope Gamma(t).
Uses safe_w = max(|omega|, eps) to avoid 1/w^3 singularity at omega=0.
"""
omega = jnp.asarray(omega, dtype=float)
safe_w = jnp.where(jnp.abs(omega) > 1e-14, omega, 1.0)
nz_result = (4 / (tau_spot**2 * safe_w**3) *
(tau_spot * safe_w * jnp.cos(safe_w * ell / 2)
+ jnp.sin(safe_w * ell / 2)
- jnp.sin(safe_w * ell / 2 + safe_w * tau_spot)))
zero_result = ell + 2 * tau_spot / 3
return jnp.where(jnp.abs(omega) > 1e-14, nz_result, zero_result)
@jax.jit
def _R_Gamma_symmetric(lag, ell, tau_s):
"""
Closed-form autocorrelation of the symmetric trapezoidal envelope.
Piecewise degree-5 polynomial on [0, ell + 2*tau_s], zero beyond.
Assumes ell/2 > tau_s.
"""
t = jnp.abs(jnp.asarray(lag, dtype=float).ravel())
R1 = (ell + 2*tau_s/5
- 4*t**2 / (3*tau_s)
+ 2*t**3 / (3*tau_s**2)
- t**5 / (15*tau_s**4))
R2 = ell + 2*tau_s/3 - t
R3 = (t**5 / (30*tau_s**4)
- (ell + 2*tau_s) * t**4 / (6*tau_s**4)
+ (ell**2 + 4*ell*tau_s + 2*tau_s**2) * t**3 / (3*tau_s**4)
- ell*(ell**2 + 6*ell*tau_s + 6*tau_s**2) * t**2 / (3*tau_s**4)
+ (ell**4 + 8*ell**3*tau_s + 12*ell**2*tau_s**2
- 6*tau_s**4) * t / (6*tau_s**4)
+ (-ell**5 - 10*ell**4*tau_s - 20*ell**3*tau_s**2
+ 30*ell*tau_s**4 + 20*tau_s**5) / (30*tau_s**4))
R4 = (ell + 2*tau_s - t)**5 / (30*tau_s**4)
return jnp.where(t <= tau_s, R1,
jnp.where(t <= ell, R2,
jnp.where(t <= ell + tau_s, R3,
jnp.where(t <= ell + 2*tau_s, R4,
0.0))))
@jax.jit
def _R_Gamma_asymmetric(lag, ell, te, td):
"""
Closed-form autocorrelation of the asymmetric trapezoidal envelope.
Assumes te <= td (enforced by caller via min/max swap).
Six intervals on [0, ell + te + td], zero beyond.
"""
t = jnp.abs(jnp.asarray(lag, dtype=float).ravel())
td2 = td**2
te2 = te**2
td2te2 = td2 * te2
ell2 = ell**2
ell3 = ell**3
R1 = (ell + (te + td) / 5
- 2 * (1/te + 1/td) / 3 * t**2
+ (1/te2 + 1/td2) / 3 * t**3
- (1/te**4 + 1/td**4) / 30 * t**5)
R2 = (ell + te/3 + td/5
- t / 2
- 2 * t**2 / (3 * td)
+ t**3 / (3 * td2)
- t**5 / (30 * td**4))
R3 = ell + (te + td) / 3 - t
R4 = (t**5 / (30 * td2te2)
- (ell + td + te) * t**4 / (6 * td2te2)
+ (ell2 + 2*ell*td + 2*ell*te + 2*td*te) * t**3 / (3 * td2te2)
- ell * (ell2 + 3*ell*td + 3*ell*te + 6*td*te) * t**2 / (3 * td2te2)
+ (ell**4 + 4*ell3*td + 4*ell3*te + 12*ell2*td*te
- 6*td2te2) * t / (6 * td2te2)
+ (-ell**5 - 5*ell**4*td - 5*ell**4*te - 20*ell3*td*te
+ 30*ell*td2te2 + 10*td**3*te2 + 10*td2*te**3) / (30 * td2te2))
R5 = (-t**3 / (3 * td2)
+ (ell + td + te/3) * t**2 / td2
- (6*ell2 + 12*ell*td + 4*ell*te + 6*td2 + 4*td*te + te2) * t / (6 * td2)
+ (ell3/3 + ell2*td + ell2*te/3 + ell*td2 + 2*ell*td*te/3
+ ell*te2/6 + td**3/3 + td2*te/3 + td*te2/6 + te**3/30) / td2)
D = ell + te + td - t
R6 = D**5 / (30 * td2te2)
return jnp.where(t <= te, R1,
jnp.where(t <= td, R2,
jnp.where(t <= ell, R3,
jnp.where(t <= ell + te, R4,
jnp.where(t <= ell + td, R5,
jnp.where(t <= ell + te + td, R6,
0.0))))))
[docs]
def compute_R_Gamma_numerical(envelope_func, tau_ref, n_grid=4096, extent=12.0):
"""
Compute R_Gamma(lag) = ∫ Gamma(t) · Gamma(t + lag) dt via FFT.
Parameters
----------
envelope_func : callable
f(t: np.ndarray) -> np.ndarray, the normalized envelope Gamma(t).
tau_ref : float
Reference timescale [days] setting grid extent and resolution.
n_grid : int
Number of time-grid points (default 4096).
extent : float
Grid half-width in units of tau_ref (default 12.0).
Returns
-------
lag_grid : jnp.ndarray, shape (n_grid,)
R_Gamma_vals : jnp.ndarray, shape (n_grid,)
"""
T = float(extent) * float(tau_ref)
t_np = np.linspace(-T, T, n_grid)
dt = float(t_np[1] - t_np[0])
env_np = np.asarray(envelope_func(t_np), dtype=np.float64)
env_np = np.maximum(env_np, 0.0)
env_fft = np.fft.rfft(env_np, n=2 * n_grid)
R_vals = np.fft.irfft(np.abs(env_fft) ** 2, n=2 * n_grid)[:n_grid] * dt
lag_grid = np.arange(n_grid, dtype=np.float64) * dt
return jnp.array(lag_grid), jnp.array(R_vals)
def _compute_Gamma_hat_sq_numerical(envelope_func, tau_ref, n_grid=4096, extent=12.0):
"""
Precompute |Gamma_hat(ω)|² for a numerical envelope (used by compute_psd).
Returns
-------
omega_grid : jnp.ndarray, shape (n_grid + 1,)
Gh_sq_vals : jnp.ndarray, shape (n_grid + 1,)
"""
T = float(extent) * float(tau_ref)
t_np = np.linspace(-T, T, n_grid)
dt = float(t_np[1] - t_np[0])
env_np = np.asarray(envelope_func(t_np), dtype=np.float64)
env_np = np.maximum(env_np, 0.0)
n_fft = 2 * n_grid
env_fft = np.fft.rfft(env_np, n=n_fft) * dt
Gh_sq = np.abs(env_fft) ** 2
omega_grid = 2.0 * np.pi * np.fft.rfftfreq(n_fft, d=dt)
return jnp.array(omega_grid), jnp.array(Gh_sq)
def _skew_normal_envelope_func(sigma_sn, n_sn):
"""
Return a callable for the normalized skew-normal envelope.
Implements Eq. (1) of Baranyi et al. (2021) A&A 653, A59:
Gamma(t) ∝ exp(-t²/(2σ²)) · (1 + erf(n·t / (σ·√2)))
n_sn < 0: rapid rise / slow decay.
n_sn > 0: slow rise / rapid decay.
n_sn = 0: symmetric Gaussian envelope.
"""
from scipy.special import erf as _scipy_erf
sigma = float(sigma_sn)
n = float(n_sn)
def _f(t):
z = np.asarray(t, dtype=np.float64) / sigma
env = np.exp(-z ** 2 / 2.0) * (1.0 + _scipy_erf(n * z / np.sqrt(2.0)))
env = np.maximum(env, 0.0)
peak = env.max()
return env / peak if peak > 0.0 else env
return _f
# ── Abstract base class ─────────────────────────────────────────────────────
[docs]
class EnvelopeFunction(ABC):
"""
Abstract base class for spot size envelope functions.
To define a custom envelope, subclass this and implement:
**Required:**
- ``tau_spot`` (property) : characteristic timescale [days]. Used to set
the extent of numerical integration grids.
- ``Gamma(t)`` : normalized envelope, peak = 1.
**Optional** (numerical defaults are provided):
- ``lspot`` (property) : plateau duration [days]. Default: 0.
- ``param_dict`` (property): ``{name: value}`` dict of envelope
parameters. Needed for ``GPSolver`` to know which parameters to
fit. Default: ``{}`` (kernel evaluation still works; fitting will
not expose envelope parameters).
- ``R_Gamma(lag)`` : autocorrelation ∫ Γ(t)Γ(t+lag)dt.
Default: computed via FFT from ``Gamma``.
- ``Gamma_hat(omega)`` : |FT[Gamma]|(ω).
Default: computed via FFT from ``Gamma``.
- ``Gamma_hat_sq(omega)`` : |FT[Gamma]|²(ω).
Default: ``Gamma_hat(omega) ** 2``.
- ``kernel_support()`` : upper lag bound where R_Gamma is negligible.
Default: ``lspot + 6 * tau_spot``.
The numerical defaults are computed **lazily**: the FFT grids are built
once on the first call and cached on the instance, so there is no cost for
subclasses that override these methods.
Example
-------
>>> class GaussianEnvelope(EnvelopeFunction):
... def __init__(self, sigma):
... self._sigma = float(sigma)
... @property
... def tau_spot(self):
... return self._sigma
... @property
... def param_dict(self):
... return {"tau_spot": self._sigma}
... def Gamma(self, t):
... import jax.numpy as jnp
... return jnp.exp(-0.5 * (t / self._sigma) ** 2)
"""
# ── Required ─────────────────────────────────────────────────────────────
@property
@abstractmethod
def tau_spot(self) -> float:
"""Characteristic (scalar) timescale [days]."""
[docs]
@abstractmethod
def Gamma(self, t):
"""
Normalized spot-size envelope evaluated at relative times t.
t = 0 is the center/peak of the envelope. Must return values in
[0, 1] with a peak of 1. Should be JAX-compatible (jnp operations)
so that it can be evaluated inside JIT-compiled code.
"""
# ── Optional with defaults ────────────────────────────────────────────────
@property
def lspot(self) -> float:
"""Spot plateau duration [days]. Default 0 (no plateau)."""
return 0.0
@property
def param_dict(self) -> dict:
"""
Envelope parameters as ``{name: value}``.
Override this to expose envelope parameters to ``GPSolver`` and
``SpotEvolutionModel``. The default returns an empty dict, which means
the envelope shape is fixed (not inferred during GP fitting).
"""
return {}
[docs]
def kernel_support(self) -> float:
"""
Upper bound on the lag support of R_Gamma [days].
Used by ``GPSolver`` to compute the banded-Cholesky bandwidth.
Default: ``lspot + 6 * tau_spot``. Override for tighter bounds.
"""
return self.lspot + 6.0 * self.tau_spot
[docs]
def Gamma_integral(self) -> float:
r"""
Integral of the squared-size envelope: :math:`\int \Gamma(t)\,dt`.
Used by ``AnalyticMean`` to compute the expected flux deficit.
Default: numerical quadrature from ``Gamma(t)``. Override with
a closed-form expression for better accuracy.
"""
import numpy as _np
T = 6.0 * self.tau_spot + self.lspot
t_grid = _np.linspace(-T, T, 4096)
gamma_vals = _np.asarray(self.Gamma(jnp.array(t_grid)))
return float(_np.trapz(gamma_vals, t_grid))
# ── Lazy numerical grid helpers ───────────────────────────────────────────
def _ensure_numerical_grids(self):
"""
Build R_Gamma and |Gamma_hat|² grids from ``Gamma`` via FFT.
Called automatically by the default ``R_Gamma``, ``Gamma_hat``, and
``Gamma_hat_sq`` methods. Results are cached on the instance so the
FFT is only computed once.
"""
if not hasattr(self, '_num_R_lag_grid'):
env_func = lambda t_arr: np.asarray(
self.Gamma(jnp.array(t_arr)), dtype=np.float64)
self._num_R_lag_grid, self._num_R_vals = \
compute_R_Gamma_numerical(env_func, self.tau_spot)
self._num_Gh_omega_grid, self._num_Gh_sq_vals = \
_compute_Gamma_hat_sq_numerical(env_func, self.tau_spot)
# ── Default implementations ───────────────────────────────────────────────
[docs]
def R_Gamma(self, lag):
"""
Autocorrelation R_Gamma(lag) = ∫ Gamma(t) · Gamma(t + lag) dt.
Default: interpolated from an FFT-based precomputed grid.
Override with an analytic expression for better performance.
"""
self._ensure_numerical_grids()
lag_abs = jnp.abs(jnp.asarray(lag, dtype=float).ravel())
return jnp.interp(lag_abs, self._num_R_lag_grid, self._num_R_vals)
[docs]
def Gamma_hat(self, omega):
"""
Fourier transform magnitude |FT[Gamma]|(omega).
Default: interpolated from an FFT-based precomputed grid.
Override with a closed-form expression for better performance.
"""
self._ensure_numerical_grids()
omega = jnp.asarray(omega, dtype=float)
return jnp.interp(
jnp.abs(omega),
self._num_Gh_omega_grid,
jnp.sqrt(self._num_Gh_sq_vals),
)
[docs]
def Gamma_hat_sq(self, omega):
"""
|FT[Gamma](omega)|² = Gamma_hat(omega)².
Default: interpolated from an FFT-based precomputed grid.
"""
self._ensure_numerical_grids()
omega = jnp.asarray(omega, dtype=float)
return jnp.interp(
jnp.abs(omega),
self._num_Gh_omega_grid,
self._num_Gh_sq_vals,
)
# ── Sympy analytic expressions (optional overrides) ───────────────────────
[docs]
def sympy_Gamma(self):
"""
Sympy expression for Gamma(t).
Override in subclasses that have a closed-form envelope.
Returns None if no analytic expression is available.
"""
return None
[docs]
def sympy_Gamma_hat(self):
"""
Sympy expression for Gamma_hat(omega) = FT[Gamma](omega).
Override in subclasses with a closed-form Fourier transform.
Returns None if no analytic form is available.
"""
return None
[docs]
def sympy_R_Gamma(self):
"""
Sympy expression for R_Gamma(tau) = integral Gamma(t) Gamma(t+tau) dt.
Override in subclasses with a closed-form autocorrelation.
Returns None if no analytic form is available.
"""
return None
def _compute_Gamma_hat_symbolic(self, gamma_expr):
"""Attempt to derive Gamma_hat from gamma_expr via symbolic integration."""
if gamma_expr is None:
return None
try:
import sympy as sp
t = sp.Symbol('t', real=True)
omega = sp.Symbol(r'\omega', real=True)
result = sp.integrate(
gamma_expr * sp.exp(-sp.I * omega * t), (t, -sp.oo, sp.oo))
return None if result.has(sp.Integral) else result
except Exception:
return None
def _compute_R_Gamma_symbolic(self, gamma_expr):
"""Attempt to derive R_Gamma from gamma_expr via symbolic integration."""
if gamma_expr is None:
return None
try:
import sympy as sp
t = sp.Symbol('t', real=True)
tau = sp.Symbol(r'\tau', real=True)
result = sp.integrate(
gamma_expr * gamma_expr.subs(t, t + tau), (t, -sp.oo, sp.oo))
return None if result.has(sp.Integral) else result
except Exception:
return None
[docs]
def get_sympy(self, display=True, status=None, compute_symbolic=False):
"""
Retrieve and display sympy expressions for Gamma, Gamma_hat, and R_Gamma.
Prints or renders the LaTeX equation for each function that has a
closed-form analytic expression, or '[numerical]' for functions that
rely on FFT-based approximations.
Requires sympy (``pip install sympy``).
Parameters
----------
display : bool, optional
If True (default), render equations as formatted LaTeX in a
Jupyter notebook (via IPython.display) or print them as LaTeX
strings in a plain terminal.
status : str or None, optional
If provided, appended to the class name header in brackets,
e.g. ``"user defined"`` renders as
``TrapezoidSymmetricEnvelope [user defined]``.
compute_symbolic : bool, optional
If True, attempt to derive Gamma_hat and R_Gamma symbolically
from sympy_Gamma() when no explicit override exists. Can be
slow for complex envelopes. Defaults to False.
Returns
-------
dict
``{"Gamma": expr_or_None, "Gamma_hat": expr_or_None,
"R_Gamma": expr_or_None}``
"""
try:
import sympy as sp
except ImportError:
raise ImportError(
"sympy is required for get_sympy(). "
"Install with: pip install sympy")
gamma_expr = self.sympy_Gamma()
# Use explicit override if subclass provides one; otherwise optionally
# derive from sympy_Gamma via symbolic integration.
if type(self).sympy_Gamma_hat is not EnvelopeFunction.sympy_Gamma_hat:
gamma_hat_expr = self.sympy_Gamma_hat()
elif compute_symbolic:
gamma_hat_expr = self._compute_Gamma_hat_symbolic(gamma_expr)
else:
gamma_hat_expr = None
if type(self).sympy_R_Gamma is not EnvelopeFunction.sympy_R_Gamma:
R_gamma_expr = self.sympy_R_Gamma()
elif compute_symbolic:
R_gamma_expr = self._compute_R_Gamma_symbolic(gamma_expr)
else:
R_gamma_expr = None
exprs = {
"Gamma": gamma_expr,
"Gamma_hat": gamma_hat_expr,
"R_Gamma": R_gamma_expr,
}
lhs = {
"Gamma": r"\Gamma(t)",
"Gamma_hat": r"\hat{\Gamma}(\omega)",
"R_Gamma": r"R_{\Gamma}(\tau)",
}
# Integral definitions shown when no closed-form analytic expression exists
t = sp.Symbol('t', real=True)
omega = sp.Symbol(r'\omega', real=True)
tau = sp.Symbol(r'\tau', real=True)
Gamma = sp.Function(r'\Gamma')
integral_forms = {
"Gamma_hat": sp.Integral(Gamma(t) * sp.exp(-sp.I * omega * t),
(t, -sp.oo, sp.oo)),
"R_Gamma": sp.Integral(Gamma(t) * Gamma(t + tau),
(t, 0, sp.oo)),
}
def _rhs_latex(key, expr):
if expr is not None:
return sp.latex(expr)
integral = integral_forms.get(key)
if integral is not None:
return sp.latex(integral) + r" \quad \text{[numerical]}"
return r"\text{[numerical]}"
if display:
status_tag = r" \text{[" + status + r"]}" if status else ""
header = r"\textbf{" + type(self).__name__ + r"}" + status_tag
try:
from IPython.display import display as ipy_display, Math
ipy_display(Math(header))
for key, expr in exprs.items():
ipy_display(Math(lhs[key] + " = " + _rhs_latex(key, expr)))
except ImportError:
status_str = f" [{status}]" if status else ""
print(f"{type(self).__name__}{status_str}")
for key, expr in exprs.items():
print(f" ${lhs[key]} = {_rhs_latex(key, expr)}$")
return exprs
[docs]
def check_functions(self, n_pts=300, ax=None, show=False):
"""
Compare analytic overrides of Gamma_hat and R_Gamma to the
FFT-based numerical implementations.
Checks whether the subclass has provided analytic overrides for
``Gamma_hat(omega)`` and/or ``R_Gamma(lag)``. For each override
found, evaluates both the analytic and numerical versions on a fine
grid, plots the comparison, and computes the RMSE and max absolute
error (both normalized by the numerical peak).
Parameters
----------
n_pts : int, optional
Number of evaluation points on each grid. Default 300.
ax : matplotlib Axes or array of Axes, optional
Axes to plot on. If None, a new figure is created sized to
the number of overridden functions.
show : bool, optional
If True, call ``plt.show()`` after plotting. Default False.
Returns
-------
errors : dict
Keys are the names of overridden functions (``"Gamma_hat"``
and/or ``"R_Gamma"``). Each value is a dict with:
- ``"rmse"`` : root-mean-square error (normalized by peak)
- ``"max_err"`` : maximum absolute error (normalized by peak)
"""
import matplotlib.pyplot as plt
# Detect which methods the subclass has overridden
has_analytic = {
"Gamma_hat": type(self).Gamma_hat is not EnvelopeFunction.Gamma_hat,
"R_Gamma": type(self).R_Gamma is not EnvelopeFunction.R_Gamma,
}
overridden = [name for name, flag in has_analytic.items() if flag]
if not overridden:
print(f"{type(self).__name__}: no analytic overrides found for "
"Gamma_hat or R_Gamma — nothing to check.")
return {}
# Ensure the numerical grids are built before we temporarily bypass them
self._ensure_numerical_grids()
# Evaluation grids
lag_max = self.kernel_support()
omega_max = 2.0 * np.pi / self.tau_spot * 5.0
grids = {
"Gamma_hat": np.linspace(0.0, omega_max, n_pts),
"R_Gamma": np.linspace(0.0, lag_max, n_pts),
}
# Numerical baselines (always from the FFT grids, bypassing any override)
def _numerical_R_Gamma(lag_arr):
lag_abs = np.abs(lag_arr)
return np.interp(lag_abs, self._num_R_lag_grid, self._num_R_vals)
def _numerical_Gamma_hat(omega_arr):
return np.interp(
np.abs(omega_arr),
self._num_Gh_omega_grid,
np.sqrt(self._num_Gh_sq_vals),
)
numerical_funcs = {
"Gamma_hat": _numerical_Gamma_hat,
"R_Gamma": _numerical_R_Gamma,
}
analytic_funcs = {
"Gamma_hat": lambda x: np.asarray(self.Gamma_hat(jnp.array(x))),
"R_Gamma": lambda x: np.asarray(self.R_Gamma(jnp.array(x))),
}
xlabels = {
"Gamma_hat": r"$\omega$ [rad/day]",
"R_Gamma": r"Lag $\tau$ [days]",
}
ylabels = {
"Gamma_hat": r"$\hat{\Gamma}(\omega)$",
"R_Gamma": r"$R_{\Gamma}(\tau)$",
}
n_panels = len(overridden)
if ax is None:
fig, axes = plt.subplots(1, n_panels, figsize=(6 * n_panels, 4))
if n_panels == 1:
axes = [axes]
else:
axes = np.atleast_1d(ax)
import time
errors = {}
for panel_ax, name in zip(axes, overridden):
x = grids[name]
analytic_funcs[name](x[:10]) # warmup: trigger JAX JIT compilation
t0 = time.perf_counter()
y_analytic = analytic_funcs[name](x)
t_analytic = time.perf_counter() - t0
t0 = time.perf_counter()
y_numerical = numerical_funcs[name](x)
t_numerical = time.perf_counter() - t0
peak = float(np.max(np.abs(y_numerical)))
if peak == 0.0:
peak = 1.0
residuals = y_analytic - y_numerical
rmse = float(np.sqrt(np.mean(residuals ** 2))) / peak
max_err = float(np.max(np.abs(residuals))) / peak
errors[name] = {"rmse": rmse, "max_err": max_err}
print(f"{name} RMSE = {rmse:.2e}, max err = {max_err:.2e}")
print(f" analytic time = {t_analytic * 1e3:.3f} ms")
print(f" numerical time = {t_numerical * 1e3:.3f} ms")
panel_ax.plot(x, y_numerical, color="k", lw=2, label="Numerical (FFT)")
panel_ax.plot(x, y_analytic, color="r", lw=1.5, ls="--", label="Analytic (user)")
panel_ax.set_xlabel(xlabels[name], fontsize=13)
panel_ax.set_ylabel(ylabels[name], fontsize=13)
panel_ax.set_title(
f"{name} — RMSE = {rmse:.2e}, max err = {max_err:.2e}",
fontsize=12,
)
panel_ax.legend(fontsize=11)
if n_panels > 0 and ax is None:
fig.suptitle(f"{type(self).__name__}.check_functions()", fontsize=13)
fig.tight_layout()
if show:
plt.show()
return errors
# ── Concrete implementations ────────────────────────────────────────────────
[docs]
class TrapezoidSymmetricEnvelope(EnvelopeFunction):
"""
Symmetric trapezoidal envelope.
Shape: linear rise over tau_spot, plateau of lspot, linear decay over tau_spot.
Parameters
----------
lspot : float or ParameterDistribution
Plateau duration [days].
tau_spot : float or ParameterDistribution
Rise/decay timescale [days].
When either parameter is a ``ParameterDistribution``, ``R_Gamma``
returns the marginalized autocorrelation integrated over the
distribution(s). The ``lspot`` and ``tau_spot`` properties return
the distribution means for backward compatibility.
"""
def __init__(self, lspot, tau_spot):
self._lspot_dist = as_distribution(lspot)
self._tau_spot_dist = as_distribution(tau_spot)
self._is_marginalized = (
is_distributed(self._lspot_dist)
or is_distributed(self._tau_spot_dist)
)
@property
def tau_spot(self) -> float:
return self._tau_spot_dist.mean
@property
def lspot(self) -> float:
return self._lspot_dist.mean
@property
def lspot_distribution(self):
"""The full ParameterDistribution for lspot."""
return self._lspot_dist
@property
def tau_spot_distribution(self):
"""The full ParameterDistribution for tau_spot."""
return self._tau_spot_dist
@property
def param_dict(self) -> dict:
return {"lspot": self.lspot, "tau_spot": self.tau_spot}
[docs]
def Gamma(self, t):
t = jnp.asarray(t, dtype=float)
half = self.lspot / 2.0
tau_spot = self.tau_spot
return jnp.where(
t < -(half + tau_spot), 0.0,
jnp.where(
t < -half, (t + half + tau_spot) / tau_spot,
jnp.where(
t <= half, 1.0,
jnp.where(
t < half + tau_spot, (half + tau_spot - t) / tau_spot,
0.0))))
[docs]
def Gamma_hat(self, omega):
return _Gamma_hat(jnp.asarray(omega, dtype=float), self.lspot, self.tau_spot)
[docs]
def Gamma_hat_sq(self, omega):
gh = _Gamma_hat(jnp.asarray(omega, dtype=float), self.lspot, self.tau_spot)
return gh ** 2
[docs]
def R_Gamma(self, lag):
lag = jnp.asarray(lag)
if not self._is_marginalized:
return _R_Gamma_symmetric(lag, self.lspot, self.tau_spot)
return self._marginalized_R_Gamma(lag)
def _marginalized_R_Gamma(self, lag, n_quad=16):
"""
R_Gamma averaged over the distributions of lspot and/or tau_spot.
Uses a tensor-product quadrature grid: for each combination
(lspot_i, tau_j), evaluate R_Gamma and take the weighted average.
"""
# Build 1D quadrature for each parameter
l_dist = self._lspot_dist
t_dist = self._tau_spot_dist
def _quad_1d(dist, n):
lo, hi = dist.support
if lo == hi:
return np.array([lo]), np.array([1.0])
nodes, weights = np.polynomial.legendre.leggauss(n)
x = 0.5 * (hi - lo) * nodes + 0.5 * (hi + lo)
w = 0.5 * (hi - lo) * weights
pdf = np.array([dist(float(xi)) for xi in x])
w = w * pdf
w = w / np.sum(w)
return x, w
l_pts, l_wts = _quad_1d(l_dist, n_quad)
t_pts, t_wts = _quad_1d(t_dist, n_quad)
R_sum = jnp.zeros_like(lag)
for i, (li, lw) in enumerate(zip(l_pts, l_wts)):
for j, (tj, tw) in enumerate(zip(t_pts, t_wts)):
R_sum = R_sum + float(lw * tw) * _R_Gamma_symmetric(
lag, float(li), float(tj))
return R_sum
[docs]
def kernel_support(self) -> float:
# Use upper end of distributions for conservative support
if self._is_marginalized:
l_hi = self._lspot_dist.support[1]
t_hi = self._tau_spot_dist.support[1]
return l_hi + 2.0 * t_hi
return self.lspot + 2.0 * self.tau_spot
[docs]
def Gamma_integral(self) -> float:
return self.lspot + 2.0 * self.tau_spot / 3.0
[docs]
def sympy_Gamma(self):
import sympy as sp
t = sp.Symbol('t', real=True)
ell = sp.Symbol(r'\ell', positive=True)
tau = sp.Symbol(r'\tau_{\rm spot}', positive=True)
half = ell / 2
alpha_norm = sp.Piecewise(
(sp.Integer(0), t < -(half + tau)),
((t + half + tau) / tau, t < -half),
(sp.Integer(1), t <= half),
((half + tau - t) / tau, t < half + tau),
(sp.Integer(0), True),
)
return alpha_norm ** 2
[docs]
def sympy_Gamma_hat(self):
import sympy as sp
omega = sp.Symbol(r'\omega', real=True)
ell = sp.Symbol(r'\ell', positive=True)
tau = sp.Symbol(r'\tau_{\rm spot}', positive=True)
return (4 / (tau**2 * omega**3) * (
tau * omega * sp.cos(omega * ell / 2)
+ sp.sin(omega * ell / 2)
- sp.sin(omega * ell / 2 + omega * tau)))
[docs]
class TrapezoidAsymmetricEnvelope(EnvelopeFunction):
"""
Asymmetric trapezoidal envelope with distinct emergence and decay rates.
Shape: linear rise over tau_em, plateau of lspot, linear decay over tau_dec.
Parameters
----------
lspot : float
Plateau duration [days].
tau_em : float
Emergence timescale [days].
tau_dec : float
Decay timescale [days].
"""
def __init__(self, lspot: float, tau_em: float, tau_dec: float):
self._lspot = float(lspot)
self._tau_em = float(tau_em)
self._tau_dec = float(tau_dec)
# Precompute numerical grids once (base-class defaults use these)
self._ensure_numerical_grids()
@property
def tau_spot(self) -> float:
return (self._tau_em + self._tau_dec) / 2.0
@property
def tau_em(self) -> float:
return self._tau_em
@property
def tau_dec(self) -> float:
return self._tau_dec
@property
def lspot(self) -> float:
return self._lspot
@property
def param_dict(self) -> dict:
return {
"lspot": self._lspot,
"tau_em": self._tau_em,
"tau_dec": self._tau_dec,
}
[docs]
def Gamma(self, t):
t = jnp.asarray(t, dtype=float)
half = self._lspot / 2.0
te, td = self._tau_em, self._tau_dec
return jnp.where(
t < -(half + te), 0.0,
jnp.where(
t < -half, (t + half + te) / te,
jnp.where(
t <= half, 1.0,
jnp.where(
t < half + td, (half + td - t) / td,
0.0))))
[docs]
def R_Gamma(self, lag):
te = min(self._tau_em, self._tau_dec)
td = max(self._tau_em, self._tau_dec)
return _R_Gamma_asymmetric(jnp.asarray(lag), self._lspot, te, td)
[docs]
def kernel_support(self) -> float:
return self._lspot + self._tau_em + self._tau_dec
[docs]
def Gamma_integral(self) -> float:
return self._lspot + self._tau_em / 3.0 + self._tau_dec / 3.0
[docs]
def sympy_Gamma(self):
import sympy as sp
t = sp.Symbol('t', real=True)
ell = sp.Symbol(r'\ell', positive=True)
te = sp.Symbol(r'\tau_{\rm em}', positive=True)
td = sp.Symbol(r'\tau_{\rm dec}', positive=True)
half = ell / 2
alpha_norm = sp.Piecewise(
(sp.Integer(0), t < -(half + te)),
((t + half + te) / te, t < -half),
(sp.Integer(1), t <= half),
((half + td - t) / td, t < half + td),
(sp.Integer(0), True),
)
return alpha_norm ** 2
[docs]
class SkewedGaussianEnvelope(EnvelopeFunction):
"""
Skew-normal (Baranyi et al. 2021) envelope.
Implements Eq. (1) of Baranyi et al. (2021) A&A 653, A59:
Gamma(t) ∝ exp(-t²/(2σ²)) · (1 + erf(n·t / (σ·√2)))
Parameters
----------
sigma_sn : float
Scale parameter [days].
n_sn : float
Skewness (dimensionless). n_sn < 0: rapid rise / slow decay;
n_sn > 0: slow rise / rapid decay; n_sn = 0: Gaussian.
lspot : float, optional
Unused (required by base schema); set to 0 (default).
"""
def __init__(self, sigma_sn: float, n_sn: float, lspot: float = 0.0):
self._sigma_sn = float(sigma_sn)
self._n_sn = float(n_sn)
self._lspot = float(lspot)
_env_func = _skew_normal_envelope_func(sigma_sn, n_sn)
# Precompute R_Gamma and |Gamma_hat|² on fine grids for interpolation
self._R_lag_grid, self._R_vals = compute_R_Gamma_numerical(
_env_func, tau_ref=sigma_sn)
self._Gh_omega_grid, self._Gh_sq_vals = _compute_Gamma_hat_sq_numerical(
_env_func, tau_ref=sigma_sn)
# Also precompute Gamma itself on a t-grid for JAX-traceable Gamma(t)
T = 12.0 * sigma_sn
t_grid_np = np.linspace(-T, T, 4096)
gamma_np = _env_func(t_grid_np)
self._t_grid = jnp.array(t_grid_np)
self._Gamma_vals = jnp.array(gamma_np)
@property
def tau_spot(self) -> float:
return self._sigma_sn
@property
def sigma_sn(self) -> float:
return self._sigma_sn
@property
def n_sn(self) -> float:
return self._n_sn
@property
def lspot(self) -> float:
return self._lspot
@property
def param_dict(self) -> dict:
return {
"sigma_sn": self._sigma_sn,
"n_sn": self._n_sn,
"lspot": self._lspot,
}
[docs]
def Gamma(self, t):
"""JAX-traceable Gamma(t) via interpolation from precomputed grid."""
t = jnp.asarray(t, dtype=float)
return jnp.interp(t, self._t_grid, self._Gamma_vals)
[docs]
def Gamma_hat(self, omega):
omega = jnp.asarray(omega, dtype=float)
return jnp.interp(
jnp.abs(omega), self._Gh_omega_grid,
jnp.sqrt(self._Gh_sq_vals))
[docs]
def Gamma_hat_sq(self, omega):
omega = jnp.asarray(omega, dtype=float)
return jnp.interp(jnp.abs(omega), self._Gh_omega_grid, self._Gh_sq_vals)
[docs]
def R_Gamma(self, lag):
lag_abs = jnp.abs(jnp.asarray(lag, dtype=float).ravel())
return jnp.interp(lag_abs, self._R_lag_grid, self._R_vals)
[docs]
def kernel_support(self) -> float:
return 12.0 * self._sigma_sn
[docs]
class ExponentialEnvelope(EnvelopeFunction):
"""
Bilateral exponential (double-sided) envelope.
Gamma(t) = exp(-|t| / tau_spot)
This gives a spot that is at peak at t = 0 and decays symmetrically
with characteristic timescale tau_spot. There is no plateau (lspot = 0).
Analytical results:
Gamma_hat(omega) = 2*tau_spot / (1 + (omega*tau_spot)²) [Lorentzian]
R_Gamma(lag) = (tau_spot + |lag|) * exp(-|lag| / tau_spot)
Parameters
----------
tau_spot : float
Decay timescale [days].
"""
def __init__(self, tau_spot: float):
self._tau_spot = float(tau_spot)
@property
def tau_spot(self) -> float:
return self._tau_spot
@property
def lspot(self) -> float:
return 0.0
@property
def param_dict(self) -> dict:
return {"tau_spot": self._tau_spot}
[docs]
def Gamma(self, t):
t = jnp.asarray(t, dtype=float)
return jnp.exp(-jnp.abs(t) / self._tau_spot)
[docs]
def Gamma_hat(self, omega):
"""|FT[Gamma]| = 2*tau_spot / (1 + (omega*tau_spot)²) (Lorentzian)."""
omega = jnp.asarray(omega, dtype=float)
return 2.0 * self._tau_spot / (1.0 + (omega * self._tau_spot) ** 2)
[docs]
def Gamma_hat_sq(self, omega):
gh = self.Gamma_hat(omega)
return gh ** 2
[docs]
def R_Gamma(self, lag):
"""R_Gamma(lag) = (tau_spot + |lag|) * exp(-|lag| / tau_spot)."""
abs_lag = jnp.abs(jnp.asarray(lag, dtype=float))
return (self._tau_spot + abs_lag) * jnp.exp(-abs_lag / self._tau_spot)
[docs]
def kernel_support(self) -> float:
return 6.0 * self._tau_spot
[docs]
def Gamma_integral(self) -> float:
return 2.0 * self._tau_spot
[docs]
def sympy_Gamma(self):
import sympy as sp
t = sp.Symbol('t', real=True)
tau = sp.Symbol(r'\tau_{\rm spot}', positive=True)
return sp.exp(-sp.Abs(t) / tau)
[docs]
def sympy_Gamma_hat(self):
import sympy as sp
omega = sp.Symbol(r'\omega', real=True)
tau = sp.Symbol(r'\tau_{\rm spot}', positive=True)
return 2 * tau / (1 + (omega * tau)**2)
[docs]
def sympy_R_Gamma(self):
import sympy as sp
tau_lag = sp.Symbol(r'\tau', real=True)
tau = sp.Symbol(r'\tau_{\rm spot}', positive=True)
return (tau + sp.Abs(tau_lag)) * sp.exp(-sp.Abs(tau_lag) / tau)
[docs]
class ExponentialAsymmetricEnvelope(EnvelopeFunction):
"""
Asymmetric exponential envelope with separate rise and decay timescales.
Gamma(t) = exp( t / tau_em) for t < 0 (emergence / rise)
Gamma(t) = exp(-t / tau_dec) for t >= 0 (decay)
The spot emerges with timescale tau_em and decays with timescale tau_dec.
There is no plateau (lspot = 0).
Analytical Fourier transform:
Gamma_hat(omega) = tau_em / (1 + (omega * tau_em)^2)
+ tau_dec / (1 + (omega * tau_dec)^2)
Parameters
----------
tau_em : float
Emergence (rise) timescale [days].
tau_dec : float
Decay timescale [days].
"""
def __init__(self, tau_em: float, tau_dec: float):
self._tau_em = float(tau_em)
self._tau_dec = float(tau_dec)
@property
def tau_spot(self) -> float:
"""Effective timescale: max of the two timescales."""
return max(self._tau_em, self._tau_dec)
@property
def lspot(self) -> float:
return 0.0
@property
def param_dict(self) -> dict:
return {"tau_em": self._tau_em, "tau_dec": self._tau_dec}
[docs]
def Gamma(self, t):
t = jnp.asarray(t, dtype=float)
return jnp.where(t < 0,
jnp.exp(t / self._tau_em),
jnp.exp(-t / self._tau_dec))
[docs]
def Gamma_hat(self, omega):
"""Sum of two Lorentzians (one per timescale)."""
omega = jnp.asarray(omega, dtype=float)
return (self._tau_em / (1.0 + (omega * self._tau_em) ** 2)
+ self._tau_dec / (1.0 + (omega * self._tau_dec) ** 2))
[docs]
def kernel_support(self) -> float:
return 6.0 * (self._tau_em + self._tau_dec)
[docs]
def Gamma_integral(self) -> float:
return self._tau_em + self._tau_dec