← all repositories
arcelien/pba

AutoML for image augmentation that won't melt your GPU

PBA learns when and how strongly to distort training images, trading the brute-force search of AutoAugment for a population-based approach that runs on a single workstation.

509 stars Jupyter Notebook ML FrameworksData Tooling
pba
Velocity · 7d
+0.2
★ / day
Trend
steady
star history

What it does

Population Based Augmentation (PBA) automatically learns schedules for data augmentation—deciding not just which transforms to apply, but how strongly and when during training. It targets the same problem as Google’s AutoAugment, but with a fraction of the compute budget. The repo includes pre-baked schedules for CIFAR-10/100 and SVHN, plus code to discover new ones.

The interesting bit

The cleverness is in the search: instead of training thousands of models to completion, PBA maintains a population of small models that periodically copy weights from better performers and mutate their augmentation policies. Reduced SVHN search finishes in about an hour on a Titan XP; reduced CIFAR-10 in five hours. The full CIFAR-10 PyramidNet result still takes 9 days on a V100, but that’s training, not search.

Key highlights

  • Matches reported AutoAugment results on CIFAR with “one thousand times less compute” (per the README)
  • Pre-computed schedules available for Wide-ResNet, Shake-Shake, and PyramidNet variants
  • Includes a Jupyter notebook (pba.ipynb) to visualize policies and applied augmentations
  • Search scripts support reduced datasets for quick iteration; full reproduction scripts included
  • Python 2 and 3 compatible

Caveats

  • TensorFlow-only; no PyTorch implementation in repo
  • The “1000x less compute” claim compares search cost, not total training cost—full model training is still expensive
  • Code appears to be research-grade: functional but not packaged as a library

Verdict

Worth a look if you’re doing image classification on CIFAR-scale datasets and want augmentation policies without burning a TPU pod. Skip if you need plug-and-play PyTorch or are working far outside the tested model/dataset combinations.

heatdrop uses Google Analytics to see which pages get read — nothing else. Your call. How we handle data.