google/flax
A neural network library for JAX providing flexible APIs (Flax NNX and Flax Linen) for building, inspecting, and training deep learning models.

Flax is a Google-maintained neural network library for JAX designed to serve the growing JAX ML research ecosystem. It provides two APIs: Flax NNX (released 2024) offers simplified Python reference semantics for easier model creation and debugging, while Flax Linen (the original API from 2020) provides functional transformations. The library enables users to express neural networks using regular Python objects with reference sharing and mutability support.