p-MoD: 透過漸進比例衰減建立深度混合概率語言模型
p-MoD: Building Mixture-of-Depths MLLMs via Progressive Ratio Decay
December 5, 2024
作者: Jun Zhang, Desen Meng, Ji Qi, Zhenpeng Huang, Tao Wu, Limin Wang
cs.AI
摘要
儘管多模式大型語言模型(MLLMs)在各種任務中表現出色,但龐大的訓練和推理成本阻礙了它們的進展。大部分計算來自於變壓器解碼器處理的龐大視覺標記數量。本文提出通過利用深度混合(MoD)機制來構建高效的MLLMs,其中每個變壓器解碼器層選擇要處理的關鍵視覺標記,同時跳過冗餘的標記。然而,將MoD整合到MLLMs中並不簡單。為了應對訓練和推理穩定性以及有限的訓練數據的挑戰,我們通過兩種新設計對MoD模塊進行了調整:tanh閘控權重歸一化(TanhNorm)和對稱標記重新加權(STRing)。此外,我們觀察到視覺標記在較深層中存在較高的冗餘性,因此設計了漸進比例衰減(PRD)策略,逐層逐步降低標記保留比例,採用了一個漸變的余弦時間表。這一關鍵設計充分發揮了MoD的潛力,顯著提升了我們模型的效率和性能。為驗證我們方法的有效性,我們在14個基準測試中對兩個基準模型進行了廣泛實驗。我們的模型p-MoD在推理過程中只使用了55.6%的TFLOPs和53.8%的KV緩存存儲,訓練過程中只使用了77.7%的GPU時數,與基準模型的性能相當甚至更優。
English
Despite the remarkable performance of multimodal large language models
(MLLMs) across diverse tasks, the substantial training and inference costs
impede their advancement. The majority of computation stems from the
overwhelming volume of vision tokens processed by the transformer decoder. In
this paper, we propose to build efficient MLLMs by leveraging the
Mixture-of-Depths (MoD) mechanism, where each transformer decoder layer selects
essential vision tokens to process while skipping redundant ones. However,
integrating MoD into MLLMs is non-trivial. To address the challenges of
training and inference stability as well as limited training data, we adapt the
MoD module with two novel designs: tanh-gated weight normalization (TanhNorm)
and symmetric token reweighting (STRing). Moreover, we observe that vision
tokens exhibit higher redundancy in deeper layer and thus design a progressive
ratio decay (PRD) strategy, which gradually reduces the token retention ratio
layer by layer, employing a shifted cosine schedule. This crucial design fully
unleashes the potential of MoD, significantly boosting the efficiency and
performance of our models. To validate the effectiveness of our approach, we
conduct extensive experiments with two baseline models across 14 benchmarks.
Our model, p-MoD, matches or even surpasses the performance of the baseline
models, with only 55.6% TFLOPs and 53.8% KV cache storage during inference, and
77.7% GPU hours during training.Summary
AI-Generated Summary