Source code for spotgp.transit

"""
transit.py — Keplerian orbit and transit light curve computation.

Implements a Keplerian orbit model for binary or planetary systems,
following the parameterization of the ``exoplanet`` package
(https://docs.exoplanet.codes).  All heavy computation uses JAX for
GPU/CPU acceleration and JIT compilation.
"""

import jax
import jax.numpy as jnp
import numpy as np
from functools import partial

jax.config.update("jax_enable_x64", True)

try:
    from .lightcurve import LightcurveModel
    from .gp_solver import GPSolver
except ImportError:
    from lightcurve import LightcurveModel
    from gp_solver import GPSolver

__all__ = [
    "KeplerianOrbit",
    "QuadLimbDarkLightCurve",
    "SpotTransitModel",
]


# =====================================================================
# Kepler's equation solver
# =====================================================================

@jax.jit
def _kepler_starter(M, ecc):
    """Markley (1995) starter for Kepler's equation."""
    E0 = M + ecc * jnp.sin(M) / (1.0 - jnp.sin(M) + jnp.sin(M + ecc))
    return E0


@jax.jit
def _kepler_refine(E, M, ecc):
    """One Newton-Raphson refinement step."""
    f = E - ecc * jnp.sin(E) - M
    fp = 1.0 - ecc * jnp.cos(E)
    return E - f / fp


@jax.jit
def _solve_kepler(M, ecc):
    """Solve Kepler's equation  M = E - e sin(E)  for eccentric anomaly E.

    Uses a starter plus several Newton-Raphson iterations, all
    implemented with pure JAX operations so the solver is JIT-compatible
    and differentiable.

    Parameters
    ----------
    M : array_like
        Mean anomaly [rad].
    ecc : float
        Eccentricity (0 <= ecc < 1).

    Returns
    -------
    E : array_like
        Eccentric anomaly [rad].
    """
    # Starter
    E = _kepler_starter(M, ecc)
    # Newton iterations (unrolled for JIT compatibility)
    E = _kepler_refine(E, M, ecc)
    E = _kepler_refine(E, M, ecc)
    E = _kepler_refine(E, M, ecc)
    E = _kepler_refine(E, M, ecc)
    E = _kepler_refine(E, M, ecc)
    return E


# =====================================================================
# True anomaly from eccentric anomaly
# =====================================================================

@jax.jit
def _E_to_true_anomaly(E, ecc):
    """Convert eccentric anomaly to true anomaly.

    Parameters
    ----------
    E : array_like
        Eccentric anomaly [rad].
    ecc : float
        Eccentricity.

    Returns
    -------
    f : array_like
        True anomaly [rad].
    """
    half_angle = 2.0 * jnp.arctan2(
        jnp.sqrt(1.0 + ecc) * jnp.sin(E / 2.0),
        jnp.sqrt(1.0 - ecc) * jnp.cos(E / 2.0),
    )
    return half_angle


# =====================================================================
# KeplerianOrbit
# =====================================================================

[docs] class KeplerianOrbit: """Keplerian orbit for a binary or planetary system. This class computes the 3-D position of the secondary body relative to the primary (host star) as a function of time. It also provides the projected sky-plane separation needed for transit light curve evaluation. The parameterization follows the ``exoplanet`` package convention: the reference direction is along the line of sight, with the sky plane defined by (x, y) and the z-axis pointing *toward* the observer. A transit occurs when z > 0 and the sky-plane separation is less than the sum of the stellar and companion radii. Parameters ---------- period : float Orbital period [days]. semi_major_axis : float or None Semi-major axis in units of stellar radii. If ``None``, it is computed from ``period`` and ``stellar_mass`` via Kepler's third law. ecc : float Orbital eccentricity (default 0). incl : float or None Orbital inclination [rad]. If ``None``, it is computed from ``impact_param``. impact_param : float or None Impact parameter (0 = central transit). Ignored when ``incl`` is provided directly. omega : float Argument of periastron of the *secondary* [rad] (default 0). Omega : float Longitude of the ascending node [rad] (default 0). t_periastron : float or None Time of periastron passage [days]. Exactly one of ``t_periastron`` and ``t0`` must be given. t0 : float or None Reference transit mid-time [days]. If provided instead of ``t_periastron``, the periastron time is derived so that transit center falls at ``t0``. radius_ratio : float Companion-to-star radius ratio Rp / R* (default 0.1). stellar_mass : float Stellar mass in solar masses (default 1.0). Used only when ``semi_major_axis`` is ``None``. stellar_radius : float Stellar radius in solar radii (default 1.0). Used for converting semi-major axis to physical units. Notes ----- Angles follow the standard celestial mechanics convention: * ``omega`` is the argument of periastron of the *orbiting body*. * ``incl = pi/2`` corresponds to an edge-on orbit. """ # Gravitational constant * solar mass [R_sun^3 / (M_sun * day^2)] _G_MSUN_RSUN3_DAY2 = 2942.2062 def __init__( self, period, semi_major_axis=None, ecc=0.0, incl=None, impact_param=None, omega=0.0, Omega=0.0, t_periastron=None, t0=None, radius_ratio=0.1, stellar_mass=1.0, stellar_radius=1.0, ): self.period = float(period) self.ecc = float(ecc) self.omega = float(omega) self.Omega = float(Omega) self.radius_ratio = float(radius_ratio) self.stellar_mass = float(stellar_mass) self.stellar_radius = float(stellar_radius) # Semi-major axis -------------------------------------------------- if semi_major_axis is not None: self.semi_major_axis = float(semi_major_axis) else: # Kepler's third law: a^3 = G M P^2 / (4 pi^2) a_cubed = ( self._G_MSUN_RSUN3_DAY2 * self.stellar_mass * self.period ** 2 / (4.0 * np.pi ** 2) ) self.semi_major_axis = float(a_cubed ** (1.0 / 3.0)) # Store in units of stellar radii self.a_over_Rstar = self.semi_major_axis / self.stellar_radius # Inclination ------------------------------------------------------ if incl is not None: self.incl = float(incl) self.impact_param = float( self.a_over_Rstar * np.cos(self.incl) * (1.0 - self.ecc ** 2) / (1.0 + self.ecc * np.sin(self.omega)) ) elif impact_param is not None: self.impact_param = float(impact_param) cos_incl = ( self.impact_param * (1.0 + self.ecc * np.sin(self.omega)) / (self.a_over_Rstar * (1.0 - self.ecc ** 2)) ) cos_incl = np.clip(cos_incl, -1.0, 1.0) self.incl = float(np.arccos(cos_incl)) else: # Default to edge-on self.incl = 0.5 * np.pi self.impact_param = 0.0 # Reference time --------------------------------------------------- if t_periastron is not None and t0 is not None: raise ValueError( "Specify exactly one of t_periastron and t0, not both." ) if t_periastron is not None: self.t_periastron = float(t_periastron) self.t0 = self._t_periastron_to_t0(self.t_periastron) elif t0 is not None: self.t0 = float(t0) self.t_periastron = self._t0_to_t_periastron(self.t0) else: self.t0 = 0.0 self.t_periastron = self._t0_to_t_periastron(self.t0) # ----------------------------------------------------------------- # Time reference conversions # ----------------------------------------------------------------- def _true_anomaly_at_transit(self): """True anomaly at inferior conjunction (transit center).""" return 0.5 * np.pi - self.omega def _t_periastron_to_t0(self, t_peri): """Compute transit mid-time from time of periastron.""" f_transit = self._true_anomaly_at_transit() E_transit = 2.0 * np.arctan2( np.sqrt(1.0 - self.ecc) * np.sin(f_transit / 2.0), np.sqrt(1.0 + self.ecc) * np.cos(f_transit / 2.0), ) M_transit = E_transit - self.ecc * np.sin(E_transit) return t_peri + M_transit * self.period / (2.0 * np.pi) def _t0_to_t_periastron(self, t0): """Compute time of periastron from transit mid-time.""" f_transit = self._true_anomaly_at_transit() E_transit = 2.0 * np.arctan2( np.sqrt(1.0 - self.ecc) * np.sin(f_transit / 2.0), np.sqrt(1.0 + self.ecc) * np.cos(f_transit / 2.0), ) M_transit = E_transit - self.ecc * np.sin(E_transit) return t0 - M_transit * self.period / (2.0 * np.pi) # ----------------------------------------------------------------- # Orbital position # ----------------------------------------------------------------- def _get_true_anomaly(self, t): """Compute true anomaly at times *t*. Parameters ---------- t : array_like Times [days]. Returns ------- f : jnp.ndarray True anomaly [rad]. """ t = jnp.asarray(t, dtype=jnp.float64) M = 2.0 * jnp.pi * (t - self.t_periastron) / self.period E = _solve_kepler(M, self.ecc) return _E_to_true_anomaly(E, self.ecc)
[docs] def get_position(self, t): """Compute the 3-D position of the companion relative to the star. The coordinate system is: * **x** — on the sky, in the direction of increasing right ascension (or equivalently, the descending-node direction when ``Omega = 0``). * **y** — on the sky, toward celestial north (ascending-node direction when ``Omega = 0``). * **z** — along the line of sight, *toward* the observer. A transit occurs when ``z > 0`` (companion is in front of star). Parameters ---------- t : array_like Times [days]. Returns ------- x, y, z : jnp.ndarray Position of the companion in units of stellar radii. """ f = self._get_true_anomaly(t) # Radial distance in units of semi-major axis r = self.a_over_Rstar * (1.0 - self.ecc ** 2) / ( 1.0 + self.ecc * jnp.cos(f) ) # Position in the orbital plane (X toward pericenter, Y perp) cos_f_w = jnp.cos(f + self.omega) sin_f_w = jnp.sin(f + self.omega) cos_Omega = jnp.cos(self.Omega) sin_Omega = jnp.sin(self.Omega) cos_incl = jnp.cos(self.incl) sin_incl = jnp.sin(self.incl) # Thiele-Innes rotation to observer frame x = r * ( -sin_f_w * cos_Omega * cos_incl + cos_f_w * sin_Omega ) # not used for transit, kept for completeness y = r * ( sin_f_w * sin_Omega * cos_incl + cos_f_w * cos_Omega ) # not used for transit, kept for completeness # z positive = toward observer (transit geometry) z = r * sin_f_w * sin_incl # Sky-plane rotation by Omega x_sky = -r * ( cos_Omega * cos_f_w - sin_Omega * sin_f_w * cos_incl ) y_sky = -r * ( sin_Omega * cos_f_w + cos_Omega * sin_f_w * cos_incl ) return x_sky, y_sky, z
[docs] def get_sky_separation(self, t): """Projected sky-plane separation between companion and star. Parameters ---------- t : array_like Times [days]. Returns ------- rho : jnp.ndarray Sky-plane separation in units of stellar radii. """ x, y, z = self.get_position(t) return jnp.sqrt(x ** 2 + y ** 2)
[docs] def get_radial_velocity(self, t, K=None): """Radial velocity of the star due to the companion. Parameters ---------- t : array_like Times [days]. K : float or None RV semi-amplitude [m/s]. If ``None``, the returned values are in natural units (useful for relative comparison only). Returns ------- rv : jnp.ndarray Radial velocity [m/s if K is given, else natural units]. """ f = self._get_true_anomaly(t) rv = jnp.cos(f + self.omega) + self.ecc * jnp.cos(self.omega) if K is not None: rv = K * rv return rv
[docs] def in_transit(self, t): """Boolean mask for times during which a transit is occurring. A point is in transit when the companion is in front of the star (``z > 0``) and the sky-plane separation is less than ``1 + radius_ratio`` stellar radii. Parameters ---------- t : array_like Times [days]. Returns ------- mask : jnp.ndarray of bool """ x, y, z = self.get_position(t) rho = jnp.sqrt(x ** 2 + y ** 2) return (z > 0) & (rho < 1.0 + self.radius_ratio)
# ===================================================================== # Quadratic limb-darkened transit light curve # =====================================================================
[docs] class QuadLimbDarkLightCurve: """Compute a transit light curve with quadratic limb darkening. Uses the analytic Mandel & Agol (2002) uniform-source solution with a polynomial limb-darkening correction, following the ``exoplanet`` implementation style. Parameters ---------- u1 : float First quadratic limb-darkening coefficient. u2 : float Second quadratic limb-darkening coefficient. The specific intensity profile is: .. math:: I(\\mu) / I(1) = 1 - u_1 (1 - \\mu) - u_2 (1 - \\mu)^2 where :math:`\\mu = \\cos\\theta` and :math:`\\theta` is the angle from disk center. """ def __init__(self, u1, u2): self.u1 = float(u1) self.u2 = float(u2) # ----------------------------------------------------------------- # Mandel & Agol uniform-source occultation # ----------------------------------------------------------------- @staticmethod @jax.jit def _uniform_occultation(z, p): """Fractional flux blocked by a uniform-source occultation. Parameters ---------- z : array_like Projected separation in units of stellar radii. p : float Companion-to-star radius ratio. Returns ------- delta_F : array_like Fraction of stellar flux blocked (>= 0). """ z = jnp.abs(z) p = jnp.abs(p) # Case 1: complete occultation (companion fully covers disk) full = p >= 1.0 full_occ = jnp.where(z <= p - 1.0, 1.0, 0.0) # Case 2: no overlap no_overlap = z >= 1.0 + p # Case 3: companion fully on disk on_disk = z <= 1.0 - p on_disk_flux = p ** 2 # Case 4: partial overlap — area of intersection of two circles # of radii 1 and p separated by distance z z_safe = jnp.clip(z, 1e-12, None) kappa1 = jnp.arccos( jnp.clip((1.0 - p ** 2 + z_safe ** 2) / (2.0 * z_safe), -1, 1) ) kappa0 = jnp.arccos( jnp.clip((p ** 2 + z_safe ** 2 - 1.0) / (2.0 * p * z_safe), -1, 1) ) partial_flux = ( p ** 2 * kappa0 + kappa1 - 0.5 * jnp.sqrt( jnp.clip( 4.0 * z_safe ** 2 - (1.0 + z_safe ** 2 - p ** 2) ** 2, 0.0, None, ) ) ) / jnp.pi delta = jnp.where( full & (z <= p - 1.0), full_occ, jnp.where( no_overlap, 0.0, jnp.where(on_disk, on_disk_flux, partial_flux), ), ) return delta # ----------------------------------------------------------------- # Limb-darkening correction # ----------------------------------------------------------------- @partial(jax.jit, static_argnums=(0,)) def _limb_dark_factor(self, z, p): """Approximate limb-darkening correction factor. For a quadratic limb-darkening law, the transit depth depends on where the companion sits on the stellar disk. This computes a multiplicative correction to the uniform-source solution using a first-order expansion in the limb-darkening coefficients. Parameters ---------- z : array_like Projected separation in stellar radii. p : float Radius ratio. Returns ------- factor : array_like Multiplicative correction (1.0 = no limb darkening). """ u1 = self.u1 u2 = self.u2 # Normalization: integral of I(mu) over the disk I0 = 1.0 - u1 / 3.0 - u2 / 6.0 # Mean mu under the companion shadow (approximate for small p) mu_eff = jnp.sqrt(jnp.clip(1.0 - z ** 2, 0.0, 1.0)) # Intensity at the shadow position I_shadow = 1.0 - u1 * (1.0 - mu_eff) - u2 * (1.0 - mu_eff) ** 2 return I_shadow / I0 # ----------------------------------------------------------------- # Public API # ----------------------------------------------------------------- @partial(jax.jit, static_argnums=(0,)) def _compute_flux(self, t, x, y, z_pos, radius_ratio): """JIT-compiled flux computation.""" rho = jnp.sqrt(x ** 2 + y ** 2) delta_uniform = self._uniform_occultation(rho, radius_ratio) ld_corr = self._limb_dark_factor(rho, radius_ratio) # Only block flux when companion is in front (z > 0) in_front = z_pos > 0.0 delta = jnp.where(in_front, delta_uniform * ld_corr, 0.0) return 1.0 - delta
[docs] def get_light_curve(self, orbit, t): """Compute the transit light curve for a Keplerian orbit. Parameters ---------- orbit : KeplerianOrbit Orbital model instance. t : array_like Times [days]. Returns ------- flux : jnp.ndarray Relative flux (1.0 = out of transit). """ t = jnp.asarray(t, dtype=jnp.float64) x, y, z_pos = orbit.get_position(t) return self._compute_flux(t, x, y, z_pos, orbit.radius_ratio)
# ===================================================================== # Combined spot + transit model # =====================================================================
[docs] class SpotTransitModel: """Combined stellar variability (spots) + planetary transit model. The total flux is the product of the spot-modulated stellar flux and the transit flux: .. math:: F(t) = F_{\\mathrm{spots}}(t) \\times F_{\\mathrm{transit}}(t) This is the standard multiplicative model: the transit removes a fraction of the *current* stellar brightness (including any spot modulation on the visible hemisphere). The spot component can be supplied in two ways: 1. **Explicit spots** via a ``LightcurveModel`` instance — a forward simulation with individual spots drawn from the spot evolution model. 2. **GP-based** via a ``GPSolver`` instance — draws from the GP prior or posterior to represent the stellar variability stochastically. Parameters ---------- orbit : KeplerianOrbit Keplerian orbit of the transiting companion. limbdark : QuadLimbDarkLightCurve Limb-darkened transit light curve calculator. spot_model : LightcurveModel or GPSolver or None Source of stellar variability. If ``None``, the spot component is unity (transit only). """ def __init__(self, orbit, limbdark, spot_model=None): self.orbit = orbit self.limbdark = limbdark self.spot_model = spot_model # ----------------------------------------------------------------- # Transit-only component # -----------------------------------------------------------------
[docs] def get_transit_flux(self, t): """Compute the transit light curve alone (no spots). Parameters ---------- t : array_like Times [days]. Returns ------- flux_transit : jnp.ndarray Relative flux from the transit (1.0 = out of transit). """ return self.limbdark.get_light_curve(self.orbit, t)
# ----------------------------------------------------------------- # Spot-only component # -----------------------------------------------------------------
[docs] def get_spot_flux(self, t): """Compute the spot-modulated stellar flux alone (no transit). Uses whichever spot model was provided at construction time. Parameters ---------- t : array_like Times [days]. Returns ------- flux_spots : ndarray Relative flux from stellar variability. Raises ------ ValueError If no spot model was provided. """ if self.spot_model is None: raise ValueError("No spot model was provided.") if isinstance(self.spot_model, LightcurveModel): t_arr = np.asarray(t, dtype=float) return np.asarray(self.spot_model.Flux(t_arr)) if isinstance(self.spot_model, GPSolver): mu, _ = self.spot_model.predict(np.asarray(t)) return np.asarray(mu) raise TypeError( f"spot_model must be a LightcurveModel or GPSolver, " f"got {type(self.spot_model).__name__}" )
# ----------------------------------------------------------------- # Combined light curve # -----------------------------------------------------------------
[docs] def get_light_curve(self, t): """Compute the combined spot + transit light curve. Parameters ---------- t : array_like Times [days]. Returns ------- flux : jnp.ndarray Combined relative flux. """ t = jnp.asarray(t, dtype=jnp.float64) flux_transit = self.get_transit_flux(t) if self.spot_model is None: return flux_transit flux_spots = jnp.asarray(self.get_spot_flux(t)) return flux_spots * flux_transit
# ----------------------------------------------------------------- # Sampling (GP mode) # -----------------------------------------------------------------
[docs] def sample_light_curves(self, t, n_samples=5, source="prior", rng=None): """Draw combined spot + transit light curve samples. Only available when the spot model is a ``GPSolver``. Each sample draws an independent realization of the spot variability and multiplies it by the deterministic transit model. Parameters ---------- t : array_like Times [days]. n_samples : int Number of samples to draw (default 5). source : {'prior', 'posterior'} Draw from the GP prior or posterior (default ``'prior'``). rng : numpy.random.Generator or None Random number generator for reproducibility. Returns ------- t : jnp.ndarray, shape (M,) Evaluation times. samples : jnp.ndarray, shape (n_samples, M) Combined flux samples. Raises ------ TypeError If the spot model is not a ``GPSolver``. """ if not isinstance(self.spot_model, GPSolver): raise TypeError( "sample_light_curves requires a GPSolver spot model, " f"got {type(self.spot_model).__name__}" ) t = jnp.asarray(t, dtype=jnp.float64) t_np = np.asarray(t) # Draw spot variability samples _, spot_samples = self.spot_model.sample_lightcurves( xpred=t_np, n_samples=n_samples, source=source, rng=rng, ) spot_samples = jnp.asarray(spot_samples) # Deterministic transit component (same for every sample) flux_transit = self.get_transit_flux(t) # Multiply each spot sample by the transit combined = spot_samples * flux_transit[None, :] return t, combined
# ----------------------------------------------------------------- # Animation # -----------------------------------------------------------------
[docs] def animate_lightcurve(self, t=None, fps=30, duration=10.0, outfile=None, dpi=150, show_spots=True, show_grid=True, show_params=True, figsize=(14, 5.5), save_last_frame=None, show_dr=True, label_size=18): """Animate the spotted star with transiting planet and combined light curve. Left panel shows the rotating stellar disk with spots and the planet transiting across it. Right panel traces the combined (spots x transit) light curve over time, with the transit-only and spot-only components shown as faint references. Requires a ``LightcurveModel`` as the spot model. Parameters ---------- t : array_like or None Time grid [days]. If ``None``, uses the spot model's internal grid ``spot_model.t``. fps : int Frames per second (default 30). duration : float Animation duration in seconds (default 10). outfile : str or None Output file path (.mp4 or .gif). If ``None``, returns the animation object without saving. dpi : int Resolution (default 150). show_spots : bool If ``True``, show individual spot contributions on the light curve panel (default ``True``). show_grid : bool If ``True``, draw latitude/longitude grid on the star (default ``True``). show_params : bool If ``True``, show parameter annotation above the figure (default ``True``). figsize : tuple Figure size (default ``(14, 5.5)``). save_last_frame : str or None If provided, save the last frame as a static image. show_dr : bool If ``True``, color the stellar disk by differential rotation rate (default ``True``). label_size : int or float Font size for labels and tick marks (default 18). Returns ------- anim : matplotlib.animation.FuncAnimation """ import matplotlib.pyplot as plt import matplotlib.animation as animation from matplotlib.patches import Circle # -- lazy imports from lightcurve helpers ------------------------- try: from .lightcurve import ( _projected_spot_patch, _alphak, _betak, ) except ImportError: from lightcurve import ( _projected_spot_patch, _alphak, _betak, ) if not isinstance(self.spot_model, LightcurveModel): raise TypeError( "animate_lightcurve requires a LightcurveModel spot " f"model, got {type(self.spot_model).__name__}" ) sm = self.spot_model orbit = self.orbit # -- time grid & fluxes ------------------------------------------- if t is None: t = sm.t t = np.asarray(t, dtype=float) n_times = len(t) t_jax = jnp.array(t) flux_spots = np.asarray(sm.Flux(t)) flux_transit = np.asarray(self.get_transit_flux(t)) flux_combined = flux_spots * flux_transit # Per-spot flux deficits (set as side effect of Flux) dspots = sm.dspots # (nspot, n_times) inc = sm.inc nspot = sm.nspot spot_longs = np.atleast_1d(sm.long) spot_lats = np.atleast_1d(sm.lat) spot_tmaxs = sm.tmax # -- precompute spot geometry for every frame --------------------- spot_alphas = np.zeros((nspot, n_times)) spot_longs_t = np.zeros((nspot, n_times)) for k in range(nspot): if sm.grow: spot_alphas[k] = np.asarray(_alphak( t_jax, spot_tmaxs[k], sm.lspot, sm.tem, sm.tdec, sm.alpha_max)) else: spot_alphas[k] = sm.alpha_max if sm.rotate: _, longk_t = _betak( t_jax, spot_longs[k], spot_lats[k], spot_tmaxs[k], sm.peq, sm.kappa, sm.inc) spot_longs_t[k] = np.asarray(longk_t) else: spot_longs_t[k] = spot_longs[k] # -- precompute planet sky position (in stellar-radii units) ------ px, py, pz = orbit.get_position(t) px = np.asarray(px) py = np.asarray(py) pz = np.asarray(pz) planet_r = orbit.radius_ratio # in stellar radii # ================================================================ # Figure layout # ================================================================ fig, (ax_star, ax_lc) = plt.subplots( 1, 2, figsize=figsize, gridspec_kw={"width_ratios": [1, 1.6]}) # -- star panel --------------------------------------------------- ax_star.set_aspect("equal") ax_star.set_xlim(-1.55, 1.55) ax_star.set_ylim(-1.55, 1.55) ax_star.set_axis_off() if show_dr: from matplotlib.colors import Normalize import matplotlib.cm as cm omega_eq = 2 * np.pi / sm.peq cmap = cm.coolwarm if sm.kappa == 0: stellar_disk = Circle( (0, 0), 1.0, fc="lightyellow", ec="k", lw=1.5, zorder=-1) ax_star.add_patch(stellar_disk) n_pix = 300 xp = np.linspace(-1, 1, n_pix) yp = np.linspace(-1, 1, n_pix) XP, YP = np.meshgrid(xp, yp) R2 = XP ** 2 + YP ** 2 omega_map = np.where(R2 <= 1.0, 0.5, np.nan) norm = Normalize(vmin=0.0, vmax=1.0) dr_img = ax_star.imshow( omega_map, extent=[-1, 1, -1, 1], origin="lower", interpolation="bilinear", cmap=cmap, norm=norm, alpha=0.3, zorder=0) clip_circle = Circle( (0, 0), 1.0, transform=ax_star.transData) dr_img.set_clip_path(clip_circle) ax_star.text( -1.3, 0.0, rf"$\Omega = {omega_eq:.3f}$ [rad/d]", fontsize=label_size - 2, ha="center", va="center", rotation=90, transform=ax_star.transData) else: omega_min = omega_eq * (1 - sm.kappa) omega_max = omega_eq if omega_min > omega_max: omega_min, omega_max = omega_max, omega_min norm = Normalize(vmin=omega_min, vmax=omega_max) n_pix = 300 xp = np.linspace(-1, 1, n_pix) yp = np.linspace(-1, 1, n_pix) XP, YP = np.meshgrid(xp, yp) R2 = XP ** 2 + YP ** 2 CZ = np.sqrt(np.clip(1.0 - R2, 0, None)) sin_lat = (-np.sin(inc) * YP + np.cos(inc) * CZ) sin_lat = np.clip(sin_lat, -1.0, 1.0) lat_map = np.arcsin(sin_lat) omega_map = omega_eq * ( 1 - sm.kappa * np.sin(lat_map) ** 2) stellar_disk = Circle( (0, 0), 1.0, fc="lightyellow", ec="k", lw=1.5, zorder=-1) ax_star.add_patch(stellar_disk) dr_img = ax_star.imshow( omega_map, extent=[-1, 1, -1, 1], origin="lower", interpolation="bilinear", cmap=cmap, norm=norm, alpha=0.3, zorder=0) clip_circle = Circle( (0, 0), 1.0, transform=ax_star.transData) dr_img.set_clip_path(clip_circle) cbar = fig.colorbar( dr_img, ax=ax_star, fraction=0.046, pad=0.04, location="left") cbar.set_label(r"$\Omega$ [rad/d]", fontsize=label_size) cbar.ax.tick_params(labelsize=label_size - 2) cbar.ax.text( 0.6, 1.02, "faster", transform=cbar.ax.transAxes, ha="center", va="bottom", fontsize=label_size - 2, color="red") cbar.ax.text( 0.6, -0.02, "slower", transform=cbar.ax.transAxes, ha="center", va="top", fontsize=label_size - 2, color="blue") else: stellar_disk = Circle( (0, 0), 1.0, fc="lightyellow", ec="k", lw=1.5, zorder=0) ax_star.add_patch(stellar_disk) # Grid lines if show_grid: phi_grid = np.linspace(0, 2 * np.pi, 200) for lat_deg in [0, 30, 60, -30, -60]: lat_r = np.radians(lat_deg) gx = (-np.sin(inc) * np.sin(lat_r) + np.cos(inc) * np.cos(lat_r) * np.cos(phi_grid)) gy = np.cos(lat_r) * np.sin(phi_grid) gz = (np.cos(inc) * np.sin(lat_r) + np.sin(inc) * np.cos(lat_r) * np.cos(phi_grid)) mask = gz > 0 style = (("k--", 0.6, 0.3) if lat_deg == 0 else ("k-", 0.3, 0.2)) ax_star.plot( np.where(mask, gy, np.nan), np.where(mask, gx, np.nan), style[0], lw=style[1], alpha=style[2]) # Rotation axis arrow ax_star.annotate( "", xy=(0, 1.4), xytext=(0, -0.3), arrowprops=dict(arrowstyle="->, head_width=0.08", color="0.5", lw=1.2)) # Spot patches spot_colors = plt.cm.Set1(np.linspace(0, 1, max(nspot, 1))) spot_patches_art = [] ghost_patches_art = [] for k in range(nspot): c = spot_colors[k % len(spot_colors)] patch, = ax_star.fill([], [], color=c, alpha=0.85, zorder=2) ghost, = ax_star.fill([], [], color=c, alpha=0.15, zorder=1, linestyle="--", edgecolor=c, linewidth=0.8) spot_patches_art.append(patch) ghost_patches_art.append(ghost) # Planet circle (dark disk, high zorder so it draws over spots) planet_patch = Circle( (0, 0), planet_r, fc="0.15", ec="k", lw=0.6, zorder=10, visible=False) ax_star.add_patch(planet_patch) time_text = ax_star.text( 0, -1.45, "", fontsize=label_size, ha="center", va="top") # -- light curve panel -------------------------------------------- dip_combined = (1 - flux_combined) * 100 dip_transit = (1 - flux_transit) * 100 dip_spots = (1 - flux_spots) * 100 dip_per_spot = dspots * 100 dip_max = max(np.max(dip_combined), 1e-3) dip_range = dip_max if dip_max > 0 else 1.0 ax_lc.set_xlim(t[0], t[-1]) ax_lc.set_ylim(-0.05 * dip_range, dip_max + 0.15 * dip_range) ax_lc.invert_yaxis() ax_lc.set_xlabel("Time [days]", fontsize=label_size) ax_lc.set_ylabel(r"Flux dip [\%]", fontsize=label_size) ax_lc.tick_params(labelsize=label_size - 2) ax_lc.minorticks_on() # Full light curves as faint background ax_lc.plot(t, dip_combined, "k-", lw=0.3, alpha=0.12, zorder=0) ax_lc.plot(t, dip_transit, "-", color="C0", lw=0.3, alpha=0.12, zorder=0) ax_lc.plot(t, dip_spots, "-", color="C1", lw=0.3, alpha=0.12, zorder=0) # Traced lines (build up over time) lc_line, = ax_lc.plot([], [], "k-", lw=1.4, zorder=4, label="Combined") transit_line, = ax_lc.plot([], [], "-", color="C0", lw=1.0, alpha=0.7, zorder=3, label="Transit") spots_line, = ax_lc.plot([], [], "-", color="C1", lw=1.0, alpha=0.7, zorder=3, label="Spots") # Per-spot traces spot_lc_lines = [] if show_spots: for k in range(nspot): c = spot_colors[k % len(spot_colors)] ln, = ax_lc.plot([], [], "-", color=c, lw=0.6, alpha=0.4, zorder=1) spot_lc_lines.append(ln) # Vertical time marker vline = ax_lc.axvline( 0, color="C3", lw=1.0, alpha=0.7, ls="--", zorder=5) ax_lc.legend(fontsize=label_size - 4, loc="upper right") fig.tight_layout() # Parameter annotation if show_params: spot_text = ( rf"$P_{{\rm eq}}={sm.peq:.1f}$ d, " rf"$\kappa={sm.kappa:.2f}$, " rf"$I={sm.inc_deg:.0f}^\circ$, " rf"$N_{{\rm spot}}={sm.nspot}$" ) transit_text = ( rf"$P_{{\rm orb}}={orbit.period:.2f}$ d, " rf"$R_p/R_*={orbit.radius_ratio:.3f}$, " rf"$b={orbit.impact_param:.2f}$, " rf"$e={orbit.ecc:.2f}$" ) fig.text(0.5, 0.99, spot_text + r" $|$ " + transit_text, fontsize=label_size - 2, ha="center", va="top") fig.subplots_adjust(top=0.90) # ================================================================ # Animation loop # ================================================================ n_frames = int(fps * duration) frame_indices = np.linspace( 0, n_times - 1, n_frames).astype(int) empty_xy = np.empty((0, 2)) def update(frame_num): idx = frame_indices[frame_num] t_now = t[idx] # -- update spots on the star --------------------------------- for k in range(nspot): alpha_k = spot_alphas[k, idx] if alpha_k < 1e-6: spot_patches_art[k].set_xy(empty_xy) ghost_patches_art[k].set_xy(empty_xy) continue lon_k = spot_longs_t[k, idx] lat_k = spot_lats[k] fx, fy, bx, by = _projected_spot_patch( lon_k, lat_k, alpha_k, inc) if fx is not None and len(fx) >= 3: spot_patches_art[k].set_xy( np.column_stack([fx, fy])) else: spot_patches_art[k].set_xy(empty_xy) if bx is not None and len(bx) >= 3: ghost_patches_art[k].set_xy( np.column_stack([bx, by])) else: ghost_patches_art[k].set_xy(empty_xy) # -- update planet on the star -------------------------------- # Planet sky coords: x_sky -> horizontal, y_sky -> vertical # on the star panel. Show only when in front (pz > 0). if pz[idx] > 0: planet_patch.set_visible(True) planet_patch.center = (px[idx], py[idx]) else: planet_patch.set_visible(False) time_text.set_text(rf"$t = {t_now:.2f}$ d") # -- update light curve traces -------------------------------- lc_line.set_data(t[:idx + 1], dip_combined[:idx + 1]) transit_line.set_data(t[:idx + 1], dip_transit[:idx + 1]) spots_line.set_data(t[:idx + 1], dip_spots[:idx + 1]) if show_spots: for k in range(nspot): spot_lc_lines[k].set_data( t[:idx + 1], dip_per_spot[k, :idx + 1]) vline.set_xdata([t_now]) return (spot_patches_art + ghost_patches_art + [planet_patch, time_text, lc_line, transit_line, spots_line, vline] + spot_lc_lines) anim = animation.FuncAnimation( fig, update, frames=n_frames, interval=1000 / fps, blit=False) if outfile is not None: import os outdir = os.path.dirname(outfile) if outdir: os.makedirs(outdir, exist_ok=True) if outfile.endswith(".gif"): writer = animation.PillowWriter(fps=fps) else: writer = animation.FFMpegWriter(fps=fps, bitrate=2000) print(f"Rendering {n_frames} frames to {outfile}...") anim.save(outfile, writer=writer, dpi=dpi) print("Done.") if save_last_frame is not None: update(n_frames - 1) fig.savefig(save_last_frame, dpi=dpi, bbox_inches="tight") print(f"Last frame saved to {save_last_frame}") plt.close(fig) return anim