Persistent Sampling As An Extension To Sequential Monte Carlo

November 03, 2025
Statistics, Python, Monte Carlo

In the previous post, we introduced Sequential Monte Carlo (SMC) methods as an alternative to traditional MCMC for Bayesian inference. SMC is a natural choice for sampling multimodal posteriors, or in situations where new data arrives over time. SMC is also exciting for its parallelisation opportunities, and because of the fact that it produces an estimate of the model evidence $Z$ as a valuable by-product.

However, SMC also seems a bit wasteful. We're updating the entire ensemble of particles in every iteration, and in the end only keep the very last state to approximate the posterior. That means if we have 1000 particles and run SMC for 50 tempering steps, we evaluate $ 1000 \times 50 = 50,000$ particle positions, but only have a final posterior sample of $1000$ points. If we want to increase the final sample, we have to increase the number of particles, which increases the computational time at every single tempering step. In MCMC on the other hand, we can simply run the chains for longer if we want to get more samples.

To improve the efficiency of SMC, Karamanis et al. (2025) recently introduced Persistent Sampling (PS), an extension of SMC with the key difference that it retains and reuses a growing ensemble of all particles from previous iterations rather than discarding particles after each iteration.

I recently implemented Persistent Sampling in the BlackJAX sampler library. In this post I want to give a brief introduction to PS and demonstrate its practical application to a regression problem in BlackJAX.

From SMC to Persistent Sampling

SMC defines a sequence of intermediate distributions

$$\pi_0(\theta), \pi_1(\theta), \ldots, \pi_T(\theta)$$

where $\pi_0(\theta) = p(\theta)$ is the prior and $\pi_T(\theta) = p(\theta | y)$ is the posterior. We can construct this sequence using a method called likelihood tempering

$$\pi_t(\theta) \propto p(y | \theta)^{\beta_t} p(\theta)$$

with $0 = \beta_0 \leq \beta_1 \leq \cdots \leq \beta_T = 1$.

At each iteration $t$, SMC maintains $N$ weighted particles $\{\theta_t^{(i)}, w_t^{(i)}\}_{i=1}^N$ that approximate $\pi_t$. The algorithm proceeds through reweighting, resampling, and MCMC move steps.

Standard SMC discards the particle ensemble from iteration $t-1$ after resampling to create the ensemble for iteration $t$.

Persistent Sampling takes a different approach. Instead of discarding the particles, we retain them and treat them as samples from a mixture distribution:

$$\tilde{\pi}_t(\theta) = \frac{1}{t-1} \sum_{s=1}^{t-1} \pi_s(\theta)$$

At iteration $t$, we have a persistent ensemble of $(t-1) \times N$ particles consisting of all particles generated in iterations $1, \ldots, t-1$.

To use particles from previous iterations to estimate expectations under $\pi_t$, we need to properly weight the particles. In PS, the weight for particle $\theta^{(i)}_{t'}$ (generated at iteration $t'$) when estimating expectations at iteration $t$ is

$$W^{(i)}_{t,t'} = \frac{\pi_t(\theta^{(i)}_{t'})}{\frac{1}{t-1}\sum_{s=1}^{t-1} \frac{\pi_s(\theta^{(i)}_{t'})}{Z_s}} \cdot \frac{1}{\hat{Z}_t}$$

For likelihood tempering, this simplifies to:

$$W^{(i)}_{t,t'} = \frac{p(y | \theta^{(i)}_{t'})^{\beta_t}}{\frac{1}{t-1}\sum_{s=1}^{t-1} \frac{p(y | \theta^{(i)}_{t'})^{\beta_s}}{\hat{Z}_s}} \cdot \frac{1}{\hat{Z}_t}$$

where $\hat{Z}_s$ are the estimated marginal likelihoods up to iteration $s$, which can be directly calculated from the weights. With these weights, the Persistent Sampling algorithm works as follows:

At each iteration $t$:

  1. Compute persistent weights $W^{(i)}_{t,t'}$ for all $(t-1) \times N$ particles
  2. Resample $N$ particles from this persistent ensemble according to their weights
  3. Apply MCMC moves to the resampled particles to target $\pi_t$
  4. Add these $N$ new particles to the persistent ensemble

This process continues until $\beta_T = 1$ and we have a final persistent ensemble of $T \times N$ weighted particles approximating the posterior. If we want to increase the sample size even further, we can simply continue to run the sampler with the tempering parameter $\beta$ fixed to $1$, adding more particles to the ensemble every iteration. We can even run the algorithm until a quantity of interest (e.g. the model evidence $Z$) has converged.

The downside of this process is the increased memory footprint. We have to keep track of $t \times N$ particles instead of just $N$ like in standard SMC. This can lead to problems if we run very long tempering schedules.

A Regression Example

To demonstrate the application of Persistent Sampling, we'll tackle a simple regression problem. We fit a logistic growth curve to noisy data with outliers. As mentioned, I recently implemented PS in BlackJAX, so we will use that framework.

# Install BlackJAX from the GitHub main branch, rather than PyPI (at time of writing, the PyPI version does not yet include Persistent Sampling)
#!pip install git+https://github.com/blackjax-devs/blackjax
# setup
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import jax.numpy as jnp
import blackjax
from blackjax.smc.persistent_sampling import compute_persistent_ess
from blackjax.smc.resampling import systematic
from blackjax.smc import extend_params

# Set random seed
key = jax.random.PRNGKey(2025)

# Plotting setup
sns.set_style("white")
plt.rcParams["figure.figsize"] = (12, 6)

Problem Setup

The logistic function is

$$f(x; L, k, x_0) = \frac{L}{1 + e^{-k(x - x_0)}}$$

where

  • $L$ is maximum value
  • $k$ is the growth rate
  • $x_0$ is the inflection point

We'll generate some synthetic data and add measurement noise and outliers to make this problem challenging and realistic. The measurement noise distribution is given by a Gaussian $$\mathcal{N}(0, \sigma^2)$$ with standard deviation $\sigma$, as usual. For the outliers, we contaminate approximately 10% of the data with values from another Gaussian distribution $$\mathcal{N}(0, \sigma_{out}^2)$$ with $\sigma_{out} \gg \sigma$.

def logistic_function(
    x: jnp.ndarray,
    L: jnp.ndarray,
    k: jnp.ndarray,
    x0: jnp.ndarray,
) -> jnp.ndarray:
    """Logistic growth function."""
    return L / (1 + jnp.exp(-k * (x - x0)))


# true parameters
true_parameters = {
    "L": jnp.array(30.0),  # Maximum value
    "k": jnp.array(0.4),  # Growth rate
    "x0": jnp.array(10.0),  # Midpoint
    "sigma": jnp.array(2.0),  # Noise standard deviation
    "sigma_out": jnp.array(15.0),  # Outlier noise standard deviation
    "w": jnp.array(0.1),  # Outlier proportion
}

# generate clean data
n_points = 60
x_data = jnp.linspace(0, 20, n_points)
y_clean = logistic_function(
    x_data,
    true_parameters["L"],
    true_parameters["k"],
    true_parameters["x0"],
)

# add Gaussian noise
key, noise_key = jax.random.split(key)
y_noise = y_clean + true_parameters["sigma"] * jax.random.normal(noise_key, (n_points,))

# add outliers (10% of data points)
n_outliers = int(true_parameters["w"] * n_points)
key, outlier_key = jax.random.split(key)
outlier_indices = jax.random.choice(
    outlier_key,
    n_points,
    shape=(n_outliers,),
    replace=False,
)
outlier_values = y_clean[outlier_indices] + true_parameters[
    "sigma_out"
] * jax.random.normal(outlier_key, (n_outliers,))
y_data = y_noise.at[outlier_indices].set(outlier_values)
# plot the data
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x_data, y_clean, c="purple", linewidth=2, label="True function", alpha=0.8)
ax.scatter(x_data, y_data, c="purple", s=50, alpha=0.4, label="Observed data")
ax.scatter(
    x_data[outlier_indices],
    y_data[outlier_indices],
    s=100,
    color="purple",
    marker="x",
    linewidths=3,
    label="Outliers",
)
ax.set_title("Logistic Function with Outliers", fontsize=16, fontweight="bold")
ax.legend(fontsize=12)
plt.tight_layout()
No description has been provided for this image

Bayesian Model with Mixture Likelihood

To handle outliers, we'll use a mixture model likelihood that explicitly models both normal observations and outliers as separate components.

Prior distributions:

  • $\log L \sim \mathcal{N}(\log(20), \log(3))$
  • $\log k \sim \mathcal{N}(\log(1), \log(0.2))$
  • $x_0 \sim \mathcal{N}(12, 4)$
  • $\log \sigma \sim \mathcal{N}(\log(3), \log(2))$
  • $\log \sigma_{\text{out}} \sim \mathcal{N}(\log(10), \log(10))$ - standard deviation for outlier component
  • $w \sim \text{Beta}(4, 20)$ — mixing weight (probability of outlier)

Likelihood (mixture model):

For each observation $y_i$, the likelihood is a mixture of two Gaussians:

$$p(y_i \mid L, k, x_0, \sigma, \sigma_{\text{out}}, w) = (1-w) \cdot \mathcal{N}(y_i \mid f(x_i; L, k, x_0), \sigma^2) + w \cdot \mathcal{N}(y_i \mid f(x_i; L, k, x_0), \sigma_{\text{out}}^2)$$

where:

  • The first component models normal observations with small noise $\sigma$
  • The second component models outliers with large noise $\sigma_{\text{out}}$
  • $w$ controls the probability that an observation is an outlier
def log_prior(theta: dict) -> jnp.ndarray:
    """Log prior density."""
    log_p_L = jax.scipy.stats.norm.logpdf(
        theta["log_L"], loc=jnp.log(20), scale=jnp.log(3)
    )
    log_p_k = jax.scipy.stats.norm.logpdf(
        theta["log_k"], loc=jnp.log(1), scale=jnp.log(0.2)
    )
    log_p_x0 = jax.scipy.stats.norm.logpdf(theta["x0"], loc=12, scale=4)

    log_p_sigma = jax.scipy.stats.norm.logpdf(
        theta["log_sigma"], loc=jnp.log(3), scale=jnp.log(2)
    )
    log_p_sigma_out = jax.scipy.stats.norm.logpdf(
        theta["log_sigma_out"], loc=jnp.log(10), scale=jnp.log(10)
    )
    log_p_w = jax.scipy.stats.beta.logpdf(theta["w"], a=4, b=20)

    return log_p_L + log_p_k + log_p_x0 + log_p_sigma + log_p_sigma_out + log_p_w


def log_likelihood(theta: dict) -> jnp.ndarray:
    """Log likelihood using Gaussian mixture model."""
    # predicted values
    y_pred = logistic_function(
        x_data, jnp.exp(theta["log_L"]), jnp.exp(theta["log_k"]), theta["x0"]
    )

    # component 1: normal observations with small noise
    log_p_normal = jax.scipy.stats.norm.logpdf(
        y_data, loc=y_pred, scale=jnp.exp(theta["log_sigma"])
    )

    # component 2: outliers with large noise
    log_p_outlier = jax.scipy.stats.norm.logpdf(
        y_data, loc=y_pred, scale=jnp.exp(theta["log_sigma_out"])
    )

    # mixture log likelihood
    w = theta["w"]
    log_mixture = jax.scipy.special.logsumexp(
        jnp.stack([jnp.log(1 - w) + log_p_normal, jnp.log(w) + log_p_outlier], axis=0),
        axis=0,
    )

    return jnp.sum(log_mixture)

Setting up Persistent Sampling

Next we're setting up the sampler. This is almost exactly the same as for the SMC case. Just for fun, we'll use an incredibly low number of particles. In this way, we will be able to see the advantage of the growing ensemble in PS more clearly.

# Create initial particles
num_particles = 100

# sample from prior to initialize particles
key, init_key = jax.random.split(key)
keys = jax.random.split(init_key, 6)

# sample each parameter from its prior
log_L_init = jax.random.normal(keys[0], (num_particles,)) * jnp.log(3) + jnp.log(20)
log_k_init = jax.random.normal(keys[1], (num_particles,)) * jnp.log(0.2) + jnp.log(1)
x0_init = jax.random.normal(keys[2], (num_particles,)) * 4.0 + 12.0
log_sigma_init = jax.random.normal(keys[3], (num_particles,)) * jnp.log(2) + jnp.log(3)
log_sigma_out_init = jax.random.normal(keys[4], (num_particles,)) * jnp.log(
    10
) + jnp.log(10)
w_init = jax.random.beta(keys[5], a=4.0, b=20.0, shape=(num_particles,))

initial_particles = {
    "log_L": log_L_init,
    "log_k": log_k_init,
    "x0": x0_init,
    "log_sigma": log_sigma_init,
    "log_sigma_out": log_sigma_out_init,
    "w": w_init,
}

For the tempering schedule, we'll demonstrate one of the advantages of Persistent Sampling. For this problem, 30 tempering steps should be enough for a good posterior estimation. But to grow the persistent ensemble, we'll continue to sample the posterior even when $\beta = 1$ is reached. Each additional iteration increases the effective sample size and thus improves our estimate.

# tempering schedule
num_temperatures = 50
num_growing = 20
# Create a tempering schedule first 30 temperatures and then just add 1 for the remaining time
tempering_schedule = jnp.concat(
    [
        jnp.linspace(0.0, 1.0, num_temperatures - num_growing),
        jnp.array([1] * num_growing),
    ]
)

# HMC parameters for MCMC moves
hmc_parameters = {
    "step_size": 0.02,
    "inverse_mass_matrix": jnp.ones(6),  # 6 parameters: L, k, x0, sigma, sigma_out, w
    "num_integration_steps": 20,
}

print(f"Sampling. Configuration:")
print(f"  Number of particles: {num_particles}")
print(f"  Number of temperatures: {num_temperatures}")
print(
    f"  MCMC kernel: HMC with {hmc_parameters['num_integration_steps']} integration steps"
)
Sampling. Configuration:
  Number of particles: 100
  Number of temperatures: 50
  MCMC kernel: HMC with 20 integration steps

Define Kernel and Inference Loop

# create Persistent Sampling kernel
ps_kernel = blackjax.persistent_sampling_smc(
    logprior_fn=log_prior,
    loglikelihood_fn=log_likelihood,
    n_schedule=num_temperatures,
    mcmc_step_fn=blackjax.hmc.build_kernel(),
    mcmc_init_fn=blackjax.hmc.init,
    mcmc_parameters=extend_params(hmc_parameters),
    resampling_fn=systematic,
    num_mcmc_steps=5,
)


# create inference loop
def inference_loop(key, kernel, initial_state, schedule):
    """Run SMC inference loop."""

    def one_step(carry, lmbda):
        key, state = carry
        key, step_key = jax.random.split(key)
        new_state, _ = kernel.step(step_key, state, lmbda)
        ess = compute_persistent_ess(
            jnp.log(new_state.persistent_weights), normalize_weights=True
        )
        return (key, new_state), ess

    init_state = kernel.init(initial_state)
    (_, final_state), ess_history = jax.lax.scan(one_step, (key, init_state), schedule)
    return final_state, ess_history

Running The Sampler

%%time

key, ps_key = jax.random.split(key)
ps_final_state, ess_history = inference_loop(
    ps_key,
    ps_kernel,
    initial_particles,
    tempering_schedule,
)

ps_ess = compute_persistent_ess(
    jnp.log(ps_final_state.persistent_weights), normalize_weights=True
)
print(f"Persistent Sampling Results:")
print(
    f"  Final ESS: {ps_ess:.1f} (from {num_particles * num_temperatures} persistent particles, {ps_ess/(num_particles * num_temperatures):.2%})"
)
print(f"  Log evidence estimate: {ps_final_state.log_Z:.2f}\n")
Persistent Sampling Results:
  Final ESS: 3622.2 (from 5000 persistent particles, 72.44%)
  Log evidence estimate: -160.14

CPU times: user 7.72 s, sys: 37.9 ms, total: 7.76 s
Wall time: 2.94 s

The Posterior Estimate

Let's first examine the posterior like we would in standard SMC. We'll simply plot the distribution using the particles at the last step of the algorithm. We'll also plot the true values, to check how close we get.

## get the posterior samples from the final step
final_posterior_sample = ps_final_state.particles
final_posterior_sample = pd.DataFrame(
    {
        "L": jnp.exp(final_posterior_sample["log_L"]),
        "k": jnp.exp(final_posterior_sample["log_k"]),
        "x0": final_posterior_sample["x0"],
        "sigma": jnp.exp(final_posterior_sample["log_sigma"]),
        "sigma_out": jnp.exp(final_posterior_sample["log_sigma_out"]),
        "w": final_posterior_sample["w"],
    }
)

# plot the posterior distributions
sns.pairplot(
    final_posterior_sample,
    diag_kind="hist",
    plot_kws={"s": 15, "color": "purple", "alpha": 0.5},
    diag_kws={"fill": True, "color": "purple", "alpha": 0.5},
    corner=True,
)

# add true parameter values to the plots
axes = plt.gcf().axes
for ax in axes:
    if ax.get_xlabel() in true_parameters:
        param_name = ax.get_xlabel()
        true_value = true_parameters[param_name]
        ax.axvline(
            true_value,
            color="darkred",
            linewidth=2,
        )
    if ax.get_ylabel() in true_parameters:
        param_name = ax.get_ylabel()
        true_value = true_parameters[param_name]
        ax.axhline(
            true_value,
            color="darkred",
            linewidth=2,
        )

plt.suptitle(
    "Posterior Distributions from Final Sample", fontsize=16, fontweight="bold"
)
plt.tight_layout()
plt.show()
No description has been provided for this image

That doesn't look too bad. The values seem sensible. But the posterior is clearly sampled far to sparsely. This is where the persistent ensemble comes in. Let's plot the distribution again using the whole ensemble. We need to be careful to correctly weight the particles in this case.

## get the posterior samples from the whole persistent ensemble
persistent_posterior_sample = pd.DataFrame(
    {
        "L": jnp.exp(ps_final_state.persistent_particles["log_L"].flatten()),
        "k": jnp.exp(ps_final_state.persistent_particles["log_k"].flatten()),
        "x0": ps_final_state.persistent_particles["x0"].flatten(),
        "sigma": jnp.exp(ps_final_state.persistent_particles["log_sigma"].flatten()),
        "sigma_out": jnp.exp(
            ps_final_state.persistent_particles["log_sigma_out"].flatten()
        ),
        "w": ps_final_state.persistent_particles["w"].flatten(),
    }
)
persistent_weights = ps_final_state.persistent_weights.flatten()
normalized_weights = jnp.exp(
    persistent_weights - jax.scipy.special.logsumexp(persistent_weights)
)

# remove outliers for plotting purposes
lower_bounds = persistent_posterior_sample.quantile(0.05)
upper_bounds = persistent_posterior_sample.quantile(0.95)
mask = (
    (persistent_posterior_sample >= lower_bounds)
    & (persistent_posterior_sample <= upper_bounds)
).all(axis=1)

# Plot the posterior distributions as weighted KDEs
sns.pairplot(
    persistent_posterior_sample[mask],
    kind="kde",
    diag_kind="hist",
    plot_kws={
        "fill": True,
        "color": "purple",
        "alpha": 0.5,
        "levels": 5,
        "weights": normalized_weights[mask.values],
    },
    diag_kws={
        "color": "purple",
        "fill": True,
        "alpha": 0.5,
        "bins": 30,
        "weights": normalized_weights[mask.values],
    },
    corner=True,
)

# add true parameter values to the plots
axes = plt.gcf().axes
for ax in axes:
    if ax.get_xlabel() in true_parameters:
        param_name = ax.get_xlabel()
        true_value = true_parameters[param_name]
        ax.axvline(
            true_value,
            color="darkred",
            linewidth=2,
        )
    if ax.get_ylabel() in true_parameters:
        param_name = ax.get_ylabel()
        true_value = true_parameters[param_name]
        ax.axhline(
            true_value,
            color="darkred",
            linewidth=2,
        )

plt.suptitle(
    "Posterior Distributions from Persistent Ensemble", fontsize=16, fontweight="bold"
)
plt.tight_layout()
plt.show()
No description has been provided for this image

This gives us a much smoother distribution! And that with only 100 particles per iteration. Let's also visualise how the effective sample size grows over time.

# plot the data
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(ess_history, c="purple", marker="o", markersize=10, linewidth=5, alpha=0.8)

ax.set_xlabel("Tempering Step", fontsize=14)
ax.set_ylabel("Effective Sample Size", fontsize=14)
ax.set_title(
    "Effective Sample Size (ESS) over Tempering Schedule",
    fontsize=16,
    fontweight="bold",
)
ax.axvline(
    num_temperatures - num_growing,
    color="darkred",
    label=r"Tempering param = 1 from here",
    linestyle="--",
)
ax.legend(fontsize=12)
plt.tight_layout()
No description has been provided for this image

We can also check how the marginal likelihood, i.e. the model evidence evolves.

# plot the data
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(
    ps_final_state.persistent_log_Z,
    c="purple",
    marker="o",
    markersize=10,
    linewidth=5,
    alpha=0.8,
)

ax.set_xlabel("Tempering Step", fontsize=14)
ax.set_ylabel(r"Log Model Evidence $Z$", fontsize=14)
ax.set_title(
    "Marginal Likelihood Estimate over Tempering Schedule",
    fontsize=16,
    fontweight="bold",
)
ax.axvline(
    num_temperatures - num_growing,
    color="darkred",
    label=r"Tempering param = 1 from here",
    linestyle="--",
)
ax.legend(fontsize=12)
plt.tight_layout()
No description has been provided for this image

This also looks good. The fact that the model evidence stays constant once we reach $\beta = 1$ indicates that we have a robust estimate.

Final Thoughts

Persistent Sampling is an exciting evolution of Sequential Monte Carlo methods. By retaining particles from all iterations, PS reduces correlation, increases effective sample size, and improves evidence estimates—all while maintaining the same per-particle computational complexity. This makes it a powerful tool for Bayesian inference, especially when model comparison is important, and I hope to see it applied more often in practice in the future! My aim with the BlackJAX implementation and this blog post is to make it more accessible to the research community. If you use it in your research, please let me know! I'd be very curious about actual use cases.

Further Reading

Persistent Sampling:

  • Karamanis et al. (2025): "Persistent Sampling: Enhancing the Efficiency of Sequential Monte Carlo" - arXiv:2407.20722

Sequential Monte Carlo:

  • Del Moral, Doucet, and Jasra (2006): "Sequential Monte Carlo samplers"
  • Chopin and Papaspiliopoulos (2020): "An Introduction to Sequential Monte Carlo"
  • Dai et al. (2022): "An invitation to sequential monte carlo samplers"

BlackJAX: