SeerAttention: 당신의 LLMs에서 내재적으로 희소한 주의력 학습

SeerAttention: Learning Intrinsic Sparse Attention in Your LLMs

October 17, 2024
저자: Yizhao Gao, Zhichen Zeng, Dayou Du, Shijie Cao, Hayden Kwok-Hay So, Ting Cao, Fan Yang, Mao Yang
cs.AI

초록

현대 대형 언어 모델(Large Language Models, LLMs)의 중심에는 주의(Attention)가 있다. 그러나 이차 복잡도는 특히 긴 문맥 창을 갖는 LLMs의 효율성과 확장성을 제한한다. 이 제한을 해결하는 유망한 접근 방식은 주의의 희소성을 활용하는 것이다. 그러나 기존의 희소성 기반 솔루션은 주로 사전 정의된 패턴이나 휴리스틱을 사용하여 희소성을 근사한다. 이러한 방법은 언어 기반 작업에서 주의의 희소성의 동적 성질을 완전히 포착하지 못한다. 본 논문은 주의의 희소성을 사전 정의하는 대신 학습되어야 한다고 주장한다. 이를 위해 우리는 SeerAttention이라는 새로운 주의 메커니즘을 설계했다. 이 메커니즘은 학습 가능한 게이트를 사용하여 기존 주의에 적응적으로 중요한 블록을 선택하고 나머지 블록을 희소하다고 판단한다. 이러한 블록 수준의 희소성은 정확도와 가속을 효과적으로 균형있게 유지한다. 게이트 네트워크를 효율적으로 학습하기 위해 우리는 주의 맵의 블록 수준 ground truth를 최소한의 오버헤드로 추출하는 사용자 정의 FlashAttention 구현을 개발했다. SeerAttention은 사후 훈련에만 적용되는 것이 아니라 긴 문맥의 세밀한 튜닝에서도 뛰어난 성과를 보여준다. 결과는 사후 훈련 단계에서 SeerAttention이 최첨단 정적 또는 휴리스틱 기반 희소 주의 방법보다 획기적으로 우수하며, 다양한 문맥 길이와 희소 비율에 적응하는 데 더욱 다재다능하고 유연하다는 것을 보여준다. YaRN을 사용한 긴 문맥 세밀 튜닝에 적용할 때, SeerAttention은 최소의 편의 손실로 32k 문맥 길이에서 놀라운 90%의 희소 비율을 달성할 수 있으며, FlashAttention-2보다 5.67배의 가속을 제공한다.
English
Attention is the cornerstone of modern Large Language Models (LLMs). Yet its quadratic complexity limits the efficiency and scalability of LLMs, especially for those with a long-context window. A promising approach addressing this limitation is to leverage the sparsity in attention. However, existing sparsity-based solutions predominantly rely on predefined patterns or heuristics to approximate sparsity. This practice falls short to fully capture the dynamic nature of attention sparsity in language-based tasks. This paper argues that attention sparsity should be learned rather than predefined. To this end, we design SeerAttention, a new Attention mechanism that augments the conventional attention with a learnable gate that adaptively selects significant blocks in an attention map and deems the rest blocks sparse. Such block-level sparsity effectively balances accuracy and speedup. To enable efficient learning of the gating network, we develop a customized FlashAttention implementation that extracts the block-level ground truth of attention map with minimum overhead. SeerAttention not only applies to post-training, but also excels in long-context fine-tuning. Our results show that at post-training stages, SeerAttention significantly outperforms state-of-the-art static or heuristic-based sparse attention methods, while also being more versatile and flexible to adapt to varying context lengths and sparsity ratios. When applied to long-context fine-tuning with YaRN, SeerAttention can achieve a remarkable 90% sparsity ratio at a 32k context length with minimal perplexity loss, offering a 5.67x speedup over FlashAttention-2.

Summary

AI-Generated Summary

PDF252November 16, 2024