正規化流是具有生成模型能力的。
Normalizing Flows are Capable Generative Models
December 9, 2024
作者: Shuangfei Zhai, Ruixiang Zhang, Preetum Nakkiran, David Berthelot, Jiatao Gu, Huangjie Zheng, Tianrong Chen, Miguel Angel Bautista, Navdeep Jaitly, Josh Susskind
cs.AI
摘要
正規化流(NFs)是針對連續輸入的基於概似的模型。它們在密度估計和生成建模任務上展示出有希望的結果,但近年來卻受到相對較少的關注。在這項工作中,我們展示了NFs比先前認為的更強大。我們提出了TarFlow:一種簡單且可擴展的架構,可實現高性能的NF模型。TarFlow可以被視為基於Transformer的Masked Autoregressive Flows(MAFs)的變體:它由一堆自回歸Transformer塊組成,應用在圖像塊上,並在層之間交替自回歸方向。TarFlow易於端對端訓練,能夠直接建模和生成像素。我們還提出了三個關鍵技術來提高樣本質量:在訓練期間進行高斯噪聲增強,一種訓練後的去噪過程,以及一種有效的引導方法,適用於有條件和無條件設置。將這些技術結合在一起,TarFlow在圖像的概似估計方面取得了新的最佳結果,大幅超越先前最佳方法,並生成了質量和多樣性與擴散模型相當的樣本,這是首次使用獨立的NF模型。我們在https://github.com/apple/ml-tarflow 上提供我們的代碼。
English
Normalizing Flows (NFs) are likelihood-based models for continuous inputs.
They have demonstrated promising results on both density estimation and
generative modeling tasks, but have received relatively little attention in
recent years. In this work, we demonstrate that NFs are more powerful than
previously believed. We present TarFlow: a simple and scalable architecture
that enables highly performant NF models. TarFlow can be thought of as a
Transformer-based variant of Masked Autoregressive Flows (MAFs): it consists of
a stack of autoregressive Transformer blocks on image patches, alternating the
autoregression direction between layers. TarFlow is straightforward to train
end-to-end, and capable of directly modeling and generating pixels. We also
propose three key techniques to improve sample quality: Gaussian noise
augmentation during training, a post training denoising procedure, and an
effective guidance method for both class-conditional and unconditional
settings. Putting these together, TarFlow sets new state-of-the-art results on
likelihood estimation for images, beating the previous best methods by a large
margin, and generates samples with quality and diversity comparable to
diffusion models, for the first time with a stand-alone NF model. We make our
code available at https://github.com/apple/ml-tarflow.Summary
AI-Generated Summary