← all repositories
zsdonghao/u-net-brain-tumor

A U-Net that segments tumors one label at a time

An older TensorLayer reference implementation for BRATS brain tumor segmentation, with some honest notes about its own rough edges.

535 stars Python Computer VisionDomain Apps
u-net-brain-tumor
Velocity · 7d
+0.2
★ / day
Trend
steady
star history

What it does

Trains a U-Net to segment brain tumors from the BRATS 2017 dataset, which provides four MRI modalities (FLAIR, T1, T1c, T2) and four label classes for each volume. The code splits data into training and validation folds, applies heavy augmentation, and trains separate single-label networks rather than one multi-class model.

The interesting bit

The dice-loss-per-label approach is unusual: instead of one network predicting all tumor regions, you train separate models for necrotic tissue, edema, or enhancing tumor, then presumably combine them. The augmentation pipeline includes elastic deformation, which matters more than usual here given how irregular tumor boundaries are.

Key highlights

  • Built on TensorLayer (not raw TensorFlow or PyTorch)
  • Dice loss and hard dice/IOU metrics built into the training loop
  • Augmentation: flips, rotation, shift, shear, zoom, plus elastic transforms
  • Single-task networks: --task=all trains all labels, or pick necrotic/edema/enhance
  • Author explicitly flags the data pipeline as slow and welcomes contributions to modernize it

Caveats

  • Data loading is self-admitted as “not the fastest way”; author suggests TensorFlow Dataset API instead
  • Default config uses only half the training data for speed; you must manually edit prepare_data_with_valid.py to use the full set
  • Loss getting stuck at 1 means non-convergence; the README’s fix is “please try restart it”
  • BRATS dataset requires a separate application process; the author will not provide it

Verdict

Worth a look if you need a working TensorLayer U-Net baseline for medical segmentation, or want to study the per-label dice-loss approach. Skip it if you want production-ready data pipelines or modern PyTorch/TF2 implementations — this is reference code with visible cobwebs.

heatdrop uses Google Analytics to see which pages get read — nothing else. Your call. How we handle data.