Multi-Band GP Tutorial#

This tutorial walks through fitting a multi-band GP to simultaneous photometry in multiple wavelength bands. The chromatic kernel exploits the fact that starspots are cooler than the photosphere, producing wavelength-dependent variability amplitudes that constrain the spot temperature \(T_{\rm spot}\).

What you will learn

  1. Visualize chromatic spot contrast from the Planck function ratio

  2. Simulate multi-band lightcurves scaled by \(c(\lambda)\)

  3. Prepare multi-band data with MultiBandData

  4. Fit multi-band photometry with MultiBandGPSolver to infer \(T_{\rm spot}\)

  5. Use amplitude ratios as a diagnostic for spot temperature

  6. Predict in each band with the GP posterior

Prerequisites

Familiarity with the single-band workflow from the Quickstart and GP Solver tutorials.

import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp

from spotgp import (
    TrapezoidSymmetricEnvelope,
    VisibilityFunction,
    SpotEvolutionModel,
    LightcurveModel,
)
from spotgp.contrast import spot_contrast, contrast_factor
from spotgp.multiband import MultiBandData, MultiBandGPSolver

np.random.seed(42)

1. The chromatic spot contrast#

A starspot at temperature \(T_{\rm spot}\) on a photosphere at \(T_{\rm phot}\) produces a wavelength-dependent contrast ratio:

\[ f_{\rm spot}(\lambda) = \frac{B_\lambda(T_{\rm spot})}{B_\lambda(T_{\rm phot})} \]

The contrast factor \(c(\lambda) = 1 - f_{\rm spot}(\lambda)\) is the fractional flux deficit per unit spot area. It is large at short wavelengths (spots are dark) and small at long wavelengths (spots fade into the Rayleigh-Jeans tail).

Let’s visualize this for a Sun-like star (\(T_{\rm phot} = 5800\) K) with a spot at \(T_{\rm spot} = 4000\) K.

lam = jnp.linspace(3000, 25000, 500)  # Angstroms
T_phot = 5800.0

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

for T_spot in [3000, 3500, 4000, 4500, 5000]:
    f = np.array(spot_contrast(lam, T_spot, T_phot))
    c = np.array(contrast_factor(lam, T_spot, T_phot))
    label = f"$T_{{\\rm spot}} = {T_spot}$ K"
    axes[0].plot(lam / 1e4, f, label=label)
    axes[1].plot(lam / 1e4, c, label=label)

# Mark common bandpasses
bands_wl = {"SDSS $g$": 4640, "Kepler": 6400, "TESS": 7865, "2MASS $J$": 12350}
for name, wl in bands_wl.items():
    for ax in axes:
        ax.axvline(wl / 1e4, color="gray", ls=":", alpha=0.5, lw=1)
    axes[1].text(wl / 1e4, 0.98, name, fontsize=8, ha="center",
                 transform=axes[1].get_xaxis_transform())

axes[0].set_xlabel(r"Wavelength [$\mu$m]")
axes[0].set_ylabel(r"$f_{\rm spot}(\lambda)$")
axes[0].set_title("Spot-to-photosphere flux ratio")
axes[0].legend(fontsize=9)

axes[1].set_xlabel(r"Wavelength [$\mu$m]")
axes[1].set_ylabel(r"$c(\lambda) = 1 - f_{\rm spot}$")
axes[1].set_title("Contrast factor (flux deficit per unit area)")

fig.tight_layout()
plt.show()

2. Simulate multi-band lightcurves#

We generate a single-band lightcurve from the geometric model, then scale it by the contrast factor \(c(\lambda)\) for each band. This is exactly what the factorized kernel assumes: the temporal structure is identical across bands, and the amplitude scales with \(c(\lambda)\).

# Define the spot evolution model (geometric part)
envelope = TrapezoidSymmetricEnvelope(lspot=8.0, tau_spot=3.0)
visibility = VisibilityFunction(peq=5.0, kappa=0.2, inc=np.pi / 3)
model = SpotEvolutionModel(envelope=envelope, visibility=visibility, sigma_k=0.01)

# Simulate lightcurve (geometric flux deficit, no wavelength dependence)
lc = LightcurveModel.from_spot_model(
    spot_model=model,
    nspot=25,
    tsim=60,
    tsamp=0.5,
)

# Flux deficit relative to median (this is the "geometric" signal)
delta_F = lc.flux - np.median(lc.flux)
t = lc.t

print(f"Simulated {len(t)} data points over {t[-1] - t[0]:.0f} days")
# True spot and photosphere temperatures
T_phot_true = 5800.0
T_spot_true = 4200.0

# Define photometric bands
band_info = {
    "SDSS_g":  {"wavelength": 4640.0, "noise": 0.0015},
    "Kepler":  {"wavelength": 6400.0, "noise": 0.0008},
    "TESS":    {"wavelength": 7865.0, "noise": 0.0012},
}

# Scale geometric flux by contrast factor and add noise
bands = {}
for name, info in band_info.items():
    c = float(contrast_factor(jnp.array(info["wavelength"]), T_spot_true, T_phot_true))
    flux_band = 1.0 + c * delta_F + np.random.normal(0, info["noise"], len(t))
    bands[name] = {
        "x": t,
        "y": flux_band,
        "yerr": np.full(len(t), info["noise"]),
        "wavelength": info["wavelength"],
    }
    print(f"{name}: c({info['wavelength']:.0f} A) = {c:.4f}, "
          f"amplitude ~ {c * np.std(delta_F):.5f}")
fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)
colors = {"SDSS_g": "#4477AA", "Kepler": "#228833", "TESS": "#CC3311"}

for ax, (name, b) in zip(axes, bands.items()):
    ax.errorbar(b["x"], b["y"], yerr=b["yerr"],
                fmt=".", ms=3, color=colors[name], alpha=0.6, elinewidth=0.5)
    ax.set_ylabel("Normalized flux")
    ax.set_title(f"{name} ({b['wavelength']:.0f} " + r"$\AA$)", fontsize=12)

axes[-1].set_xlabel("Time [days]")
fig.suptitle(f"Simulated multi-band lightcurves  "
             f"($T_{{\\rm phot}}={T_phot_true:.0f}$ K, "
             f"$T_{{\\rm spot}}={T_spot_true:.0f}$ K)",
             fontsize=14, y=1.01)
fig.tight_layout()
plt.show()

3. Create a MultiBandData container#

MultiBandData merges observations from all bands into a single time-sorted array, tagging each observation with its band index. Each band is independently median-normalized.

data = MultiBandData(bands)

print(f"Total observations: {data.N}")
print(f"Number of bands:    {data.n_bands}")
print(f"Band names:         {data.band_names}")
print(f"Band wavelengths:   {data.band_wavelengths} Angstroms")
print(f"Time baseline:      {data.baseline:.1f} days")
print(f"Median cadence:     {data.median_dt:.2f} days")

4. Fit with MultiBandGPSolver#

The solver uses the factorized chromatic kernel:

\[ \mathcal{K}(\tau;\,\lambda_i,\lambda_j) = c(\lambda_i)\,c(\lambda_j)\,\mathcal{K}_{\rm geom}(\tau) \]

where \(c(\lambda) = 1 - B_\lambda(T_{\rm spot})/B_\lambda(T_{\rm phot})\). This adds one new parameter (\(T_{\rm spot}\)) to the standard six-parameter kernel.

Parameter

Description

Typical bounds

peq

Equatorial rotation period [days]

(1, 50)

kappa

Differential rotation shear

(0, 1)

inc

Stellar inclination [rad]

(0.01, \(\pi/2\))

lspot

Spot plateau duration [days]

(0.5, 30)

tau_spot

Spot rise/decay timescale [days]

(0.1, 10)

sigma_k

Geometric kernel amplitude

(1e-4, 0.1)

T_spot

Spot temperature [K]

(2500, 5500)

# Hyperparameters for the geometric kernel
hparam = dict(
    peq=5.0,
    kappa=0.2,
    inc=np.pi / 3,
    lspot=8.0,
    tau_spot=3.0,
    sigma_k=0.01,
)

# Parameter bounds (use log-space for sigma_k)
bounds = {
    "peq":         (2.0, 15.0),
    "kappa":       (0.001, 0.8),
    "inc":         (0.1, np.pi - 0.1),
    "lspot":       (1.0, 20.0),
    "tau_spot":    (0.1, 8.0),
    "log_sigma_k": (-4.0, -0.5),
    "T_spot":      (2500.0, 5500.0),
}

gp = MultiBandGPSolver(
    data,
    hparam,
    T_phot=T_phot_true,
    T_spot_init=4000.0,  # starting guess
    bounds=bounds,
    matrix_solver="cholesky_banded",
)

print(f"Parameters:  {gp.param_keys}")
print(f"n_params:    {gp.n_params}")
print(f"theta0:      {dict(zip(gp.param_keys, np.array(gp.theta0)))}")
# Pre-compile JAX functions
gp.build_jax()

# Check that the log-likelihood is finite at the initial parameters
ll0 = gp.log_likelihood_at(gp.theta0)
print(f"log L(theta0) = {ll0:.2f}")

5. Amplitude ratios between bands#

The ratio of variability amplitudes between two bands depends only on \(T_{\rm spot} / T_{\rm phot}\) and the band wavelengths:

\[ \frac{\sigma_k(\lambda_1)}{\sigma_k(\lambda_2)} = \frac{c(\lambda_1)}{c(\lambda_2)} = \frac{1 - f_{\rm spot}(\lambda_1)}{1 - f_{\rm spot}(\lambda_2)} \]

This is independent of all geometric parameters (\(P_{\rm eq}\), \(\kappa\), \(I\), \(\ell_{\rm spot}\), \(\tau_{\rm spot}\), \(\sigma_k\)) and provides a clean diagnostic for \(T_{\rm spot}\).

# Compute amplitude ratios at the true T_spot
theta_dict = {"T_spot": T_spot_true}

pairs = [("SDSS_g", "TESS"), ("Kepler", "TESS"), ("SDSS_g", "Kepler")]
print(f"Amplitude ratios at T_spot = {T_spot_true:.0f} K:")
for b1, b2 in pairs:
    lam1 = band_info[b1]["wavelength"]
    lam2 = band_info[b2]["wavelength"]
    ratio = gp.amplitude_ratio(lam1, lam2, theta=theta_dict)
    print(f"  sigma({b1}) / sigma({b2}) = {ratio:.4f}")
# Show how the amplitude ratio varies with T_spot
T_spot_range = np.linspace(2500, 5500, 200)
lam_g = band_info["SDSS_g"]["wavelength"]
lam_tess = band_info["TESS"]["wavelength"]

ratios = []
for T in T_spot_range:
    c_g = float(contrast_factor(jnp.array(lam_g), T, T_phot_true))
    c_t = float(contrast_factor(jnp.array(lam_tess), T, T_phot_true))
    ratios.append(c_g / c_t if c_t > 0 else np.nan)

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(T_spot_range, ratios, color="steelblue", lw=2)
ax.axvline(T_spot_true, color="red", ls="--", label=f"True $T_{{\\rm spot}}$ = {T_spot_true:.0f} K")
ax.set_xlabel(r"$T_{\rm spot}$ [K]")
ax.set_ylabel(r"$\sigma(g) \,/\, \sigma(\mathrm{TESS})$")
ax.set_title("Amplitude ratio as a diagnostic for spot temperature")
ax.legend()
fig.tight_layout()
plt.show()

6. GP predictions per band#

The predict() method returns the posterior mean and variance at new times. Pass band_wavelength to get predictions scaled by the contrast factor for a specific band.

# Switch to full solver for prediction (banded solver is for likelihood only)
gp_pred = MultiBandGPSolver(
    data, hparam, T_phot=T_phot_true, T_spot_init=T_spot_true,
    bounds=bounds, matrix_solver="cholesky_full",
)

xpred = np.linspace(t[0], t[-1], 300)

fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)

for ax, (name, b) in zip(axes, bands.items()):
    color = colors[name]
    wl = b["wavelength"]

    mu, var = gp_pred.predict(xpred, band_wavelength=wl)
    sigma = np.sqrt(np.maximum(var, 0))

    ax.fill_between(xpred, mu - 2 * sigma, mu + 2 * sigma,
                    color=color, alpha=0.15)
    ax.plot(xpred, mu, color=color, lw=1.5, label="GP mean")
    ax.errorbar(b["x"], b["y"], yerr=b["yerr"],
                fmt=".", ms=3, color="gray", alpha=0.5, elinewidth=0.5)
    ax.set_ylabel("Flux")
    ax.set_title(f"{name} ({wl:.0f} " + r"$\AA$)")
    ax.legend(loc="upper right", fontsize=9)

axes[-1].set_xlabel("Time [days]")
fig.suptitle("Multi-band GP predictions", fontsize=14, y=1.01)
fig.tight_layout()
plt.show()

Summary#

Concept

Key class/function

Multi-band data container

MultiBandData(bands_dict)

Chromatic GP solver

MultiBandGPSolver(data, hparam, T_phot, ...)

Planck contrast (default)

contrast_factor(wavelength, T_spot, T_phot)

Amplitude ratio diagnostic

gp.amplitude_ratio(lam1, lam2, theta=...)

Next steps#

  • Use bandpass-integrated spectral models for realistic contrast: see the Spectral Contrast Models tutorial

  • Run full MCMC with the multi-band solver using BlackJAX sampling or Dynesty sampling

  • Combine multi-band photometry with spectroscopic constraints in a hierarchical framework (see Paper II)