Papers I Read Notes and Summaries

Averaging Weights leads to Wider Optima and Better Generalization

Introduction

  • The paper proposes Stochastic Weight Averaging (SWA) procedure for improving the generalization performance of models trained with SGD (with cyclic or constant learning rate).

  • Specifically, the model is checkpointed at several points along the training trajectory, and these checkpoints are averaged (in the parameter space) to obtain a single model.

  • Link to the paper

Idea

  • “Stochastic” in the name refers to the idea that with cyclical or constant learning rate, SGD proposals are approximately sampled from a neural network’s loss surface and are hence stochastic.

  • SWA uses a learning rate schedule that allows exploration in the weight space.

  • SGD with cyclical and constant learning rates explore points (model instances) at the periphery of high-performing networks.

  • With different initializations, SGD will find different points (of low training loss) on this boundary, but will not move inside it.

  • Averaging the points provide a mechanism to move inside this periphery.

  • The train and the test error surfaces, while being similar, are not perfectly aligned. Hence, averaging several models (along the optimization trajectory) could lead to a more robust model.

Algorithm

  • Given a model $w$ and some training budget $B$, train the model in the conventional way for approx 75% of the budget.

  • Starting from that point, continue training with the remaining budget, with a constant or cyclical learning rate.

  • For fixed learning rate, checkpoint models at each epoch. For cyclical learning rate, checkpoint the model at the lowest learning rate in the cycle.

  • Average all the models to get the SWA model.

  • If the model has Batch Normalization layers, run an additional pass to compute the SWA model’s running mean and standard deviation.

  • The computational and space complexity of computing the SWA model is relatively low.

  • The paper highlights the ensembling like the effect of SWA by showing that if the model checkpoints ($w_i$) are generated by training with Fast Geometric Ensembling (FGE), the difference between averaging the weights and averaging the predictions is of the order $O(\Delta)$ where $\Delta = max ||w_i - w_{SA}||$.

  • Note that SWA does not have the overhead of an extra-forward pass during inference.

Experiments

  • Datasets: CIFAR10, CIFAR100, ImageNet

  • Models: VGG16, WideResNet, 164-layer preactivation ResNet, ShakeShake, Pyramid Net.

  • Baselines: Conventional SGD, Exponentially decaying average with SGD and FGE.

  • In all the CIFAR experiments, SWA consistently outperforms SGD in one budget and consistently improves with training.

  • SWA also achieves performance comparable to FGE, despite FGE being an ensemble method.

  • On ImageNet, SWA is run on a pre-trained model, and it improves performance in all the cases.

  • An ablation experiment (on CIFAR-100) shows that it is possible to train a network (with SWA) using a fixed learning rate. In that setup, using SWA improves performance by 16%.