细节中的魔鬼:关于实现用于训练专门混合专家模型的负载平衡损失
Demons in the Detail: On Implementing Load Balancing Loss for Training Specialized Mixture-of-Expert Models
January 21, 2025
作者: Zihan Qiu, Zeyu Huang, Bo Zheng, Kaiyue Wen, Zekun Wang, Rui Men, Ivan Titov, Dayiheng Liu, Jingren Zhou, Junyang Lin
cs.AI
摘要
本文重新审视了在训练混合专家模型(MoEs)时实现负载平衡损失(LBL)。具体而言,MoEs的LBL被定义为N_E sum_{i=1}^{N_E} f_i p_i,其中N_E是专家总数,f_i代表选择专家i的频率,p_i表示专家i的平均门控分数。现有的MoE训练框架通常采用并行训练策略,以便在微批次内计算f_i和LBL,然后在并行组中进行平均。实质上,用于训练数十亿规模的LLMs的微批次通常包含非常少的序列。因此,微批次的LBL几乎处于序列级别,路由器被推动以在每个序列内均匀分配令牌。在这种严格的约束下,即使来自特定领域序列(例如代码)的令牌也会均匀路由到所有专家,从而抑制了专家的专业化。在这项工作中,我们提出使用全局批次计算LBL以放松这种约束。因为全局批次包含比微批次更多样化的序列,这将促进语料库级别的负载平衡。具体而言,我们引入了额外的通信步骤来同步微批次间的f_i,然后用它来计算LBL。通过对基于MoEs的LLMs进行实验(总参数高达42.8B,令牌数达400B),我们惊讶地发现全局批次的LBL策略在预训练困惑度和下游任务中都取得了出色的性能提升。我们的分析表明,全局批次的LBL还极大地提高了MoE专家的领域专业化。
English
This paper revisits the implementation of
Load-balancing Loss (LBL) when training
Mixture-of-Experts (MoEs) models. Specifically, LBL for MoEs is defined as N_E
sum_{i=1}^{N_E} f_i p_i, where N_E is the total number of experts, f_i
represents the frequency of expert i being selected, and p_i denotes the
average gating score of the expert i. Existing MoE training frameworks
usually employ the parallel training strategy so that f_i and the LBL are
calculated within a micro-batch and then averaged across parallel
groups. In essence, a micro-batch for training billion-scale LLMs normally
contains very few sequences. So, the micro-batch LBL is almost at the sequence
level, and the router is pushed to distribute the token evenly within each
sequence. Under this strict constraint, even tokens from a domain-specific
sequence (e.g., code) are uniformly routed to all experts, thereby
inhibiting expert specialization. In this work, we propose calculating LBL
using a global-batch to loose this constraint. Because a
global-batch contains much more diverse sequences than a micro-batch, which
will encourage load balance at the corpus level. Specifically, we introduce an
extra communication step to synchronize f_i across micro-batches and then use
it to calculate the LBL. Through experiments on training MoEs-based LLMs (up to
42.8B total parameters and 400B tokens), we surprisingly
find that the global-batch LBL strategy yields excellent performance gains in
both pre-training perplexity and downstream tasks. Our analysis reveals that
the global-batch LBL also greatly improves the domain specialization of MoE
experts.Summary
AI-Generated Summary