在大詞彙語言模型中減少損失
Cut Your Losses in Large-Vocabulary Language Models
November 13, 2024
作者: Erik Wijmans, Brody Huval, Alexander Hertzberg, Vladlen Koltun, Philipp Krähenbühl
cs.AI
摘要
隨著語言模型變得越來越龐大,它們的詞彙量也隨之增加。這導致在訓練期間,LLM 的記憶體佔用不成比例地轉移到一個單獨的層:交叉熵在損失計算中。交叉熵建立了一個 logit 矩陣,其中包含每對輸入標記和詞彙項的條目,對於小型模型而言,其消耗的記憶體比LLM 的其餘部分多一個數量級。我們提出了Cut Cross-Entropy(CCE)方法,該方法計算交叉熵損失,而無需將所有標記的 logits 實現為全局記憶體中的矩陣。相反,CCE僅計算正確標記的 logits,並即時評估所有 logits 的 log-sum-exp。我們實現了一個自定義核心,該核心在快閃記憶體中對詞彙進行矩陣乘法和 log-sum-exp 減少,使得交叉熵計算的全局記憶體消耗可以忽略不計。這產生了戲劇性的效果。以Gemma 2(2B)模型為例,CCE將損失計算的記憶體佔用從24 GB減少到1 MB,將分類器頭部的總訓練時間記憶體消耗從28 GB降至1 GB。為了提高CCE的吞吐量,我們利用 softmax 的固有稀疏性,並建議跳過對梯度計算貢獻微不足道(即低於數值精度)的元素。實驗表明,記憶體消耗的戲劇性減少是在不犧牲訓練速度或收斂性的情況下實現的。
English
As language models grow ever larger, so do their vocabularies. This has
shifted the memory footprint of LLMs during training disproportionately to one
single layer: the cross-entropy in the loss computation. Cross-entropy builds
up a logit matrix with entries for each pair of input tokens and vocabulary
items and, for small models, consumes an order of magnitude more memory than
the rest of the LLM combined. We propose Cut Cross-Entropy (CCE), a method that
computes the cross-entropy loss without materializing the logits for all tokens
into global memory. Rather, CCE only computes the logit for the correct token
and evaluates the log-sum-exp over all logits on the fly. We implement a custom
kernel that performs the matrix multiplications and the log-sum-exp reduction
over the vocabulary in flash memory, making global memory consumption for the
cross-entropy computation negligible. This has a dramatic effect. Taking the
Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss
computation from 24 GB to 1 MB, and the total training-time memory consumption
of the classifier head from 28 GB to 1 GB. To improve the throughput of CCE, we
leverage the inherent sparsity of softmax and propose to skip elements of the
gradient computation that have a negligible (i.e., below numerical precision)
contribution to the gradient. Experiments demonstrate that the dramatic
reduction in memory consumption is accomplished without sacrificing training
speed or convergence.Summary
AI-Generated Summary