← all repositories
pemami4911/neural-combinatorial-rl-pytorch

Teaching neural nets to sort and route salesmen with pointer networks

A PyTorch reimplementation of a 2016 paper that uses reinforcement learning to solve combinatorial optimization without hand-crafted heuristics.

610 stars Python Domain AppsML Frameworks
neural-combinatorial-rl-pytorch
Velocity · 7d
+0.2
★ / day
Trend
steady
star history

What it does

This repo implements a pointer network trained with policy gradients to attack two classic combinatorial problems: the planar symmetric Euclidean Traveling Salesman Problem (TSP) and a sorting task. The model learns to output permutations of its input by attending over input elements, using either greedy decoding or stochastic sampling during training. For TSP, the network is rewarded based on tour length; for sorting, the reward scales with the length of the longest increasing subsequence in the output.

The interesting bit

The author skipped the paper’s critic network entirely and instead used an exponential moving average baseline—apparently after discovering, through correspondence with other reimplementers, that this simpler approach actually worked better. The beam search decoder is also cheerfully admitted to be “not yet finished” and currently supports exactly one beam, i.e., greedy search.

Key highlights

  • Implements RL pretraining with greedy decoding for TSP and sorting tasks
  • Uses stochastic policy via torch.multinomial() during training
  • Sorting generalization tested: 99.7% reward on sort10, but drops to 55.9% on sort20 with the same model
  • Attention visualization supported via --plot_attention True
  • Extensible to other tasks: add a dataset class and a scalar reward function
  • PyTorch 0.4 compatibility lives on a separate branch; main code targets 0.2–0.3

Caveats

  • Beam search is unfinished (only greedy decoding works)
  • Several paper features remain unimplemented: RL pretraining-Sampling, Active Search, A3C-style async training, and variable-length inputs
  • TSP results shown for a single random seed over 50 epochs, so reproducibility is unclear
  • Dependencies are dated: Python 3.6, PyTorch 0.2–0.3

Verdict

Worth a look if you’re studying pointer networks or need a hackable baseline for neural combinatorial optimization. Skip it if you need production-grade TSP solvers or modern PyTorch—the code is a research reimplementation with acknowledged rough edges, not a maintained library.

heatdrop uses Google Analytics to see which pages get read — nothing else. Your call. How we handle data.