ICLR2025
Transformers Provably Learn Two-Mixture of Linear Classification via Gradient Flow
Hongru Yang, Zhangyang Wang, Jason D. Lee, Yingbin Liang
Abstract
Understanding how transformers learn and utilize hidden connections between words is crucial to understand the behavior of large language models. To study this mechanism, we consider the task of two-mixture of linear classification which features a hidden correspondence structure between words, and study the training dynamics of a symmetric two-headed transformer with ReLU neurons. Motivated by the stage-wise learning phenomenon observed in our experiments, we design and theoretically analyze a three-stage training algorithm, which can effectively characterize the actual gradient descent dynamics when we simultaneously train the neuron weights and the softmax attention. The first stage is a neuron learning stage, where the neurons align with the underlying signals. The second stage is an attention feature learning stage, where we analyze the feature learning process of how the attention learns to utilize the relationship between the tokens to solve certain hard samples. In the meantime, the attention features evolve from a nearly non-separable state (at the initialization) to a well-separated state. The third stage is a convergence stage, where the population loss is driven towards zero. The key technique in our analysis of softmax attention is to identify a critical subsystem inside a large dynamical system and bound the growth of the non-linear sub-system by a linear system. Along the way, we utilize a novel structure called mean-field infinite-width transformer. Finally, we discuss the setting with more than two mixtures. We empirically show the difficulty of generalizing our analysis of the gradient flow dynamics to the case even when the number of mixtures equals three, although the transformer can still successfully learn such distribution. On the other hand, we show by construction that there exists a transformer that can solve mixture of linear classification given any arbitrary number of mixtures. * Work done while visiting Princeton University How do transformers learn and utilize the hidden correspondence structure to solve mixture of linear classification via gradient descent? Our contributions. In this work, we study the training dynamics of a two-headed transformer given two mixtures. Our contributions are summarized as follows: