NeurIPS2023

Likelihood-Based Diffusion Language Models

Ishaan Gulrajani, Tatsunori B. Hashimoto

138 citations

Abstract

Despite a growing interest in diffusion-based language models, existing work has not shown that these models can attain nontrivial likelihoods on standard language modeling benchmarks. In this work, we take the first steps towards closing the likelihood gap between autoregressive and diffusion-based language models, with the goal of building and releasing a diffusion model which outperforms a small but widely-known autoregressive model. We pursue this goal through algorithmic improvements, scaling laws, and increased compute. On the algorithmic front, we introduce several methodological improvements for the maximum-likelihood training of diffusion language models. We then study scaling laws for our diffusion models and find compute-optimal training regimes which differ substantially from autoregressive models. Using our methods and scaling analysis, we train and release Plaid 1B, a large diffusion language model which outperforms GPT-2 124M in likelihood on benchmark datasets and generates fluent samples in unconditional and zero-shot control settings. 1 Variational Diffusion Models for language In this background section, we formally define continuous diffusion models over text sequences, adopting the Variational Diffusion Models (VDM) framework [18] which is a natural fit for likelihoodbased training (see Karras et al. [16] for a survey on other formalisms). For brevity, we simplify some details in our exposition and refer the reader to Kingma et al. [18] for details. Consistent with prior work (e.g. Li et al. [21]), our basic approach will be to map discrete text sequences into a continuous space with a token-wise embedding function and then construct a diffusion model on the embedded data. Forward diffusion process Consider a sequence of tokens x = [x (1) , . . . , x (L) ] drawn from the data distribution q(x). We transform x into a sequence x of embedding vectors using an invertible token-wise embedding function Embed(•), such that x(i) := Embed(x (i) ). The forward process is a Markov chain over latent variables z t from t = 0 to t = 1 which progressively adds Gaussian noise to x. Let σ 2 (t) be some monotonic function that specifies the total noise