Potatura Computazionale Adattiva per il Transformer con Dimenticanza
Adaptive Computation Pruning for the Forgetting Transformer
April 9, 2025
Autori: Zhixuan Lin, Johan Obando-Ceron, Xu Owen He, Aaron Courville
cs.AI
Abstract
Il recentemente proposto Forgetting Transformer (FoX) incorpora un gate di dimenticanza nell'attenzione softmax e ha dimostrato prestazioni costantemente migliori o equivalenti rispetto al Transformer standard basato su RoPE. In particolare, molte teste di attenzione in FoX tendono a dimenticare rapidamente, facendo sì che il loro output ad ogni passo temporale dipenda principalmente dal contesto locale. Sulla base di questa osservazione, proponiamo l'Adaptive Computation Pruning (ACP) per FoX, un metodo che pota dinamicamente i calcoli che coinvolgono le dipendenze input-output che sono fortemente attenuate dal gate di dimenticanza. Questo è ottenuto utilizzando una soglia di potatura impostata dinamicamente che garantisce che i pesi di attenzione potati rimangano trascurabili. Applichiamo ACP al pretraining di modelli linguistici con FoX e dimostriamo che riduce costantemente il numero di FLOP nell'attenzione softmax di circa il 70% su diverse dimensioni del modello e lunghezze del contesto, risultando in un miglioramento del throughput di addestramento di circa il 10% al 35%. Inoltre, lunghezze del contesto più lunghe producono maggiori risparmi computazionali. Tutti questi miglioramenti di velocità sono ottenuti senza alcuna degradazione delle prestazioni. Eseguiamo anche diverse analisi per fornire approfondimenti sul nostro metodo, come l'esame dei modelli di potatura e l'analisi della distribuzione dei risparmi di FLOP tra le diverse teste di attenzione. Il nostro codice è disponibile all'indirizzo https://github.com/zhixuan-lin/arctic-fox.
English
The recently proposed Forgetting Transformer (FoX) incorporates a forget gate
into softmax attention and has shown consistently better or on-par performance
compared to the standard RoPE-based Transformer. Notably, many attention heads
in FoX tend to forget quickly, causing their output at each timestep to rely
primarily on the local context. Based on this observation, we propose Adaptive
Computation Pruning (ACP) for FoX, a method that dynamically prunes
computations involving input-output dependencies that are strongly decayed by
the forget gate. This is achieved using a dynamically set pruning threshold
that ensures that the pruned attention weights remain negligible. We apply ACP
to language model pretraining with FoX and show it consistently reduces the
number of FLOPs in softmax attention by around 70% across different model sizes
and context lengths, resulting in a roughly 10% to 35% improvement in training
throughput. Furthermore, longer context lengths yield greater computational
savings. All these speed improvements are achieved without any performance
degradation. We also perform several analyses to provide deeper insights into
our method, such as examining the pruning patterns and analyzing the
distribution of FLOP savings across different attention heads. Our code is
available at https://github.com/zhixuan-lin/arctic-fox.Summary
AI-Generated Summary