lucidrains/memorizing-transformers-pytorch
A PyTorch implementation of Memorizing Transformers, a transformer architecture augmented with approximate nearest neighbor memory retrieval.

This repository provides a PyTorch implementation of the Memorizing Transformers paper from ICLR 2022. The model augments standard transformer attention with an external memory system that uses approximate nearest neighbors for retrieval. During inference, the model retrieves relevant past tokens from a memory store to enhance context understanding. The implementation uses cosine similarity attention with learned temperature for the KNN attention layer, and supports hybrid attention across local and distant contexts.