pyro-ppl/numpyro
Probabilistic programming library that provides a NumPy backend for Pyro using JAX for automatic differentiation and GPU/TPU compilation.

NumPyro is a lightweight probabilistic programming framework built on JAX for high-performance numerical computation. It implements inference algorithms with a focus on MCMC methods like Hamiltonian Monte Carlo (HMC) and NUTS, along with variational inference support. The library enables users to write probabilistic models using NumPy-compatible syntax while leveraging JAX’s automatic differentiation and just-in-time compilation across hardware accelerators.