NeurIPS2023
Enhancing Sharpness-Aware Optimization Through Variance Suppression
Bingcong Li, Georgios B. Giannakis
41 citations
Abstract
Sharpness-aware minimization (SAM) has well documented merits in enhancing generalization of deep neural networks, even without sizable data augmentation. Embracing the geometry of the loss function, where neighborhoods of 'flat minima' heighten generalization ability, SAM seeks 'flat valleys' by minimizing the maximum loss caused by an adversary perturbing parameters within the neighborhood. Although critical to account for sharpness of the loss function, such an 'over-friendly adversary' can curtail the outmost level of generalization. The novel approach of this contribution fosters stabilization of adversaries through variance suppression (VaSSO) to avoid such friendliness. VaSSO's provable stability safeguards its numerical improvement over SAM in model-agnostic tasks, including image classification and machine translation. In addition, experiments confirm that VaSSO endows SAM with robustness against high levels of label noise. Code is available at https://github.com/ BingcongLi/VaSSO . Introduction Despite deep neural networks (DNNs) have advanced the concept of "learning from data," and markedly improved performance across several applications in vision and language (Devlin et al., 2018; Tom et al., 2020) , their overparametrized nature renders the tendency to overfit on training data (Zhang et al., 2021a). This has led to concerns in generalization, which is a practically underscored perspective yet typically suffers from a gap relative to the training performance. Improving generalizability is challenging. Common approaches include (model) regularization and data augmentation (Srivastava et al., 2014) . While it is the default choice to integrate regularization such as weight decay and dropout into training, these methods are often insufficient for DNNs especially when coping with complicated network architectures (Chen et al., 2022) . Another line of effort resorts to suitable optimization schemes attempting to find a generalizable local minimum. For example, SGD is more preferable than Adam on certain overparameterized problems since it converges to maximum margin solutions (Wilson et al., 2017) . Decoupling weight decay from Adam also empirically facilitates generalizability (Loshchilov and Hutter, 2017) . Unfortunately, the underlying mechanism remains unveiled, and whether the generalization merits carry over to other intricate learning tasks calls for additional theoretical elaboration. Our main focus, sharpness aware minimization (SAM), is a highly compelling optimization approach that facilitates state-of-the-art generalizability by exploiting sharpness of loss landscape (Foret et al., 2021; Chen et al., 2022) . A high-level interpretation of sharpness is how violently the loss fluctuates within a neighborhood. It has been shown through large-scale empirical studies that sharpness-based measures highly