張量乘積注意力就是你所需的

Tensor Product Attention Is All You Need

January 11, 2025
作者: Yifan Zhang, Yifeng Liu, Huizhuo Yuan, Zhen Qin, Yang Yuan, Quanquan Gu, Andrew Chi-Chih Yao
cs.AI

摘要

將語言模型擴展以處理較長的輸入序列通常需要大型的鍵-值(KV)緩存,這導致推論過程中存在重大的內存開銷。在本文中,我們提出了張量乘積注意力(TPA),這是一種使用張量分解來緊湊表示查詢、鍵和值的新型注意力機制,顯著地縮小了推論時的KV緩存大小。通過將這些表示因子分解為上下文低秩組件(上下文分解),並與RoPE無縫集成,TPA實現了模型質量的提升以及內存效率。基於TPA,我們引入了Tensor ProducT ATTenTion Transformer(T6),這是一種用於序列建模的新模型架構。通過對語言建模任務的廣泛實證評估,我們展示了T6在各種指標上超越了標準Transformer基準模型,包括MHA、MQA、GQA和MLA,包括困惑度和一系列知名評估基準。值得注意的是,TPA的內存效率使其能夠在固定資源限制下處理更長的序列,解決了現代語言模型中的一個關鍵可擴展性挑戰。代碼可在https://github.com/tensorgi/T6找到。
English
Scaling language models to handle longer input sequences typically necessitates large key-value (KV) caches, resulting in substantial memory overhead during inference. In this paper, we propose Tensor Product Attention (TPA), a novel attention mechanism that uses tensor decompositions to represent queries, keys, and values compactly, significantly shrinking KV cache size at inference time. By factorizing these representations into contextual low-rank components (contextual factorization) and seamlessly integrating with RoPE, TPA achieves improved model quality alongside memory efficiency. Based on TPA, we introduce the Tensor ProducT ATTenTion Transformer (T6), a new model architecture for sequence modeling. Through extensive empirical evaluation of language modeling tasks, we demonstrate that T6 exceeds the performance of standard Transformer baselines including MHA, MQA, GQA, and MLA across various metrics, including perplexity and a range of renowned evaluation benchmarks. Notably, TPAs memory efficiency enables the processing of significantly longer sequences under fixed resource constraints, addressing a critical scalability challenge in modern language models. The code is available at https://github.com/tensorgi/T6.

Summary

AI-Generated Summary

PDF664January 14, 2025