Time Series Modelling And Forecasting With gallifrey

November 21, 2025
JAX, Gaussian Processes, Time Series

NOTE: This quickstart is taken from the gallifrey documentation (https://chrisboettner.github.io/gallifrey/). Check it out for more information an usage examples, and see the paper (https://arxiv.org/abs/2505.20394) for the more mathsy details.

gallifrey is a Python package for Bayesian structure learning, inference, and analysis within Gaussian Process models, focused on time series data. Built on JAX and using Sequential Monte Carlo (SMC) techniques, it enables efficient and flexible modeling of complex time series data. This guide will walk you through the basic steps to get started with gallifrey.

Setup environment

Before importing gallifrey, we need to configure JAX to utilize all available CPU cores. This is done using the following code snippet:

import multiprocessing
import os

# enable jax to recognize all CPU cores
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
# import necessary packages and set up plotting
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(
    context="poster",
    style="ticks",
    palette="rocket",
    font_scale=1,
    rc={
        "figure.figsize": (16, 7),
        "axes.grid": False,
        "font.family": "serif",
        "text.usetex": True,
        "lines.linewidth": 3,
    },
)

Next, import the core components from gallifrey:

from gallifrey import GPConfig, GPModel, LinearSchedule
gallifrey: Setting flag `JAX_ENABLE_X64` to `True`
gallifrey: Setting flag `OMP_NUM_THREADS` to `1`

Generate mock data

For this quickstart, we'll generate some mock data for demonstration. We create a simple data set, and reserve some data for training and testing.

rng_key = jr.PRNGKey(0)
# Mock data
key, data_key = jr.split(rng_key)
n = 160
noise_var = 9.0
x = jnp.linspace(0, 15, n)
y = (x + 0.01) * jnp.sin(x * 3.2) + jnp.sqrt(noise_var) * jr.normal(data_key, (n,))


# mask values
xtrain = x[(x < 10)]
ytrain = y[(x < 10)]

Initialize the GP Model

Now we can initialize the Gaussian Process model. There are a variety of setting for the details of the model, but for this quickstart we can stick with the default config. Please see the tutorials for more details on the different options. We also have to set the number of particles, which together form an ensemble of Gaussian Processes used to make predictions.

config = GPConfig()

key, model_key = jr.split(key)
gpmodel = GPModel(
    model_key,
    x=xtrain,
    y=ytrain,
    num_particles=8,
    config=config,
)

Fit the GP Model using SMC

And now we can fit the GP model to the data. We use a Sequential Monte Carlo sampling with a data annealing schedule.

key, smc_key = jr.split(key)
final_smc_state, history = gpmodel.fit_smc(
    smc_key,
    annealing_schedule=LinearSchedule().generate(len(xtrain), 10),
    n_mcmc=75,
    n_hmc=10,
    verbosity=1,
)
gpmodel = gpmodel.update_state(final_smc_state)
Running SMC round [1/10] with [1/106] data points.
Weights: [0.05392501 0.10727033 0.23023163 0.07581755 0.09684554 0.17955717
 0.16175602 0.09459675]
Resampled: False (Normalised ESS: 0.83)
Particle 1 | Accepted: MCMC[49/75]  HMC[490/490]
Particle 2 | Accepted: MCMC[45/75]  HMC[450/450]
Particle 3 | Accepted: MCMC[51/75]  HMC[510/510]
Particle 4 | Accepted: MCMC[51/75]  HMC[510/510]
Particle 5 | Accepted: MCMC[48/75]  HMC[480/480]
Particle 6 | Accepted: MCMC[54/75]  HMC[540/540]
Particle 7 | Accepted: MCMC[57/75]  HMC[570/570]
Particle 8 | Accepted: MCMC[62/75]  HMC[620/620]
==================================================
Running SMC round [2/10] with [13/106] data points.
Weights: [8.05914360e-04 5.47702896e-06 3.93273140e-07 1.17978333e-02
 6.02643432e-01 6.84271626e-05 3.84672447e-01 6.07536689e-06]
Resampled: True (Normalised ESS: 0.24)
Particle 1 | Accepted: MCMC[48/75]  HMC[479/480]
Particle 2 | Accepted: MCMC[36/75]  HMC[360/360]
Particle 3 | Accepted: MCMC[46/75]  HMC[460/460]
Particle 4 | Accepted: MCMC[46/75]  HMC[460/460]
Particle 5 | Accepted: MCMC[42/75]  HMC[420/420]
Particle 6 | Accepted: MCMC[46/75]  HMC[460/460]
Particle 7 | Accepted: MCMC[43/75]  HMC[430/430]
Particle 8 | Accepted: MCMC[41/75]  HMC[410/410]
==================================================
Running SMC round [3/10] with [24/106] data points.
Weights: [0.01799203 0.38161775 0.31482115 0.00537581 0.0032308  0.01912249
 0.25110413 0.00673584]
Resampled: True (Normalised ESS: 0.41)
Particle 1 | Accepted: MCMC[43/75]  HMC[430/430]
Particle 2 | Accepted: MCMC[47/75]  HMC[470/470]
Particle 3 | Accepted: MCMC[50/75]  HMC[500/500]
Particle 4 | Accepted: MCMC[40/75]  HMC[400/400]
Particle 5 | Accepted: MCMC[48/75]  HMC[480/480]
Particle 6 | Accepted: MCMC[30/75]  HMC[300/300]
Particle 7 | Accepted: MCMC[47/75]  HMC[470/470]
Particle 8 | Accepted: MCMC[44/75]  HMC[440/440]
==================================================
Running SMC round [4/10] with [36/106] data points.
Weights: [0.03426517 0.0556542  0.11974631 0.09799773 0.45193096 0.17182724
 0.05309929 0.01547909]
Resampled: True (Normalised ESS: 0.47)
Particle 1 | Accepted: MCMC[40/75]  HMC[400/400]
Particle 2 | Accepted: MCMC[44/75]  HMC[440/440]
Particle 3 | Accepted: MCMC[46/75]  HMC[460/460]
Particle 4 | Accepted: MCMC[39/75]  HMC[390/390]
Particle 5 | Accepted: MCMC[47/75]  HMC[470/470]
Particle 6 | Accepted: MCMC[34/75]  HMC[339/340]
Particle 7 | Accepted: MCMC[36/75]  HMC[360/360]
Particle 8 | Accepted: MCMC[46/75]  HMC[460/460]
==================================================
Running SMC round [5/10] with [48/106] data points.
Weights: [0.021441   0.2299715  0.2355654  0.02415679 0.11430621 0.24059889
 0.09370489 0.04025532]
Resampled: False (Normalised ESS: 0.66)
Particle 1 | Accepted: MCMC[42/75]  HMC[419/420]
Particle 2 | Accepted: MCMC[47/75]  HMC[469/470]
Particle 3 | Accepted: MCMC[38/75]  HMC[380/380]
Particle 4 | Accepted: MCMC[42/75]  HMC[419/420]
Particle 5 | Accepted: MCMC[44/75]  HMC[439/440]
Particle 6 | Accepted: MCMC[46/75]  HMC[460/460]
Particle 7 | Accepted: MCMC[41/75]  HMC[410/410]
Particle 8 | Accepted: MCMC[43/75]  HMC[430/430]
==================================================
Running SMC round [6/10] with [59/106] data points.
Weights: [0.01346345 0.46789399 0.15199219 0.02610874 0.09903758 0.14641147
 0.06438393 0.03070865]
Resampled: True (Normalised ESS: 0.45)
Particle 1 | Accepted: MCMC[34/75]  HMC[340/340]
Particle 2 | Accepted: MCMC[31/75]  HMC[309/310]
Particle 3 | Accepted: MCMC[28/75]  HMC[279/280]
Particle 4 | Accepted: MCMC[42/75]  HMC[420/420]
Particle 5 | Accepted: MCMC[49/75]  HMC[490/490]
Particle 6 | Accepted: MCMC[41/75]  HMC[410/410]
Particle 7 | Accepted: MCMC[44/75]  HMC[440/440]
Particle 8 | Accepted: MCMC[39/75]  HMC[390/390]
==================================================
Running SMC round [7/10] with [71/106] data points.
Weights: [0.01389594 0.01428358 0.8226441  0.0120111  0.00830618 0.02457485
 0.09274809 0.01153616]
Resampled: True (Normalised ESS: 0.18)
Particle 1 | Accepted: MCMC[14/75]  HMC[140/140]
Particle 2 | Accepted: MCMC[40/75]  HMC[399/400]
Particle 3 | Accepted: MCMC[19/75]  HMC[189/190]
Particle 4 | Accepted: MCMC[23/75]  HMC[230/230]
Particle 5 | Accepted: MCMC[15/75]  HMC[150/150]
Particle 6 | Accepted: MCMC[20/75]  HMC[199/200]
Particle 7 | Accepted: MCMC[17/75]  HMC[168/170]
Particle 8 | Accepted: MCMC[33/75]  HMC[330/330]
==================================================
Running SMC round [8/10] with [83/106] data points.
Weights: [1.44958343e-01 4.55231902e-04 4.55909880e-01 1.89346520e-02
 9.01826052e-02 2.08687957e-01 8.00052122e-02 8.66118823e-04]
Resampled: True (Normalised ESS: 0.44)
Particle 1 | Accepted: MCMC[19/75]  HMC[190/190]
Particle 2 | Accepted: MCMC[32/75]  HMC[317/320]
Particle 3 | Accepted: MCMC[17/75]  HMC[169/170]
Particle 4 | Accepted: MCMC[24/75]  HMC[238/240]
Particle 5 | Accepted: MCMC[23/75]  HMC[227/230]
Particle 6 | Accepted: MCMC[18/75]  HMC[176/180]
Particle 7 | Accepted: MCMC[24/75]  HMC[237/240]
Particle 8 | Accepted: MCMC[24/75]  HMC[239/240]
==================================================
Running SMC round [9/10] with [94/106] data points.
Weights: [0.08233663 0.23588815 0.10630179 0.05492449 0.00296773 0.1164186
 0.25538018 0.14578241]
Resampled: False (Normalised ESS: 0.71)
Particle 1 | Accepted: MCMC[24/75]  HMC[235/240]
Particle 2 | Accepted: MCMC[17/75]  HMC[162/170]
Particle 3 | Accepted: MCMC[17/75]  HMC[163/170]
Particle 4 | Accepted: MCMC[18/75]  HMC[176/180]
Particle 5 | Accepted: MCMC[18/75]  HMC[177/180]
Particle 6 | Accepted: MCMC[19/75]  HMC[177/190]
Particle 7 | Accepted: MCMC[25/75]  HMC[235/250]
Particle 8 | Accepted: MCMC[20/75]  HMC[195/200]
==================================================
Running SMC round [10/10] with [106/106] data points.
Weights: [0.04970241 0.098085   0.03733657 0.01727045 0.00186775 0.09631474
 0.64934274 0.05008033]
Resampled: False (Normalised ESS: 0.28)
Particle 1 | Accepted: MCMC[23/75]  HMC[229/230]
Particle 2 | Accepted: MCMC[26/75]  HMC[249/260]
Particle 3 | Accepted: MCMC[22/75]  HMC[203/220]
Particle 4 | Accepted: MCMC[20/75]  HMC[172/200]
Particle 5 | Accepted: MCMC[15/75]  HMC[146/150]
Particle 6 | Accepted: MCMC[17/75]  HMC[151/170]
Particle 7 | Accepted: MCMC[17/75]  HMC[157/170]
Particle 8 | Accepted: MCMC[16/75]  HMC[148/160]
==================================================

Predictions using the model

With the fitted model, we can now make predictions. Since the SMC sampler fits an entire ensemble of Gaussian processes, we can use the mixture of predictive distributions for our forecasting.

xtest = gpmodel.x_transform(jnp.linspace(0, 20, 500))
dist = gpmodel.get_mixture_distribution(xtest)

predictive_mean = dist.mean()
predictive_std = dist.stddev()
plot = sns.lineplot(x=xtest, y=predictive_mean)
plot.fill_between(
    xtest,
    predictive_mean - predictive_std,
    predictive_mean + predictive_std,
    alpha=0.3,
)

sns.scatterplot(
    x=gpmodel.x_transformed,
    y=gpmodel.y_transformed,
    label="Training Data",
    ax=plot,
    zorder=3,
)
sns.scatterplot(
    x=gpmodel.x_transform(x),
    y=gpmodel.y_transform(y),
    label="Test Data",
    ax=plot,
    zorder=2,
)
No description has been provided for this image

Alternatively, we can look that the prediction from each particle individually.

dist = gpmodel.get_predictive_distributions(xtest)

means, stds = [], []
for d in dist:
    means.append(d.mean())
    stds.append(d.stddev())

fig, ax = plt.subplots()
for i in range(len(means)):
    sns.lineplot(x=xtest, y=means[i], label=f"Particle {i+1}", ax=ax)
    ax.fill_between(
        xtest,
        means[i] - stds[i],
        means[i] + stds[i],
        alpha=0.1,
    )

sns.scatterplot(
    x=gpmodel.x_transformed,
    y=gpmodel.y_transformed,
    ax=ax,
    zorder=3,
    color="C0",
)
sns.scatterplot(
    x=gpmodel.x_transform(x),
    y=gpmodel.y_transform(y),
    ax=ax,
    zorder=2,
    color="C1",
)
ax.legend(ncols=4)
No description has been provided for this image

Conclusion

And we're done! This quickstart guide provides a minimal example to get you started. For more advanced features, customization options, and detailed explanations, please refer to the tutorials and full documentation.