MaskLLM:適用於大型語言模型的可學習半結構稀疏性
MaskLLM: Learnable Semi-Structured Sparsity for Large Language Models
September 26, 2024
作者: Gongfan Fang, Hongxu Yin, Saurav Muralidharan, Greg Heinrich, Jeff Pool, Jan Kautz, Pavlo Molchanov, Xinchao Wang
cs.AI
摘要
大型語言模型(LLMs)以其龐大的參數量而著稱,通常導致顯著的冗餘。本研究介紹了MaskLLM,一種可學習的修剪方法,該方法在LLMs中建立了半結構化(或“N:M”)稀疏性,旨在減少推論期間的計算開銷。MaskLLM並非開發新的重要性準則,而是通過Gumbel Softmax抽樣明確地將N:M模式建模為可學習的分佈。這種方法有助於在大規模數據集上進行端到端訓練,並提供了兩個顯著優勢:1)高質量的遮罩 - 我們的方法能夠有效擴展到大型數據集並學習準確的遮罩;2)可轉移性 - 遮罩分佈的概率建模使得可以跨領域或任務進行稀疏性的遷移學習。我們使用2:4的稀疏性對各種LLMs進行了MaskLLM評估,包括843M到15B參數範圍的LLaMA-2、Nemotron-4和GPT-3,我們的實驗結果顯示與最先進的方法相比有顯著改進。例如,領先的方法在Wikitext上的困惑度(PPL)達到10或更高,而與密集模型的5.12 PPL相比,MaskLLM僅通過學習凍結權重的遮罩就實現了明顯較低的6.72 PPL。此外,MaskLLM的可學習性使得可以為下游任務或領域定制遮罩以無損地應用2:4的稀疏性。代碼可在https://github.com/NVlabs/MaskLLM找到。
English
Large Language Models (LLMs) are distinguished by their massive parameter
counts, which typically result in significant redundancy. This work introduces
MaskLLM, a learnable pruning method that establishes Semi-structured (or
``N:M'') Sparsity in LLMs, aimed at reducing computational overhead during
inference. Instead of developing a new importance criterion, MaskLLM explicitly
models N:M patterns as a learnable distribution through Gumbel Softmax
sampling. This approach facilitates end-to-end training on large-scale datasets
and offers two notable advantages: 1) High-quality Masks - our method
effectively scales to large datasets and learns accurate masks; 2)
Transferability - the probabilistic modeling of mask distribution enables the
transfer learning of sparsity across domains or tasks. We assessed MaskLLM
using 2:4 sparsity on various LLMs, including LLaMA-2, Nemotron-4, and GPT-3,
with sizes ranging from 843M to 15B parameters, and our empirical results show
substantial improvements over state-of-the-art methods. For instance, leading
approaches achieve a perplexity (PPL) of 10 or greater on Wikitext compared to
the dense model's 5.12 PPL, but MaskLLM achieves a significantly lower 6.72 PPL
solely by learning the masks with frozen weights. Furthermore, MaskLLM's
learnable nature allows customized masks for lossless application of 2:4
sparsity to downstream tasks or domains. Code is available at
https://github.com/NVlabs/MaskLLM.Summary
AI-Generated Summary