不再使用Adam:初始化時學習率縮放就是你所需的一切
No More Adam: Learning Rate Scaling at Initialization is All You Need
December 16, 2024
作者: Minghao Xu, Lichuan Xiang, Xu Cai, Hongkai Wen
cs.AI
摘要
在這項工作中,我們質疑對於訓練深度神經網絡是否需要適應性梯度方法。SGD-SaI是對帶動量的隨機梯度下降(SGDM)的一種簡單而有效的增強方法。SGD-SaI在初始化時執行學習率初始化縮放(SaI),以區分參數組,並根據它們各自的梯度信噪比(g-SNR)進行引導。通過調整學習率,而不依賴於適應性的二階動量,SGD-SaI有助於從第一次迭代開始防止訓練不平衡,並且與AdamW相比,將優化器的內存使用量減半。儘管其簡單和高效,SGD-SaI在訓練各種基於Transformer的任務時始終能夠與或優於AdamW,有效地克服了使用SGD訓練Transformer的長期挑戰。SGD-SaI在ImageNet-1K分類中表現突出,尤其在使用Vision Transformers(ViT)進行GPT-2預訓練以訓練大型語言模型(LLMs,僅限transformer解碼器)時,展現了對超參數變化的穩健性和對各種應用的實用性。我們進一步在LoRA微調LLMs和擴散模型等任務上測試了其穩健性,結果顯示其在性能上一直優於最先進的優化器。從內存效率的角度來看,SGD-SaI為優化器狀態實現了可觀的內存節省,在全精度訓練設置中,相比AdamW,為GPT-2(15億參數)節省了5.93 GB的內存使用量,為Llama2-7B節省了25.15 GB。
English
In this work, we question the necessity of adaptive gradient methods for
training deep neural networks. SGD-SaI is a simple yet effective enhancement to
stochastic gradient descent with momentum (SGDM). SGD-SaI performs learning
rate Scaling at Initialization (SaI) to distinct parameter groups, guided by
their respective gradient signal-to-noise ratios (g-SNR). By adjusting learning
rates without relying on adaptive second-order momentum, SGD-SaI helps prevent
training imbalances from the very first iteration and cuts the optimizer's
memory usage by half compared to AdamW. Despite its simplicity and efficiency,
SGD-SaI consistently matches or outperforms AdamW in training a variety of
Transformer-based tasks, effectively overcoming a long-standing challenge of
using SGD for training Transformers. SGD-SaI excels in ImageNet-1K
classification with Vision Transformers(ViT) and GPT-2 pretraining for large
language models (LLMs, transformer decoder-only), demonstrating robustness to
hyperparameter variations and practicality for diverse applications. We further
tested its robustness on tasks like LoRA fine-tuning for LLMs and diffusion
models, where it consistently outperforms state-of-the-art optimizers. From a
memory efficiency perspective, SGD-SaI achieves substantial memory savings for
optimizer states, reducing memory usage by 5.93 GB for GPT-2 (1.5B parameters)
and 25.15 GB for Llama2-7B compared to AdamW in full-precision training
settings.Summary
AI-Generated Summary