p-MoD: 通过渐进比率衰减构建深度混合语言模型
p-MoD: Building Mixture-of-Depths MLLMs via Progressive Ratio Decay
December 5, 2024
作者: Jun Zhang, Desen Meng, Ji Qi, Zhenpeng Huang, Tao Wu, Limin Wang
cs.AI
摘要
尽管多模态大型语言模型(MLLMs)在各种任务上表现出色,但巨大的训练和推理成本阻碍了它们的进展。大部分计算来自于变压器解码器处理的庞大视觉标记数量。本文提出通过利用深度混合(MoD)机制构建高效MLLMs,其中每个变压器解码器层选择要处理的关键视觉标记,同时跳过冗余标记。然而,将MoD集成到MLLMs中并不是一件简单的事。为了解决训练和推理稳定性以及有限训练数据的挑战,我们使用两种新设计对MoD模块进行了调整:tanh门控权重归一化(TanhNorm)和对称标记重新加权(STRing)。此外,我们观察到视觉标记在更深层中存在更高的冗余性,因此设计了渐进比率衰减(PRD)策略,逐渐减少标记保留比例,采用了平移余弦调度。这一关键设计充分释放了MoD的潜力,显著提升了我们模型的效率和性能。为验证我们方法的有效性,我们在14个基准测试中对两个基准模型进行了大量实验。我们的模型p-MoD在推理过程中仅使用了基线模型的55.6% TFLOPs和53.8% KV缓存存储,并在训练过程中使用了77.7%的GPU小时,与基线模型的性能相匹敌甚至超越。
English
Despite the remarkable performance of multimodal large language models
(MLLMs) across diverse tasks, the substantial training and inference costs
impede their advancement. The majority of computation stems from the
overwhelming volume of vision tokens processed by the transformer decoder. In
this paper, we propose to build efficient MLLMs by leveraging the
Mixture-of-Depths (MoD) mechanism, where each transformer decoder layer selects
essential vision tokens to process while skipping redundant ones. However,
integrating MoD into MLLMs is non-trivial. To address the challenges of
training and inference stability as well as limited training data, we adapt the
MoD module with two novel designs: tanh-gated weight normalization (TanhNorm)
and symmetric token reweighting (STRing). Moreover, we observe that vision
tokens exhibit higher redundancy in deeper layer and thus design a progressive
ratio decay (PRD) strategy, which gradually reduces the token retention ratio
layer by layer, employing a shifted cosine schedule. This crucial design fully
unleashes the potential of MoD, significantly boosting the efficiency and
performance of our models. To validate the effectiveness of our approach, we
conduct extensive experiments with two baseline models across 14 benchmarks.
Our model, p-MoD, matches or even surpasses the performance of the baseline
models, with only 55.6% TFLOPs and 53.8% KV cache storage during inference, and
77.7% GPU hours during training.Summary
AI-Generated Summary