不再使用Adam:学习率缩放初始化即可
No More Adam: Learning Rate Scaling at Initialization is All You Need
December 16, 2024
作者: Minghao Xu, Lichuan Xiang, Xu Cai, Hongkai Wen
cs.AI
摘要
在这项工作中,我们质疑了对训练深度神经网络是否需要自适应梯度方法。SGD-SaI是对带动量的随机梯度下降(SGDM)的一种简单而有效的增强方法。SGD-SaI在初始化时执行学习率缩放(SaI),针对不同的参数组进行,根据它们各自的梯度信噪比(g-SNR)进行引导。通过调整学习率,而不依赖自适应的二阶动量,SGD-SaI有助于防止训练不平衡从第一次迭代开始,并且与AdamW相比,将优化器的内存使用减少了一半。尽管其简单性和效率,SGD-SaI在训练各种基于Transformer的任务时始终能够与或胜过AdamW,有效地克服了使用SGD训练Transformer长期存在的挑战。SGD-SaI在ImageNet-1K分类中表现出色,使用Vision Transformers(ViT)和GPT-2预训练大型语言模型(LLMs,仅限transformer解码器),展现出对超参数变化的稳健性和适用于多样应用的实用性。我们进一步测试了其在LoRA微调LLMs和扩散模型等任务上的稳健性,在这些任务中,它始终优于最先进的优化器。从内存效率的角度看,SGD-SaI为优化器状态实现了大量内存节省,在全精度训练设置中,与AdamW相比,为GPT-2(15亿参数)节省了5.93GB的内存使用量,为Llama2-7B节省了25.15GB。
English
In this work, we question the necessity of adaptive gradient methods for
training deep neural networks. SGD-SaI is a simple yet effective enhancement to
stochastic gradient descent with momentum (SGDM). SGD-SaI performs learning
rate Scaling at Initialization (SaI) to distinct parameter groups, guided by
their respective gradient signal-to-noise ratios (g-SNR). By adjusting learning
rates without relying on adaptive second-order momentum, SGD-SaI helps prevent
training imbalances from the very first iteration and cuts the optimizer's
memory usage by half compared to AdamW. Despite its simplicity and efficiency,
SGD-SaI consistently matches or outperforms AdamW in training a variety of
Transformer-based tasks, effectively overcoming a long-standing challenge of
using SGD for training Transformers. SGD-SaI excels in ImageNet-1K
classification with Vision Transformers(ViT) and GPT-2 pretraining for large
language models (LLMs, transformer decoder-only), demonstrating robustness to
hyperparameter variations and practicality for diverse applications. We further
tested its robustness on tasks like LoRA fine-tuning for LLMs and diffusion
models, where it consistently outperforms state-of-the-art optimizers. From a
memory efficiency perspective, SGD-SaI achieves substantial memory savings for
optimizer states, reducing memory usage by 5.93 GB for GPT-2 (1.5B parameters)
and 25.15 GB for Llama2-7B compared to AdamW in full-precision training
settings.Summary
AI-Generated Summary