kingoflolz/mesh-transformer-jax
A model-parallel transformer library in JAX and Haiku designed to scale up to 40B parameters on TPUs, hosting the GPT-J-6B language model.

This repository provides a haiku library leveraging JAX’s xmap/pjit operators for transformer model parallelism, implementing a scheme similar to Megatron-LM optimized for TPU meshes with optional ZeRo-style sharding. It includes GPT-J-6B, a 6 billion parameter autoregressive language model trained on The Pile, along with support for fine-tuning and partial checkpoint training. The codebase is designed for scalability up to approximately 40B parameters.