sanchit-gandhi/whisper-jax
Optimized JAX/Flax implementation of OpenAI's Whisper speech-to-text model achieving up to 70x speedup over the original PyTorch version.

This repository provides a high-performance implementation of OpenAI’s Whisper model using JAX and Flax. It leverages the Hugging Face Transformers implementation as a foundation while achieving significant inference speedups, particularly on TPU hardware where it can transcribe 30 minutes of audio in approximately 30 seconds. The project offers a FlaxWhisperPipeline abstraction for easy usage and supports CPU, GPU, and TPU execution.