ICLR2025

Optimized Multi-Token Joint Decoding With Auxiliary Model for LLM Inference

Zongyue Qin, Ziniu Hu, Zifan He, Neha Prakriya, Jason Cong, Yizhou Sun

Abstract

Large language models (LLMs) have achieved remarkable success across diverse tasks, but, due to single-token generation at each decoding step, their inference processes are hindered by substantial time and energy demands. While previous methods such as speculative decoding mitigate these inefficiencies by producing multiple tokens per step, each token is still generated by its single-token distribution. Although this enhances the speed, it does not improve the output quality. In contrast, our work simultaneously boosts inference speed and improves the output effectiveness. We consider multi-token joint decoding (MTJD), which generates multiple tokens from their joint distribution at each iteration, theoretically reducing perplexity and raising task performance. However, MTJD suffers from the high cost of sampling from the joint distribution of multiple tokens. Inspired by speculative decoding, we introduce multi-token assisted decoding (MTAD), a novel framework designed to accelerate MTJD. MTAD leverages a smaller auxiliary model to approximate the joint distribution of a larger model, incorporating a verification mechanism that not only ensures the accuracy of this approximation, but also increases the decoding efficiency over conventional speculative decoding. To further improve efficiency, we extend MTAD to multi-candidate multi-token assisted decoding (MMTAD) which incorporates tree-wise parallel decoding to efficiently verify multiple candidates. Theoretically, we demonstrate that MTAD and MMTAD closely approximate exact MTJD with a bounded error. Empirical evaluations across various tasks reveal that our method improves downstream performance by 43% compared to standard single-token sampling. Furthermore, MTAD achieves a 1.26× speed-up and consumes 23.6% less energy than vanilla speculative decoding methods. These results highlight MTAD's ability to make multi-token joint decoding both effective and efficient, promoting more productive and high-performance deployment of LLMs. 1