NeurIPS2020
The Pitfalls of Simplicity Bias in Neural Networks
Harshay Shah, Kaustav Tamuly, Aditi Raghunathan, Prateek Jain, Praneeth Netrapalli
被引用 446 次
摘要
Several works have proposed Simplicity Bias (SB)-the tendency of standard training procedures such as Stochastic Gradient Descent (SGD) to find simple models-to justify why neural networks generalize well [1, 49, 74] . However, the precise notion of simplicity remains vague. Furthermore, previous settings [67, 24] that use SB to justify why neural networks generalize well do not simultaneously capture the non-robustness of neural networks-a widely observed phenomenon in practice [71, 36] . We attempt to reconcile SB and the superior standard generalization of neural networks with the non-robustness observed in practice by designing datasets that (a) incorporate a precise notion of simplicity, (b) comprise multiple predictive features with varying levels of simplicity, and (c) capture the non-robustness of neural networks trained on real data. Through theoretical analysis and targeted experiments on these datasets, we make four observations: (i) SB of SGD and variants can be extreme: neural networks can exclusively rely on the simplest feature and remain invariant to all predictive complex features. (ii) The extreme aspect of SB could explain why seemingly benign distribution shifts and small adversarial perturbations significantly degrade model performance. (iii) Contrary to conventional wisdom, SB can also hurt generalization on the same data distribution, as SB persists even when the simplest feature has less predictive power than the more complex features. (iv) Common approaches to improve generalization and robustness-ensembles and adversarial training-can fail in mitigating SB and its pitfalls. Given the role of SB in training neural networks, we hope that the proposed datasets and methods serve as an effective testbed to evaluate novel algorithmic approaches aimed at avoiding the pitfalls of SB.