NeurIPS2023
Transformers as Statisticians: Provable In-Context Learning with In-Context Algorithm Selection
Yu Bai, Fan Chen, Huan Wang, Caiming Xiong, Song Mei
300 citations
Abstract
Neural sequence models based on the transformer architecture have demonstrated remarkable incontext learning (ICL) abilities, where they can perform new tasks when prompted with training and test examples, without any parameter update to the model. This work advances the understandings of the strong ICL abilities of transformers. We first provide a comprehensive statistical theory for transformers to perform ICL by deriving end-to-end quantitative results for the expressive power, in-context prediction power, and sample complexity of pretraining. Concretely, we show that transformers can implement a broad class of standard machine learning algorithms in context, such as least squares, ridge regression, Lasso, convex risk minimization for generalized linear models (such as logistic regression), and gradient descent on two-layer neural networks, with near-optimal predictive power on various in-context data distributions. Using an efficient implementation of in-context gradient descent as the underlying mechanism, our transformer constructions admit mild bounds on the number of layers and heads, and can be learned with polynomially many pretraining sequences. Building on these "base" ICL algorithms, intriguingly, we show that transformers can implement more complex ICL procedures involving in-context algorithm selection, akin to what a statistician can do in real life-A single transformer can adaptively select different base ICL algorithms-or even perform qualitatively different tasks-on different input sequences, without any explicit prompting of the right algorithm or task. We both establish this in theory by explicit constructions, and also observe this phenomenon experimentally. In theory, we construct two general mechanisms for algorithm selection with concrete examples: (1) Pre-ICL testing, where the transformer determines the right task for the given sequence (such as choosing between regression and classification) by examining certain summary statistics of the input sequence; (2) Post-ICL validation, where the transformer selects-among multiple base ICL algorithms (such as ridge regression with multiple regularization strengths)-a near-optimal one for the given sequence using a train-validation split. As an example, we use the post-ICL validation mechanism to construct a transformer that can perform nearly Bayes-optimal ICL on a challenging task-noisy linear models with mixed noise levels. Experimentally, we demonstrate the strong in-context algorithm selection capabilities of standard transformer architectures.