Few-shot learning by averaging things and hoping they cluster
A clean PyTorch reimplementation of Prototypical Networks that learns class prototypes from a handful of examples.

What it does
Trains a network to embed images into a vector space where each class is represented by the average — the barycentre — of its support samples. During an “episode,” a few examples define the prototype, then query samples are classified by their distance to these class centroids. The repo implements the full training loop, a custom batch sampler for episode construction, and the prototypical loss on top of the Omniglot dataset.
The interesting bit
The whole trick is that “class average in embedding space” works as a classifier when you only have one or five examples to learn from. The repo includes both a functional and a PyTorch-style loss class, plus a PrototypicalBatchSampler that handles the episode machinery — random class selection, support/query splitting — so you don’t have to wire it yourself.
Key highlights
- Reproduces paper results within ~0.3–0.9 percentage points on Omniglot (1-shot and 5-shot, 5-way and 20-way)
- Custom sampler and loss implemented as reusable PyTorch components
- Uses Vinyals’ dataset splits and rotations for direct comparability with the original paper
- Training script is explicitly marked as “for demonstration purposes”
- MIT licensed
Caveats
- Training code is demo-grade, not production infrastructure
- Only Omniglot is supported; no built-in t-SNE visualization despite the README showing one from the paper
- Awaiting official torchvision Omniglot integration, which may or may not happen
Verdict
Worth a look if you need a readable, hackable baseline for few-shot classification in PyTorch. Skip it if you want a framework with built-in dataset flexibility or industrial training pipelines.