NeurIPS2025

Neural Attention Search

Difan Deng, Marius Lindauer

被引用 431 次

摘要

We present Neural Attention Search (NAtS), an end-to-end learnable sparse transformer that automatically evaluates the importance of each token within a sequence and determines if the corresponding token can be dropped after several steps. To this end, we design a search space that contains three token types: (i) Global Tokens will be preserved and queried by all the following tokens; (ii) Local Tokens survive until the next global token appears; and (iii) Sliding Window Tokens have an impact on the inference of a fixed size of the next following tokens. Similar to the One-Shot Neural Architecture Search approach, this token-type information can be learned jointly with the architecture weights via a learnable attention mask. Experiments on both training a new transformer from scratch and fine-tuning existing large language models show that NAtS can efficiently reduce the KV cache size and the inference costs for the models while maintaining the models' performance. However, querying information from historical sequences requires a complexity of O(L 2 ) w.r.t. the input sequence length L. KV caching could reduce this time complexity to O(L) by storing all the historical KV values. Nevertheless, with the increasing model size of recent LLMs, even the O(L) time-wise and memory-wise complexity could become a bottleneck during inference time. Indeed, not all the tokens in a sequence are equally important [46] . Many of them are redundant and do not contribute to the final output. Humans can recognize this information without pre-defined fixed rules and summarize or discard the context information into much smaller content. Transformers could also learn this ability implicitly: Many tokens in the attention map might only have very low weights [96] and only have little influence on the final predictions. However, as the transformer learns this information implicitly, we might not know how the important tokens would be distributed in the context. Selecting these tokens and recognizing the attention distributions might require extra human experts' knowledge by either looking at the attention maps [27, 58, 95, 96] or applying specific fixed rules [13, 16, 17, 31, 84] . Since this knowledge is already contained in the transformer models, we could also ask the model to evaluate the importance of each token and learn to predict the optimal type for the given input tokens automatically. Unlike prior works that rely on human expertise or predefined rules to identify important tokens [15, 27, 28, 31, 54, 84, 85, 96] , we propose a novel approach to evaluate the importance of each token by assigning different roles to each of the tokens. For example, some tokens will be preserved until the end, while other tokens might only survive for a short amount of time. These roles measure the importance of each token and determine if it would survive within the next few tokens. Rather 39th Conference on Neural Information Processing Systems (NeurIPS 2025).