NeurIPS2023
Preference-grounded Token-level Guidance for Language Model Fine-tuning
Shentao Yang, Shujian Zhang, Congying Xia, Yihao Feng, Caiming Xiong, Mingyuan Zhou
35 citations
Abstract
Aligning language models (LMs) with preferences is an important problem in natural language generation. A key challenge is that preferences are typically provided at the sequence level while LM training and generation both occur at the token level. There is, therefore, a granularity mismatch between the preference and the LM training losses, which may complicate the learning problem. In this paper, we address this issue by developing an alternate training process, where we iterate between grounding the sequence-level preference into token-level training guidance, and improving the LM with the learned guidance. For guidance learning, we design a framework that extends the pairwise-preference learning in imitation learning to both variable-length LM generation and the utilization of the preference among multiple generations. For LM training, based on the amount of supervised data, we present two minimalist learning objectives that utilize the learned guidance. In experiments, our method performs competitively on two distinct representative LM tasks -discrete-prompt generation and text summarization. Source codes are released at https://github.com/Shentao-YANG/Preference_Grounded_Guidance . Update (01/07/25) In our follow-up work [113] , we developed new techniques to successfully scale up the token-level RLHF framework in this paper to PPO + LLMs. As in this paper, we observed strong gain over the classical bandit RLHF, as tabulated in the following Table 1 . Table 1: Performance comparison between token-level RLHF and bandit RLHF on PPO-trained LM policy, with the 8B-parameter Llama-family backbone model. The judge model is GPT-4o. For each backbone model, the highest value of each column is in bold. See Section 4.1 of Yin et al. [113] for experimental details. Action Space Backbone Model AlpacaEval 2 (LC) Arena-Hard MT-Bench Token Llama-3.