NeurIPS2023
The Crucial Role of Normalization in Sharpness-Aware Minimization
Yan Dai, Kwangjun Ahn, Suvrit Sra
32 citations
Abstract
Sharpness-Aware Minimization (SAM) is a recently proposed gradient-based optimizer (Foret et al., ICLR 2021) that greatly improves the prediction performance of deep neural networks. Consequently, there has been a surge of interest in explaining its empirical success. We focus, in particular, on understanding the role played by normalization, a key component of the SAM updates. We theoretically and empirically study the effect of normalization in SAM for both convex and non-convex functions, revealing two key roles played by normalization: i) it helps in stabilizing the algorithm; and ii) it enables the algorithm to drift along a continuum (manifold) of minima -a property identified by recent theoretical works that is the key to better performance. We further argue that these two properties of normalization make SAM robust against the choice of hyper-parameters, supporting the practicality of SAM. Our conclusions are backed by various experiments. * The first two authors contribute equally. Work done while Yan Dai was visiting MIT. 1 In principle, the normalization in Equation 1 may make SAM ill-defined. However, Wen et al. (2023, Appendix B) showed that except for countably many learning rates, SAM (with any ρ) is always well-defined for almost all initialization. Hence, throughout the paper, we assume that the SAM iterates are always well-defined. 37th Conference on Neural Information Processing Systems (NeurIPS 2023).