Taglia le tue perdite nei modelli linguistici ad ampio vocabolario
Cut Your Losses in Large-Vocabulary Language Models
November 13, 2024
Autori: Erik Wijmans, Brody Huval, Alexander Hertzberg, Vladlen Koltun, Philipp Krähenbühl
cs.AI
Abstract
Man mano che i modelli linguistici crescono sempre di più, crescono anche i loro vocabolari. Ciò ha spostato in modo sproporzionato l'impronta di memoria dei LLM durante l'addestramento su un singolo strato: l'entropia incrociata nel calcolo della perdita. L'entropia incrociata costruisce una matrice di logit con voci per ciascuna coppia di token di input e elementi del vocabolario e, per modelli piccoli, consuma un ordine di grandezza di memoria maggiore rispetto al resto del LLM combinato. Proponiamo Cut Cross-Entropy (CCE), un metodo che calcola la perdita di entropia incrociata senza materializzare i logit per tutti i token nella memoria globale. Piuttosto, CCE calcola solo il logit per il token corretto e valuta il log-sum-exp su tutti i logit al volo. Implementiamo un kernel personalizzato che esegue le moltiplicazioni delle matrici e la riduzione del log-sum-exp sul vocabolario nella memoria flash, rendendo trascurabile il consumo di memoria globale per il calcolo dell'entropia incrociata. Ciò ha un effetto drammatico. Prendendo come esempio il modello Gemma 2 (2B), CCE riduce l'impronta di memoria del calcolo della perdita da 24 GB a 1 MB e il consumo di memoria totale durante il tempo di addestramento della testa del classificatore da 28 GB a 1 GB. Per migliorare il throughput di CCE, sfruttiamo la sparità intrinseca del softmax e proponiamo di saltare gli elementi del calcolo del gradiente che hanno un contributo trascurabile (cioè al di sotto della precisione numerica) al gradiente. Gli esperimenti dimostrano che la drastica riduzione del consumo di memoria è realizzata senza sacrificare la velocità di addestramento o la convergenza.
English
As language models grow ever larger, so do their vocabularies. This has
shifted the memory footprint of LLMs during training disproportionately to one
single layer: the cross-entropy in the loss computation. Cross-entropy builds
up a logit matrix with entries for each pair of input tokens and vocabulary
items and, for small models, consumes an order of magnitude more memory than
the rest of the LLM combined. We propose Cut Cross-Entropy (CCE), a method that
computes the cross-entropy loss without materializing the logits for all tokens
into global memory. Rather, CCE only computes the logit for the correct token
and evaluates the log-sum-exp over all logits on the fly. We implement a custom
kernel that performs the matrix multiplications and the log-sum-exp reduction
over the vocabulary in flash memory, making global memory consumption for the
cross-entropy computation negligible. This has a dramatic effect. Taking the
Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss
computation from 24 GB to 1 MB, and the total training-time memory consumption
of the classifier head from 28 GB to 1 GB. To improve the throughput of CCE, we
leverage the inherent sparsity of softmax and propose to skip elements of the
gradient computation that have a negligible (i.e., below numerical precision)
contribution to the gradient. Experiments demonstrate that the dramatic
reduction in memory consumption is accomplished without sacrificing training
speed or convergence.Summary
AI-Generated Summary