ICML2025
Puzzle: Distillation-Based NAS for Inference-Optimized LLMs
Akhiad Bercovich, Tomer Ronen, Talor Abramovich, Nir Ailon, Nave Assaf, Mohammad Dabbah, Ido Galil, Amnon Geifman, Yonatan Geifman, Izhak Golan, Netanel Haber, Ehud Karpas, Roi Koren, Itay Levy, Pavlo Molchanov, Shahar Mor, Zach Moshe, Najeeb Nabwani, Omri Puny, Ran Rubin, Itamar Schen, Ido Shahaf, Oren Tropp, Omer Ullman Argov, Ran Zilberstein, Ran El-Yaniv
Abstract
Large language models (LLMs) offer remarkable capabilities, yet their high inference costs restrict wider adoption. While increasing parameter counts improves accuracy, it also broadens the gap between state-of-the-art capabilities and practical deployability. We present Puzzle, a hardware-aware framework that accelerates the inference of LLMs while preserving their capabilities. Using neural architecture search (NAS) at a large-scale, Puzzle optimizes models with tens of billions of parameters. Our approach utilizes blockwise local knowledge distillation (BLD) for parallel architecture exploration and employs mixed-integer programming for precise constraint optimization. We showcase our framework's impact via Llama-3.1-Nemotron-51B-Instruct (Nemotron-51B) and Llama-3.3-Nemotron-49B, two publicly available models derived from Llama-70B-Instruct. Both models achieve a 2.17× inference throughput speedup, fitting on a single NVIDIA H100 GPU while retaining 98.4% of the original model's benchmark accuracies. These are the most accurate models supporting single H100 GPU inference with large batch sizes, despite training on 45B tokens at most, far fewer than the 15T used to train Llama-70B. Lastly, we show that lightweight alignment on these derived models allows them to surpass the parent model in specific capabilities. Our work establishes that powerful LLM models can be optimized for efficient deployment with only negligible loss in quality, underscoring that inference performance, not parameter count alone, should guide model selection. LLMs require a substantial amount of parameters for their training process to converge easily and achieve better generalization [29, 25, 5, 8] . This overparameterization not only facilitates optimization, but also provides greater capacity to store knowledge and learn complex patterns across diverse tasks, explaining why larger models consistently demonstrate superior performance [29, 25] . However, once trained, many parameters and computations turn out to be redundant for inference, as evidenced by the success of various computational efficiency techniques [21, 9, 58, 7, 40, 27, 3 ]. Yet, LLM architectures remain largely uniform, comprising repeated identical layers, with little consideration given to balancing each block's computational cost against its contribution to overall model predictive performance-a *These authors contributed equally. Other co-authors are listed alphabetically. design choice primarily driven by training stability and ease of scaling rather than inference efficiency. This work addresses how to transform a trained LLM from a structure suited for training into one optimized for efficient inference on specific hardware (such as H100), while preserving its accumulated knowledge and predictive performance. Given a "parent model", our approach explores a large search space of architecture configurations to identify efficient options tailored to meet specific hardware and task-related constraints. This exploration requires a method to reliably estimate the performance of each potential configuration, allowing us to identify models that balance efficiency and accuracy for deployment. MHA FFN MHA Linear no-op FFN MQA Linear GQA FFN GQA Linear GQA Linear MQA FFN Step 1: Crafting the "puzzle pieces" Applying block-wise local distillation to every alternative subblock replacement in parallel and scoring its quality and inference cost to build a "library" of blocks. Step 2: Assembling the puzzle architecture Utilizing Mixed-Integer-Programming to assemble a heterogeneous architecture that optimizes quality under constraints such as throughput, latency and memory usage. Step 3: Uptraining The reassembled architecture is trained with global Knowledge-Distillation to strengthen interblock compatibility.