SeerAttention : Apprentissage de l'Attention Intrinsèquement Éparse dans Vos LLMs

SeerAttention: Learning Intrinsic Sparse Attention in Your LLMs

October 17, 2024
Auteurs: Yizhao Gao, Zhichen Zeng, Dayou Du, Shijie Cao, Hayden Kwok-Hay So, Ting Cao, Fan Yang, Mao Yang
cs.AI

Résumé

L'attention est la pierre angulaire des modèles de langage de grande taille (LLM) modernes. Cependant, sa complexité quadratique limite l'efficacité et la scalabilité des LLM, en particulier pour ceux avec une fenêtre de contexte longue. Une approche prometteuse pour surmonter cette limitation est d'exploiter la parcimonie dans l'attention. Cependant, les solutions basées sur la parcimonie existantes reposent principalement sur des motifs prédéfinis ou des heuristiques pour approximer la parcimonie. Cette pratique ne parvient pas à capturer pleinement la nature dynamique de la parcimonie de l'attention dans les tâches basées sur le langage. Cet article soutient que la parcimonie de l'attention devrait être apprise plutôt que prédéfinie. À cette fin, nous concevons SeerAttention, un nouveau mécanisme d'attention qui complète l'attention conventionnelle avec une porte apprenante qui sélectionne de manière adaptative des blocs significatifs dans une carte d'attention et considère les autres blocs comme parcimonieux. Une telle parcimonie au niveau des blocs équilibre efficacement précision et accélération. Pour permettre l'apprentissage efficace du réseau de portes, nous développons une implémentation FlashAttention personnalisée qui extrait la vérité terrain au niveau des blocs de la carte d'attention avec un minimum de surcharge. SeerAttention s'applique non seulement à la post-formation, mais excelle également dans le fine-tuning à long contexte. Nos résultats montrent qu'aux étapes de post-formation, SeerAttention surpasse significativement les méthodes d'attention parcimonieuses statiques ou basées sur des heuristiques de pointe, tout en étant plus polyvalent et flexible pour s'adapter à des longueurs de contexte variables et à des taux de parcimonie. Lorsqu'appliqué au fine-tuning à long contexte avec YaRN, SeerAttention peut atteindre un remarquable taux de parcimonie de 90% avec une longueur de contexte de 32k et une perte de perplexité minimale, offrant un gain de vitesse de 5,67 fois par rapport à FlashAttention-2.
English
Attention is the cornerstone of modern Large Language Models (LLMs). Yet its quadratic complexity limits the efficiency and scalability of LLMs, especially for those with a long-context window. A promising approach addressing this limitation is to leverage the sparsity in attention. However, existing sparsity-based solutions predominantly rely on predefined patterns or heuristics to approximate sparsity. This practice falls short to fully capture the dynamic nature of attention sparsity in language-based tasks. This paper argues that attention sparsity should be learned rather than predefined. To this end, we design SeerAttention, a new Attention mechanism that augments the conventional attention with a learnable gate that adaptively selects significant blocks in an attention map and deems the rest blocks sparse. Such block-level sparsity effectively balances accuracy and speedup. To enable efficient learning of the gating network, we develop a customized FlashAttention implementation that extracts the block-level ground truth of attention map with minimum overhead. SeerAttention not only applies to post-training, but also excels in long-context fine-tuning. Our results show that at post-training stages, SeerAttention significantly outperforms state-of-the-art static or heuristic-based sparse attention methods, while also being more versatile and flexible to adapt to varying context lengths and sparsity ratios. When applied to long-context fine-tuning with YaRN, SeerAttention can achieve a remarkable 90% sparsity ratio at a 32k context length with minimal perplexity loss, offering a 5.67x speedup over FlashAttention-2.

Summary

AI-Generated Summary

PDF252November 16, 2024