"""
spot_model.py — SpotEvolutionModel.
SpotEvolutionModel combines an EnvelopeFunction and a VisibilityFunction
with an amplitude parameter (sigma_k) to fully describe the statistical
spot evolution model used by AnalyticKernel, NumericalKernel, GPSolver,
and LightcurveModel.
LatitudeDistributionFunction, VisibilityFunction, EdgeOnVisibilityFunction,
and low-level helpers are re-exported here for backward compatibility.
"""
from __future__ import annotations
import numpy as np
_UNSET = object() # sentinel to distinguish "not passed" from explicit None
import jax.numpy as jnp
try:
from .distributions import as_distribution, is_distributed, DeltaDistribution
except ImportError:
from distributions import as_distribution, is_distributed, DeltaDistribution
try:
from .envelope import (
EnvelopeFunction,
TrapezoidSymmetricEnvelope,
TrapezoidAsymmetricEnvelope,
SkewedGaussianEnvelope,
ExponentialEnvelope,
_R_Gamma_symmetric,
_R_Gamma_asymmetric,
)
from .params import resolve_hparam
from .latitude import LatitudeDistributionFunction, UniformDoubleHemisphereBand
from .visibility import (
VisibilityFunction,
EdgeOnVisibilityFunction,
FullGeometryVisibilityFunction,
_cn_general_jax,
_cn_squared_coefficients_jax,
_gauss_legendre_grid,
)
except ImportError:
from envelope import (
EnvelopeFunction,
TrapezoidSymmetricEnvelope,
TrapezoidAsymmetricEnvelope,
SkewedGaussianEnvelope,
ExponentialEnvelope,
_R_Gamma_symmetric,
_R_Gamma_asymmetric,
)
from params import resolve_hparam
from latitude import LatitudeDistributionFunction, UniformDoubleHemisphereBand
from visibility import (
VisibilityFunction,
EdgeOnVisibilityFunction,
FullGeometryVisibilityFunction,
_cn_general_jax,
_cn_squared_coefficients_jax,
_gauss_legendre_grid,
)
__all__ = [
# Re-exported for backward compatibility
"LatitudeDistributionFunction",
"VisibilityFunction",
"EdgeOnVisibilityFunction",
"FullGeometryVisibilityFunction",
"_cn_general_jax",
"_cn_squared_coefficients_jax",
"_gauss_legendre_grid",
# Defined here
"SpotEvolutionModel",
]
# ── SpotEvolutionModel ──────────────────────────────────────────────────────
[docs]
class SpotEvolutionModel:
"""
Complete statistical spot evolution model.
Combines an EnvelopeFunction (spot size evolution) with a
VisibilityFunction (stellar rotation and inclination) and a kernel
amplitude parameter sigma_k.
Parameters
----------
envelope : EnvelopeFunction or None
Spot size envelope (e.g. TrapezoidSymmetricEnvelope). When None
the kernel contains only the visibility function (R_Gamma = 1).
visibility : VisibilityFunction or None
Stellar visibility function (peq, kappa, inc). When None the
kernel contains only the envelope function.
sigma_k : float, optional
Kernel amplitude prefactor. Provide either sigma_k directly or
(nspot_rate, fspot, alpha_max) for the physical parameterization.
nspot_rate : float, optional
Spot emergence rate [spots/day]. Used when sigma_k is not given.
fspot : float, optional
Spot contrast fraction (default 0).
alpha_max : float, optional
Peak spot angular radius [rad]. Used when sigma_k is not given.
latitude_distribution : LatitudeDistributionFunction, optional
Latitude probability density for spot placement and kernel
integration. Defaults to a uniform distribution over
[-pi/2, pi/2].
Notes
-----
Exactly one of the following amplitude specifications must be supplied:
- ``sigma_k`` directly,
- ``(nspot_rate, a_spot)`` where ``a_spot = (1 - fspot) * alpha_max**2``,
- ``(nspot_rate, fspot, alpha_max)`` (physical parameterization).
"""
def __init__(
self,
envelope: EnvelopeFunction = _UNSET,
visibility: VisibilityFunction = _UNSET,
sigma_k: float = None,
nspot_rate: float = None,
fspot: float = 0.0,
alpha_max: float = None,
a_spot: float = None,
latitude_distribution: LatitudeDistributionFunction = _UNSET,
):
# Resolve each component and record its provenance for get_sympy()
if envelope is _UNSET:
self.envelope = None
self._envelope_status = "not specified"
elif envelope is None:
self.envelope = None
self._envelope_status = "not specified"
else:
if not isinstance(envelope, EnvelopeFunction):
raise TypeError(
f"envelope must be an EnvelopeFunction, got {type(envelope)}")
self.envelope = envelope
self._envelope_status = "user defined"
if visibility is _UNSET:
self.visibility = None
self._visibility_status = "not specified"
elif visibility is None:
self.visibility = None
self._visibility_status = "not specified"
else:
if not isinstance(visibility, VisibilityFunction):
raise TypeError(
f"visibility must be a VisibilityFunction, got {type(visibility)}")
self.visibility = visibility
self._visibility_status = "user defined"
if latitude_distribution is _UNSET:
self.latitude_distribution = LatitudeDistributionFunction()
self._latitude_status = "default"
elif latitude_distribution is None:
self.latitude_distribution = LatitudeDistributionFunction()
self._latitude_status = "not specified"
else:
if not isinstance(latitude_distribution, LatitudeDistributionFunction):
raise TypeError(
f"latitude_distribution must be a LatitudeDistributionFunction, "
f"got {type(latitude_distribution)}")
self.latitude_distribution = latitude_distribution
self._latitude_status = "user defined"
self.fspot = float(fspot)
self.alpha_max = float(alpha_max) if alpha_max is not None else None
if sigma_k is not None:
self._sigma_k_dist = as_distribution(sigma_k)
self._nspot_rate = None
self._a_spot = None
elif a_spot is not None and nspot_rate is not None:
self._nspot_rate = float(nspot_rate)
self._a_spot = float(a_spot)
computed = float(np.sqrt(float(nspot_rate))) * float(a_spot)
self._sigma_k_dist = as_distribution(computed)
elif nspot_rate is not None and alpha_max is not None:
self._nspot_rate = float(nspot_rate)
self._a_spot = (1.0 - float(fspot)) * float(alpha_max) ** 2
computed = float(np.sqrt(float(nspot_rate))) * self._a_spot
self._sigma_k_dist = as_distribution(computed)
else:
raise ValueError(
"SpotEvolutionModel requires either sigma_k, "
"(nspot_rate, a_spot), or (nspot_rate, fspot, alpha_max).")
# ── Convenience accessors ───────────────────────────────────────────────
@property
def sigma_k(self) -> float:
"""Point estimate (mean) of sigma_k. Backward-compatible float."""
return self._sigma_k_dist.mean
@sigma_k.setter
def sigma_k(self, value):
self._sigma_k_dist = as_distribution(value)
@property
def sigma_k_distribution(self):
"""The full ParameterDistribution for sigma_k."""
return self._sigma_k_dist
@property
def nspot_rate(self) -> float:
"""Spot emergence rate [spots/day], or None if not specified."""
return self._nspot_rate
@property
def a_spot(self) -> float:
"""Spot contrast-area product (1-f_spot)*alpha_max^2, or None."""
return self._a_spot
@property
def sigma_k_sq_expected(self) -> float:
"""E[sigma_k^2] under the distribution. Exact for DeltaDistribution."""
return self._sigma_k_dist.expectation(lambda x: x ** 2)
@property
def peq(self) -> float:
return self.visibility.peq if self.visibility is not None else None
@property
def kappa(self) -> float:
return self.visibility.kappa if self.visibility is not None else None
@property
def inc(self) -> float:
return self.visibility.inc if self.visibility is not None else None
@property
def lspot(self) -> float:
return self.envelope.lspot if self.envelope is not None else None
@property
def tau_spot(self) -> float:
return self.envelope.tau_spot if self.envelope is not None else None
# ── Parameter keys ──────────────────────────────────────────────────────
@property
def use_physical_amplitude(self) -> bool:
"""True when the model uses (nspot_rate, a_spot) instead of sigma_k."""
return self._nspot_rate is not None
@property
def param_keys(self) -> tuple:
"""
Ordered parameter names for the theta vector used in GPSolver.
When both are present, starts with (peq, kappa, inc) from the
visibility function, followed by the envelope-specific keys, then
the amplitude parameter(s). When the model was constructed with
``(nspot_rate, a_spot)`` or ``(nspot_rate, fspot, alpha_max)``,
the last two keys are ``("nspot_rate", "a_spot")``; otherwise
the last key is ``("sigma_k",)``.
"""
vis_keys = self.visibility.param_keys if self.visibility is not None else ()
env_keys = (tuple(self.envelope.param_dict.keys())
if self.envelope is not None else ())
lat_keys = self.latitude_distribution.param_keys
if self.use_physical_amplitude:
return vis_keys + env_keys + lat_keys + ("nspot_rate", "a_spot")
return vis_keys + env_keys + lat_keys + ("sigma_k",)
@property
def theta0(self) -> np.ndarray:
"""Initial parameter vector from current model values."""
vals = {}
if self.visibility is not None:
vals.update(self.visibility.param_dict)
if self.envelope is not None:
vals.update(self.envelope.param_dict)
vals.update(self.latitude_distribution.param_dict)
if self.use_physical_amplitude:
vals["nspot_rate"] = self._nspot_rate
vals["a_spot"] = self._a_spot
else:
vals["sigma_k"] = self.sigma_k
return np.array([float(vals[k]) for k in self.param_keys])
[docs]
def theta_from_hparam(self, hparam: dict) -> np.ndarray:
"""
Build a theta vector from a (possibly partial) hparam dict.
Missing keys fall back to the model's current values.
"""
current = dict(zip(self.param_keys, self.theta0))
current.update({k: float(v) for k, v in hparam.items()
if k in current})
return np.array([current[k] for k in self.param_keys])
# ── JAX-compilable R_Gamma function ─────────────────────────────────────
[docs]
def get_r_gamma_func(self):
"""
Return a JAX-traceable function r_gamma(theta_arr, lag) -> R_Gamma.
The theta_arr layout follows self.param_keys:
[peq, kappa, inc, <envelope params...>, sigma_k]
When envelope is None, returns a function that always yields 1.0
(pure visibility kernel). The R_Gamma function is selected based
on the envelope type and captured (together with any precomputed
grids) in a closure so that the returned callable is safe to use
inside jax.jit.
"""
if self.envelope is None:
def r_gamma(theta_arr, lag): # noqa: ARG001
return jnp.ones_like(jnp.asarray(lag))
return r_gamma
n_vis = len(self.visibility.param_keys) if self.visibility is not None else 0
if isinstance(self.envelope, TrapezoidSymmetricEnvelope):
def r_gamma(theta_arr, lag):
lspot = theta_arr[n_vis] # index 3
tau_spot = theta_arr[n_vis + 1] # index 4
return _R_Gamma_symmetric(lag, lspot, tau_spot)
elif isinstance(self.envelope, TrapezoidAsymmetricEnvelope):
def r_gamma(theta_arr, lag):
lspot = theta_arr[n_vis] # index 3
tau_em = theta_arr[n_vis + 1] # index 4
tau_dec = theta_arr[n_vis + 2] # index 5
te = jnp.minimum(tau_em, tau_dec)
td = jnp.maximum(tau_em, tau_dec)
return _R_Gamma_asymmetric(lag, lspot, te, td)
elif isinstance(self.envelope, SkewedGaussianEnvelope):
# sigma_sn and n_sn are in theta but R_Gamma uses the
# precomputed interpolation grid (fixed at init time).
lag_grid = self.envelope._R_lag_grid
R_vals = self.envelope._R_vals
def r_gamma(theta_arr, lag): # noqa: ARG001 (theta_arr unused)
return jnp.interp(jnp.abs(lag), lag_grid, R_vals)
elif isinstance(self.envelope, ExponentialEnvelope):
def r_gamma(theta_arr, lag):
tau_spot = theta_arr[n_vis] # index 3 (no lspot for exponential)
abs_lag = jnp.abs(lag)
return (tau_spot + abs_lag) * jnp.exp(-abs_lag / tau_spot)
else:
# Generic fallback via precomputed grid (like skew-normal)
env = self.envelope
tau_ref = env.tau_spot if env.tau_spot > 0 else 1.0
env_np = lambda t_arr: np.asarray(env.Gamma(jnp.array(t_arr)))
from .envelope import compute_R_Gamma_numerical
lag_grid, R_vals = compute_R_Gamma_numerical(env_np, tau_ref)
def r_gamma(theta_arr, lag): # noqa: ARG001
return jnp.interp(jnp.abs(lag), lag_grid, R_vals)
return r_gamma
# ── JAX-compilable latitude weight function ────────────────────────────
[docs]
def get_lat_weight_func(self):
"""
Return a JAX-traceable function ``f(theta_arr, phi_grid) -> weights``
that computes per-node latitude weights from the theta vector.
When the latitude distribution has no free parameters, returns None
(the caller should use the static weights precomputed at init).
The theta_arr layout follows ``self.param_keys``.
"""
lat_dist = self.latitude_distribution
if not lat_dist.param_dict:
return None
# Index of first latitude param in theta_arr
n_vis = len(self.visibility.param_keys) if self.visibility is not None else 0
n_env = len(self.envelope.param_dict) if self.envelope is not None else 0
lat_offset = n_vis + n_env
if isinstance(lat_dist, UniformDoubleHemisphereBand):
def lat_weight_fn(theta_arr, phi_grid):
lat_min = theta_arr[lat_offset]
lat_max = theta_arr[lat_offset + 1]
abs_phi = jnp.abs(phi_grid)
return jnp.where((abs_phi > lat_min) & (abs_phi < lat_max),
1.0, 0.0)
return lat_weight_fn
# Generic fallback: not JAX-traceable, but works for fixed params
return None
# ── Bandwidth support ───────────────────────────────────────────────────
[docs]
def bandwidth_support(self, param_keys, bounds_arr) -> float:
"""
Estimate the kernel support using upper bounds of parameters.
Used by GPSolver._compute_bandwidth to determine the banded
Cholesky bandwidth as a compile-time constant.
Parameters
----------
param_keys : sequence of str
Parameter names in the same order as bounds_arr.
bounds_arr : array_like, shape (n_params, 2)
Lower and upper bounds for each parameter.
"""
keys = list(param_keys)
bounds_arr = np.asarray(bounds_arr)
def upper(key, fallback):
if key in keys:
return float(bounds_arr[keys.index(key), 1])
log_key = f"log_{key}"
if log_key in keys:
return 10.0 ** float(bounds_arr[keys.index(log_key), 1])
return float(fallback)
if self.envelope is None:
return 0.0
if isinstance(self.envelope, TrapezoidSymmetricEnvelope):
return (upper("lspot", self.lspot)
+ 2.0 * upper("tau_spot", self.tau_spot))
elif isinstance(self.envelope, TrapezoidAsymmetricEnvelope):
return (upper("lspot", self.lspot)
+ upper("tau_em", self.envelope.tau_em)
+ upper("tau_dec", self.envelope.tau_dec))
elif isinstance(self.envelope, SkewedGaussianEnvelope):
return 12.0 * upper("sigma_sn", self.envelope.sigma_sn)
elif isinstance(self.envelope, ExponentialEnvelope):
return 6.0 * upper("tau_spot", self.tau_spot)
else:
return self.envelope.kernel_support()
# ── Serialization ───────────────────────────────────────────────────────
[docs]
def to_hparam(self) -> dict:
"""
Convert to a flat hparam dict for backward compatibility.
The returned dict is accepted by resolve_hparam and by the
old-style constructors of AnalyticKernel, GPSolver, etc.
"""
d = {}
if self.visibility is not None:
d.update(self.visibility.param_dict)
if self.envelope is not None:
d.update(self.envelope.param_dict)
d["sigma_k"] = self.sigma_k
if self._nspot_rate is not None:
d["nspot_rate"] = self._nspot_rate
if self._a_spot is not None:
d["a_spot"] = self._a_spot
if self.alpha_max is not None:
d["alpha_max"] = self.alpha_max
if self.fspot:
d["fspot"] = self.fspot
return d
[docs]
@classmethod
def from_hparam(cls, hparam: dict) -> "SpotEvolutionModel":
"""
Construct a SpotEvolutionModel from a raw hparam dict.
Accepts the same dict format that resolve_hparam accepts,
including all envelope types and amplitude modes.
"""
p = resolve_hparam(hparam)
visibility = VisibilityFunction(p["peq"], p["kappa"], p["inc"])
if "sigma_sn" in hparam and "n_sn" in hparam:
envelope = SkewedGaussianEnvelope(
p["sigma_sn"], p["n_sn"], p.get("lspot", 0.0))
elif "tau_em" in hparam and "tau_dec" in hparam:
envelope = TrapezoidAsymmetricEnvelope(
p["lspot"], p["tau_em"], p["tau_dec"])
else:
envelope = TrapezoidSymmetricEnvelope(p["lspot"], p["tau_spot"])
return cls(envelope, visibility, sigma_k=p["sigma_k"])
[docs]
def get_sympy(self, display=True, compute_symbolic=False):
"""
Display sympy expressions for the full spot evolution model.
Delegates to ``EnvelopeFunction.get_sympy()`` and
``VisibilityFunction.get_sympy()`` in sequence.
Requires sympy (``pip install sympy``).
Parameters
----------
display : bool, optional
If True (default), render equations as formatted LaTeX in a
Jupyter notebook (via IPython.display) or print them as LaTeX
strings in a plain terminal.
compute_symbolic : bool, optional
Passed through to ``EnvelopeFunction.get_sympy()``. If True,
attempt to derive Gamma_hat and R_Gamma symbolically from
sympy_Gamma() when no explicit override exists. Defaults to
False.
Returns
-------
dict
``{"envelope": envelope_exprs, "visibility": visibility_exprs,
"latitude": latitude_exprs}``
where each value is the dict returned by the respective
``get_sympy()`` call.
"""
def _display_not_specified(label):
if display:
try:
from IPython.display import display as ipy_display, Math
ipy_display(Math(r"\textbf{" + label
+ r"} \text{ [not specified]}"))
except ImportError:
print(f"{label} [not specified]")
if self.envelope is not None:
envelope_exprs = self.envelope.get_sympy(
display=display, status=self._envelope_status,
compute_symbolic=compute_symbolic)
else:
_display_not_specified("EnvelopeFunction")
envelope_exprs = None
if self.visibility is not None:
visibility_exprs = self.visibility.get_sympy(
display=display, status=self._visibility_status)
else:
_display_not_specified("VisibilityFunction")
visibility_exprs = None
latitude_exprs = self.latitude_distribution.get_sympy(
display=display, status=self._latitude_status)
return {"envelope": envelope_exprs, "visibility": visibility_exprs, "latitude": latitude_exprs}
def __repr__(self) -> str:
if self.envelope is not None:
env_str = (f"{self.envelope.__class__.__name__}"
f"({self.envelope.param_dict})")
else:
env_str = "None"
if self.visibility is not None:
vis_str = (f"VisibilityFunction"
f"(peq={self.peq}, kappa={self.kappa}, inc={self.inc:.3f})")
else:
vis_str = "None"
return (
f"SpotEvolutionModel(\n"
f" envelope={env_str},\n"
f" visibility={vis_str},\n"
f" sigma_k={self.sigma_k}\n)"
)