Taglia le tue perdite nei modelli linguistici ad ampio vocabolario

Cut Your Losses in Large-Vocabulary Language Models

November 13, 2024
Autori: Erik Wijmans, Brody Huval, Alexander Hertzberg, Vladlen Koltun, Philipp Krähenbühl
cs.AI

Abstract

Man mano che i modelli linguistici crescono sempre di più, crescono anche i loro vocabolari. Ciò ha spostato in modo sproporzionato l'impronta di memoria dei LLM durante l'addestramento su un singolo strato: l'entropia incrociata nel calcolo della perdita. L'entropia incrociata costruisce una matrice di logit con voci per ciascuna coppia di token di input e elementi del vocabolario e, per modelli piccoli, consuma un ordine di grandezza di memoria maggiore rispetto al resto del LLM combinato. Proponiamo Cut Cross-Entropy (CCE), un metodo che calcola la perdita di entropia incrociata senza materializzare i logit per tutti i token nella memoria globale. Piuttosto, CCE calcola solo il logit per il token corretto e valuta il log-sum-exp su tutti i logit al volo. Implementiamo un kernel personalizzato che esegue le moltiplicazioni delle matrici e la riduzione del log-sum-exp sul vocabolario nella memoria flash, rendendo trascurabile il consumo di memoria globale per il calcolo dell'entropia incrociata. Ciò ha un effetto drammatico. Prendendo come esempio il modello Gemma 2 (2B), CCE riduce l'impronta di memoria del calcolo della perdita da 24 GB a 1 MB e il consumo di memoria totale durante il tempo di addestramento della testa del classificatore da 28 GB a 1 GB. Per migliorare il throughput di CCE, sfruttiamo la sparità intrinseca del softmax e proponiamo di saltare gli elementi del calcolo del gradiente che hanno un contributo trascurabile (cioè al di sotto della precisione numerica) al gradiente. Gli esperimenti dimostrano che la drastica riduzione del consumo di memoria è realizzata senza sacrificare la velocità di addestramento o la convergenza.
English
As language models grow ever larger, so do their vocabularies. This has shifted the memory footprint of LLMs during training disproportionately to one single layer: the cross-entropy in the loss computation. Cross-entropy builds up a logit matrix with entries for each pair of input tokens and vocabulary items and, for small models, consumes an order of magnitude more memory than the rest of the LLM combined. We propose Cut Cross-Entropy (CCE), a method that computes the cross-entropy loss without materializing the logits for all tokens into global memory. Rather, CCE only computes the logit for the correct token and evaluates the log-sum-exp over all logits on the fly. We implement a custom kernel that performs the matrix multiplications and the log-sum-exp reduction over the vocabulary in flash memory, making global memory consumption for the cross-entropy computation negligible. This has a dramatic effect. Taking the Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss computation from 24 GB to 1 MB, and the total training-time memory consumption of the classifier head from 28 GB to 1 GB. To improve the throughput of CCE, we leverage the inherent sparsity of softmax and propose to skip elements of the gradient computation that have a negligible (i.e., below numerical precision) contribution to the gradient. Experiments demonstrate that the dramatic reduction in memory consumption is accomplished without sacrificing training speed or convergence.

Summary

AI-Generated Summary

PDF384November 15, 2024