A VAE reference that admits its own limitations
A clean PyTorch/TensorFlow/JAX reference implementation for variational autoencoders, with enough self-awareness to note where it falls short of published results.

What it does
Trains variational autoencoders on binarized MNIST across three frameworks—PyTorch, TensorFlow, and JAX—using either a mean-field posterior or the more expressive inverse autoregressive flow. The author recommends the PyTorch version, which hits -97.10 nats on the test set with a basic encoder-decoder and improves to -95.33 nats with normalizing flows.
The interesting bit
The README is unusually honest about the gap between its numbers and the literature. The author explicitly notes that the flow-based model still trails the paper’s ~-80 nats because this implementation skips convolutions and residual blocks. That candor is rarer than the code itself. There’s also a JAX port that clocks a 3× speedup over PyTorch for the mean-field case—0.81 minutes versus 2.49.
Key highlights
- Three framework implementations (PyTorch, TensorFlow, JAX) in one repo
- Inverse autoregressive flow as a drop-in
--variational flowoption - Importance-sampled marginal likelihood estimates, not just ELBO
- Accompanying blog post walks through VAE fundamentals
- DOI-backed for citation purposes
Caveats
- No convolutions or residual blocks, so ceiling is intentionally modest
- TODO lists multi-GPU/TPU support and
jaxtypingas unfinished - Anaconda environment file for PyTorch is oddly named
environment-jax.yml
Verdict
Grab this if you want a readable, no-magic VAE baseline to extend or compare against. Skip it if you need state-of-the-art density estimation out of the box—the author will tell you as much.