基于自适应加权拒绝采样的语言模型快速可控生成
Fast Controlled Generation from Language Models with Adaptive Weighted Rejection Sampling
April 7, 2025
作者: Benjamin Lipkin, Benjamin LeBrun, Jacob Hoover Vigly, João Loula, David R. MacIver, Li Du, Jason Eisner, Ryan Cotterell, Vikash Mansinghka, Timothy J. O'Donnell, Alexander K. Lew, Tim Vieira
cs.AI
摘要
在语言模型生成过程中,受限于某些约束条件时,主流方法是局部约束解码(LCD),即在每个时间步逐步采样令牌,确保约束始终不被违反。通常,这是通过令牌掩码实现的:遍历词汇表并排除不符合条件的令牌。然而,这种方法存在两个重要问题:(i) 对每个令牌评估约束条件可能极其耗时——语言模型的词汇量常超过10万;(ii) LCD可能扭曲字符串的全局分布,仅基于局部信息采样令牌,即便这些令牌可能导致死胡同。本研究引入了一种新算法,旨在同时解决这两个问题。首先,为避免在生成过程中每一步都对完整词汇表进行约束评估,我们提出了一种自适应拒绝采样算法,该算法通常能大幅减少约束评估次数。其次,我们展示了如何以极小的额外成本扩展此算法,以生成低方差、无偏的重要性权重估计——这些估计可稳妥地用于先前提出的序列蒙特卡洛算法中,以纠正局部约束执行的短视行为。通过在文本到SQL、分子合成、目标推理、模式匹配及JSON领域的大量实证评估,我们证明了该方法优于现有最先进的基线,支持更广泛的约束类别,并在运行时间和性能上均有提升。进一步的理论与实证分析表明,我们方法的运行效率得益于其动态计算利用,其计算量随无约束与约束语言模型间的差异而调整,因此,对于更优的模型,运行时间的改善更为显著。
English
The dominant approach to generating from language models subject to some
constraint is locally constrained decoding (LCD), incrementally sampling tokens
at each time step such that the constraint is never violated. Typically, this
is achieved through token masking: looping over the vocabulary and excluding
non-conforming tokens. There are two important problems with this approach. (i)
Evaluating the constraint on every token can be prohibitively expensive -- LM
vocabularies often exceed 100,000 tokens. (ii) LCD can distort the global
distribution over strings, sampling tokens based only on local information,
even if they lead down dead-end paths. This work introduces a new algorithm
that addresses both these problems. First, to avoid evaluating a constraint on
the full vocabulary at each step of generation, we propose an adaptive
rejection sampling algorithm that typically requires orders of magnitude fewer
constraint evaluations. Second, we show how this algorithm can be extended to
produce low-variance, unbiased estimates of importance weights at a very small
additional cost -- estimates that can be soundly used within previously
proposed sequential Monte Carlo algorithms to correct for the myopic behavior
of local constraint enforcement. Through extensive empirical evaluation in
text-to-SQL, molecular synthesis, goal inference, pattern matching, and JSON
domains, we show that our approach is superior to state-of-the-art baselines,
supporting a broader class of constraints and improving both runtime and
performance. Additional theoretical and empirical analyses show that our
method's runtime efficiency is driven by its dynamic use of computation,
scaling with the divergence between the unconstrained and constrained LM, and
as a consequence, runtime improvements are greater for better models.Summary
AI-Generated Summary