linsdex

A high-performance JAX library for linear stochastic differential equations, state space models, and Gaussian inference.

linsdex provides a modular framework for defining, simulating, and conditioning linear-Gaussian systems with support for parallelized inference on GPUs.

The library focuses on two things: high performance and numerical stability.

Key features

  • Linear SDEs with exact transition distributions for both time-invariant and time-varying systems
  • Parallel message passing using associative scans for \(O(\log T)\) filtering, smoothing, and sampling in chain-structured Gaussian CRFs
  • Three Gaussian parameterizations (Standard, Natural, Mixed) that allow numerically stable operations across different inference contexts
  • Specialized matrix library with Diagonal, Block, and Dense types that use symbolic tags to simplify expressions before they reach JAX
  • Diffusion model utilities including a unified interface for converting between clean data predictions, scores, probability flows, and SDE drifts

Stochastic differential equations

The library supports a hierarchy of SDEs. Time-invariant SDEs like Brownian motion, Ornstein-Uhlenbeck processes, and stochastic harmonic oscillators have constant drift and diffusion coefficients. Time-varying SDEs, such as variance-preserving processes used in diffusion models, extend this with time-dependent coefficients. Any linear SDE can be conditioned on Gaussian potentials to produce posterior processes.

Conditioning and inference

Given a base linear SDE and noisy observations, linsdex constructs a Gaussian conditional random field and uses parallel message passing to compute the posterior. This enables sampling trajectories that interpolate observed data while respecting the underlying dynamics.

import jax
import jax.numpy as jnp
from linsdex import BrownianMotion

sde = BrownianMotion(sigma=1.0, dim=2)
conditioned_sde = sde.condition_on_starting_point(t0=0.0, x0=jnp.zeros(2))

key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 100)
times = jnp.linspace(0.0, 1.0, 500)
trajectories = jax.vmap(conditioned_sde.sample, in_axes=(0, None))(keys, times)