EMNLP2024
GRASS: Compute Efficient Low-Memory LLM Training with Structured Sparse Gradients
Aashiq Muhamed, Oscar Li, David P. Woodruff, Mona T. Diab, Virginia Smith
1 citation
Abstract
Large language model (LLM) training and finetuning are often bottlenecked by limited GPU memory. While existing projection-based optimization methods address this by projecting gradients into a lower-dimensional subspace to reduce optimizer state memory, they typically rely on dense projection matrices, which can introduce computational and memory overheads. In this work, we propose GRASS (GRAdient Stuctured Sparsification), a novel approach that leverages sparse projections to transform gradients into structured sparse updates. This design not only significantly reduces memory usage for optimizer states but also minimizes gradient memory footprint, computation, and communication costs, leading to substantial throughput improvements. Extensive experiments on pretraining and finetuning tasks demonstrate that GRASS achieves competitive performance to full-rank training and existing projection-based methods. Notably, GRASS enables half-precision pretraining of a 13B parameter LLaMA model on a single 40GB A100 GPU-a feat infeasible for previous methodsand yields up to a 2× throughput improvement on an 8-GPU system. Code is released here 1 . P ← compute P (∇L(W (t) )) ▷ P ∈ R m×r 8: // [Optional] Update optimizer state 9: S (t) ← update_state(S (t) ) 10: end if 11: S (t+1) , ∆ (t+1) ← opt.update(S (t) , GC ) 13: ▷ Apply update 14: end for Algorithm 2 MeSO Implementations FLORA Compute dense P : Sample Pij i.i.d. from N (0, 1/r). Update_state: Updates momentum as P (t+1) P ⊤ (t) S (t) . Compute GC : Computes GC using dense matmul. Apply update: Updates full W after dense matmul. GALORE Compute dense P : Top-r left singular vectors of grad GW . Update_state: Maintains optimizer state. Compute GC : Computes GC using dense matmul. Apply update: Updates full W after a dense matmul. GRASS (ours) Compute sparse P : Computes the selection matrix B and the diagonal scaling matrix ρ based on row norms of GW . Update_state: Resets S (t) to zero as necessary. Compute GC : Uses matrix associativity and sparse matmul. Apply update: Sparse update W after sparse matmul. structured sparse matrices for P , demonstrating their advantages in memory, computation, and communication efficiency across both pretraining and finetuning. Our main contributions include: