Getting started¶
gallifrey
is a Python package for Bayesian structure learning, inference, and analysis within Gaussian Process models, focused on time series data. Built on JAX and using Sequential Monte Carlo (SMC) techniques, it enables efficient and flexible modeling of complex time series data. This guide will walk you through the basic steps to get started with gallifrey
.
Setup environment¶
Before importing gallifrey, we need to configure JAX to utilize all available CPU cores. This is done using the following code snippet:
import multiprocessing
import os
# enable jax to recognize all CPU cores
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
multiprocessing.cpu_count()
)
# import necessary packages and set up plotting
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(
context="poster",
style="ticks",
palette="rocket",
font_scale=1,
rc={
"figure.figsize": (16, 7),
"axes.grid": False,
"font.family": "serif",
"text.usetex": True,
"lines.linewidth": 3,
},
)
Next, import the core components from gallifrey
:
from gallifrey import GPConfig, GPModel, LinearSchedule
gallifrey: Setting flag `JAX_ENABLE_X64` to `True` gallifrey: Setting flag `OMP_NUM_THREADS` to `1`
Generate mock data¶
For this quickstart, we'll generate some mock data for demonstration. We create a simple data set, and reserve some data for training and testing.
rng_key = jr.PRNGKey(0)
# Mock data
key, data_key = jr.split(rng_key)
n = 160
noise_var = 9.0
x = jnp.linspace(0, 15, n)
y = (x + 0.01) * jnp.sin(x * 3.2) + jnp.sqrt(noise_var) * jr.normal(data_key, (n,))
# mask values
xtrain = x[(x < 10)]
ytrain = y[(x < 10)]
Initialize the GP Model¶
Now we can initialize the Gaussian Process model. There are a variety of setting for the details of the model, but for this quickstart we can stick with the default config. Please see the tutorials for more details on the different options. We also have to set the number of particles, which together form an ensemble of Gaussian Processes used to make predictions.
config = GPConfig()
key, model_key = jr.split(key)
gpmodel = GPModel(
model_key,
x=xtrain,
y=ytrain,
num_particles=8,
config=config,
)
Fit the GP Model using SMC¶
And now we can fit the GP model to the data. We use a Sequential Monte Carlo sampling with a data annealing schedule.
key, smc_key = jr.split(key)
final_smc_state, history = gpmodel.fit_smc(
smc_key,
annealing_schedule=LinearSchedule().generate(len(xtrain), 10),
n_mcmc=75,
n_hmc=10,
verbosity=1,
)
gpmodel = gpmodel.update_state(final_smc_state)
Running SMC round [1/10] with [1/106] data points. Weights: [0.05392501 0.10727033 0.23023163 0.07581755 0.09684554 0.17955717 0.16175602 0.09459675] Resampled: False (Normalised ESS: 0.83) Particle 1 | Accepted: MCMC[49/75] HMC[490/490] Particle 2 | Accepted: MCMC[45/75] HMC[450/450] Particle 3 | Accepted: MCMC[51/75] HMC[510/510] Particle 4 | Accepted: MCMC[51/75] HMC[510/510] Particle 5 | Accepted: MCMC[48/75] HMC[480/480] Particle 6 | Accepted: MCMC[54/75] HMC[540/540] Particle 7 | Accepted: MCMC[57/75] HMC[570/570] Particle 8 | Accepted: MCMC[62/75] HMC[620/620] ================================================== Running SMC round [2/10] with [13/106] data points. Weights: [8.05914360e-04 5.47702896e-06 3.93273140e-07 1.17978333e-02 6.02643432e-01 6.84271626e-05 3.84672447e-01 6.07536689e-06] Resampled: True (Normalised ESS: 0.24) Particle 1 | Accepted: MCMC[48/75] HMC[479/480] Particle 2 | Accepted: MCMC[36/75] HMC[360/360] Particle 3 | Accepted: MCMC[46/75] HMC[460/460] Particle 4 | Accepted: MCMC[46/75] HMC[460/460] Particle 5 | Accepted: MCMC[42/75] HMC[420/420] Particle 6 | Accepted: MCMC[46/75] HMC[460/460] Particle 7 | Accepted: MCMC[43/75] HMC[430/430] Particle 8 | Accepted: MCMC[41/75] HMC[410/410] ================================================== Running SMC round [3/10] with [24/106] data points. Weights: [0.01799203 0.38161775 0.31482115 0.00537581 0.0032308 0.01912249 0.25110413 0.00673584] Resampled: True (Normalised ESS: 0.41) Particle 1 | Accepted: MCMC[43/75] HMC[430/430] Particle 2 | Accepted: MCMC[47/75] HMC[470/470] Particle 3 | Accepted: MCMC[50/75] HMC[500/500] Particle 4 | Accepted: MCMC[40/75] HMC[400/400] Particle 5 | Accepted: MCMC[48/75] HMC[480/480] Particle 6 | Accepted: MCMC[30/75] HMC[300/300] Particle 7 | Accepted: MCMC[47/75] HMC[470/470] Particle 8 | Accepted: MCMC[44/75] HMC[440/440] ================================================== Running SMC round [4/10] with [36/106] data points. Weights: [0.03426517 0.0556542 0.11974631 0.09799773 0.45193096 0.17182724 0.05309929 0.01547909] Resampled: True (Normalised ESS: 0.47) Particle 1 | Accepted: MCMC[40/75] HMC[400/400] Particle 2 | Accepted: MCMC[44/75] HMC[440/440] Particle 3 | Accepted: MCMC[46/75] HMC[460/460] Particle 4 | Accepted: MCMC[39/75] HMC[390/390] Particle 5 | Accepted: MCMC[47/75] HMC[470/470] Particle 6 | Accepted: MCMC[34/75] HMC[339/340] Particle 7 | Accepted: MCMC[36/75] HMC[360/360] Particle 8 | Accepted: MCMC[46/75] HMC[460/460] ================================================== Running SMC round [5/10] with [48/106] data points. Weights: [0.021441 0.2299715 0.2355654 0.02415679 0.11430621 0.24059889 0.09370489 0.04025532] Resampled: False (Normalised ESS: 0.66) Particle 1 | Accepted: MCMC[42/75] HMC[419/420] Particle 2 | Accepted: MCMC[47/75] HMC[469/470] Particle 3 | Accepted: MCMC[38/75] HMC[380/380] Particle 4 | Accepted: MCMC[42/75] HMC[419/420] Particle 5 | Accepted: MCMC[44/75] HMC[439/440] Particle 6 | Accepted: MCMC[46/75] HMC[460/460] Particle 7 | Accepted: MCMC[41/75] HMC[410/410] Particle 8 | Accepted: MCMC[43/75] HMC[430/430] ================================================== Running SMC round [6/10] with [59/106] data points. Weights: [0.01346345 0.46789399 0.15199219 0.02610874 0.09903758 0.14641147 0.06438393 0.03070865] Resampled: True (Normalised ESS: 0.45) Particle 1 | Accepted: MCMC[34/75] HMC[340/340] Particle 2 | Accepted: MCMC[31/75] HMC[309/310] Particle 3 | Accepted: MCMC[28/75] HMC[279/280] Particle 4 | Accepted: MCMC[42/75] HMC[420/420] Particle 5 | Accepted: MCMC[49/75] HMC[490/490] Particle 6 | Accepted: MCMC[41/75] HMC[410/410] Particle 7 | Accepted: MCMC[44/75] HMC[440/440] Particle 8 | Accepted: MCMC[39/75] HMC[390/390] ================================================== Running SMC round [7/10] with [71/106] data points. Weights: [0.01389594 0.01428358 0.8226441 0.0120111 0.00830618 0.02457485 0.09274809 0.01153616] Resampled: True (Normalised ESS: 0.18) Particle 1 | Accepted: MCMC[14/75] HMC[140/140] Particle 2 | Accepted: MCMC[40/75] HMC[399/400] Particle 3 | Accepted: MCMC[19/75] HMC[189/190] Particle 4 | Accepted: MCMC[23/75] HMC[230/230] Particle 5 | Accepted: MCMC[15/75] HMC[150/150] Particle 6 | Accepted: MCMC[20/75] HMC[199/200] Particle 7 | Accepted: MCMC[17/75] HMC[168/170] Particle 8 | Accepted: MCMC[33/75] HMC[330/330] ================================================== Running SMC round [8/10] with [83/106] data points. Weights: [1.44958343e-01 4.55231902e-04 4.55909880e-01 1.89346520e-02 9.01826052e-02 2.08687957e-01 8.00052122e-02 8.66118823e-04] Resampled: True (Normalised ESS: 0.44) Particle 1 | Accepted: MCMC[19/75] HMC[190/190] Particle 2 | Accepted: MCMC[32/75] HMC[317/320] Particle 3 | Accepted: MCMC[17/75] HMC[169/170] Particle 4 | Accepted: MCMC[24/75] HMC[238/240] Particle 5 | Accepted: MCMC[23/75] HMC[227/230] Particle 6 | Accepted: MCMC[18/75] HMC[176/180] Particle 7 | Accepted: MCMC[24/75] HMC[237/240] Particle 8 | Accepted: MCMC[24/75] HMC[239/240] ================================================== Running SMC round [9/10] with [94/106] data points. Weights: [0.08233663 0.23588815 0.10630179 0.05492449 0.00296773 0.1164186 0.25538018 0.14578241] Resampled: False (Normalised ESS: 0.71) Particle 1 | Accepted: MCMC[24/75] HMC[235/240] Particle 2 | Accepted: MCMC[17/75] HMC[162/170] Particle 3 | Accepted: MCMC[17/75] HMC[163/170] Particle 4 | Accepted: MCMC[18/75] HMC[176/180] Particle 5 | Accepted: MCMC[18/75] HMC[177/180] Particle 6 | Accepted: MCMC[19/75] HMC[177/190] Particle 7 | Accepted: MCMC[25/75] HMC[235/250] Particle 8 | Accepted: MCMC[20/75] HMC[195/200] ================================================== Running SMC round [10/10] with [106/106] data points. Weights: [0.04970241 0.098085 0.03733657 0.01727045 0.00186775 0.09631474 0.64934274 0.05008033] Resampled: False (Normalised ESS: 0.28) Particle 1 | Accepted: MCMC[23/75] HMC[229/230] Particle 2 | Accepted: MCMC[26/75] HMC[249/260] Particle 3 | Accepted: MCMC[22/75] HMC[203/220] Particle 4 | Accepted: MCMC[20/75] HMC[172/200] Particle 5 | Accepted: MCMC[15/75] HMC[146/150] Particle 6 | Accepted: MCMC[17/75] HMC[151/170] Particle 7 | Accepted: MCMC[17/75] HMC[157/170] Particle 8 | Accepted: MCMC[16/75] HMC[148/160] ==================================================
Predictions using the model¶
With the fitted model, we can now make predictions. Since the SMC sampler fits an entire ensemble of Gaussian processes, we can use the mixture of predictive distributions for our forecasting.
xtest = gpmodel.x_transform(jnp.linspace(0, 20, 500))
dist = gpmodel.get_mixture_distribution(xtest)
predictive_mean = dist.mean()
predictive_std = dist.stddev()
plot = sns.lineplot(x=xtest, y=predictive_mean)
plot.fill_between(
xtest,
predictive_mean - predictive_std,
predictive_mean + predictive_std,
alpha=0.3,
)
sns.scatterplot(
x=gpmodel.x_transformed,
y=gpmodel.y_transformed,
label="Training Data",
ax=plot,
zorder=3,
)
sns.scatterplot(
x=gpmodel.x_transform(x),
y=gpmodel.y_transform(y),
label="Test Data",
ax=plot,
zorder=2,
)
<Axes: >
Alternatively, we can look that the prediction from each particle individually.
dist = gpmodel.get_predictive_distributions(xtest)
means, stds = [], []
for d in dist:
means.append(d.mean())
stds.append(d.stddev())
fig, ax = plt.subplots()
for i in range(len(means)):
sns.lineplot(x=xtest, y=means[i], label=f"Particle {i+1}", ax=ax)
ax.fill_between(
xtest,
means[i] - stds[i],
means[i] + stds[i],
alpha=0.1,
)
sns.scatterplot(
x=gpmodel.x_transformed,
y=gpmodel.y_transformed,
ax=ax,
zorder=3,
color="C0",
)
sns.scatterplot(
x=gpmodel.x_transform(x),
y=gpmodel.y_transform(y),
ax=ax,
zorder=2,
color="C1",
)
ax.legend(ncols=4)
<matplotlib.legend.Legend at 0x7cc6c02549d0>
Conclusion¶
And we're done! This quickstart guide provides a minimal example to get you started. For more advanced features, customization options, and detailed explanations, please refer to the tutorials and full documentation.