标准化流是一种强大的生成模型。
Normalizing Flows are Capable Generative Models
December 9, 2024
作者: Shuangfei Zhai, Ruixiang Zhang, Preetum Nakkiran, David Berthelot, Jiatao Gu, Huangjie Zheng, Tianrong Chen, Miguel Angel Bautista, Navdeep Jaitly, Josh Susskind
cs.AI
摘要
归一化流(NFs)是针对连续输入的基于似然的模型。它们在密度估计和生成建模任务上展现出有希望的结果,但近年来受到相对较少的关注。在这项工作中,我们展示了NFs比以往认为的更强大。我们提出了TarFlow:一个简单且可扩展的架构,可以实现高性能的NF模型。TarFlow可以被视为基于Transformer的Masked Autoregressive Flows(MAFs)的变体:它由一堆自回归Transformer块组成,应用于图像块,交替地在层之间改变自回归方向。TarFlow易于端到端训练,能够直接对像素进行建模和生成。我们还提出了三种关键技术来提高样本质量:训练过程中的高斯噪声增强,训练后的去噪过程,以及适用于有条件和无条件设置的有效引导方法。将这些技术结合起来,TarFlow在图像的似然估计方面取得了新的最先进结果,大幅超过以往最佳方法,并生成了与扩散模型相媲美的质量和多样性样本,这是首次使用独立的NF模型。我们在https://github.com/apple/ml-tarflow上提供了我们的代码。
English
Normalizing Flows (NFs) are likelihood-based models for continuous inputs.
They have demonstrated promising results on both density estimation and
generative modeling tasks, but have received relatively little attention in
recent years. In this work, we demonstrate that NFs are more powerful than
previously believed. We present TarFlow: a simple and scalable architecture
that enables highly performant NF models. TarFlow can be thought of as a
Transformer-based variant of Masked Autoregressive Flows (MAFs): it consists of
a stack of autoregressive Transformer blocks on image patches, alternating the
autoregression direction between layers. TarFlow is straightforward to train
end-to-end, and capable of directly modeling and generating pixels. We also
propose three key techniques to improve sample quality: Gaussian noise
augmentation during training, a post training denoising procedure, and an
effective guidance method for both class-conditional and unconditional
settings. Putting these together, TarFlow sets new state-of-the-art results on
likelihood estimation for images, beating the previous best methods by a large
margin, and generates samples with quality and diversity comparable to
diffusion models, for the first time with a stand-alone NF model. We make our
code available at https://github.com/apple/ml-tarflow.Summary
AI-Generated Summary