ICLR2025

Sparse Learning for State Space Models on Mobile

Xuan Shen, Hangyu Zheng, Yifan Gong, Zhenglun Kong, Changdi Yang, Zheng Zhan, Yushu Wu, Xue Lin, Yanzhi Wang, Pu Zhao, Wei Niu

摘要

Transformer models have been widely investigated in different domains by providing long-range dependency handling and global contextual awareness, driving the development of popular AI applications such as ChatGPT, Gemini, and Alexa. State Space Models (SSMs) have emerged as strong contenders in the field of sequential modeling, challenging the dominance of Transformers. SSMs incorporate a selective mechanism that allows for dynamic parameter adjustment based on input data, enhancing their performance. However, this mechanism also comes with increasing computational complexity and bandwidth demands, posing challenges for deployment on resource-constraint mobile devices. To address these challenges without sacrificing the accuracy of the selective mechanism, we propose a sparse learning framework that integrates architecture-aware compiler optimizations. We introduce an end-to-end solution-C n 4 kernel sparsity, which prunes n elements from every four contiguous weights, and develop a compiler-based acceleration solution to ensure execution efficiency for this sparsity on mobile devices. Based on the kernel sparsity, our framework generates optimized sparse models targeting specific sparsity or latency requirements for various model sizes. We further leverage pruned weights to compensate for the remaining weights, enhancing downstream task performance. For practical hardware acceleration, we propose C n 4 -specific optimizations combined with a layout transformation elimination strategy. This approach mitigates inefficiencies arising from fine-grained pruning in linear layers and improves performance across other operations. Experimental results demonstrate that our method achieves superior task performance compared to other semi-structured pruning methods and achieves up-to 7→ speedup compared to llama.cpp framework on mobile devices. 1. We design a special kernel C n 4 and with a set of comprehensive compiler optimizations, including C n 4 -specific optimizations and layout transformation elimination strategy on mobile devices. 2. We propose the sparsity-oriented and/or latency-oriented sparse learning framework to explore the optimal pruning strategy with the proposed kernels for Mamba models. 3. We propose the weight compensation algorithm for the rectification of the sparse model weights by calibrating with only 128 samples, thereby further enhancing the model effectiveness. 4. Experiments show that our framework can achieve better task performance than other semistructure pruning methods and achieve pratical on-device speedup up to 7→ compared to llama.cpp.