LLM 多步推理的離線強化學習
Offline Reinforcement Learning for LLM Multi-Step Reasoning
December 20, 2024
作者: Huaijie Wang, Shibo Hao, Hanze Dong, Shenao Zhang, Yilin Bao, Ziran Yang, Yi Wu
cs.AI
摘要
通過離線強化學習(RL)來提升大型語言模型(LLMs)的多步推理能力對於快速適應它們到複雜任務是至關重要的。雖然直接偏好優化(DPO)在對齊LLMs與人類偏好方面顯示出潛力,但對於多步推理任務來說不太適用,原因在於(1)DPO依賴於成對偏好數據,這對於多步推理任務來說並不容易獲得,以及(2)它對待所有標記一視同仁,在多步推理任務中效果不佳,因為這些任務通常伴隨稀疏獎勵。在這項工作中,我們提出了OREO(離線推理優化),這是一種用於增強LLM多步推理的離線RL方法。借鑒於最大熵強化學習的先前工作,它通過優化軟Bellman方程聯合學習策略模型和價值函數。我們原則上展示了它減少了收集成對數據的需求,並實現更好的信用分配。在實證方面,OREO在多步推理基準測試中超越了現有的離線學習方法,包括數學推理任務(GSM8K,MATH)和具體代理控制(ALFWorld)。當額外資源可用時,這種方法可以擴展到多次迭代框架。此外,學習到的價值函數可以被利用來引導樹搜索,這可以在測試時進一步提高性能。
English
Improving the multi-step reasoning ability of large language models (LLMs)
with offline reinforcement learning (RL) is essential for quickly adapting them
to complex tasks. While Direct Preference Optimization (DPO) has shown promise
in aligning LLMs with human preferences, it is less suitable for multi-step
reasoning tasks because (1) DPO relies on paired preference data, which is not
readily available for multi-step reasoning tasks, and (2) it treats all tokens
uniformly, making it ineffective for credit assignment in multi-step reasoning
tasks, which often come with sparse reward. In this work, we propose OREO
(Offline Reasoning Optimization), an offline RL method for enhancing LLM
multi-step reasoning. Building on insights from previous works of maximum
entropy reinforcement learning, it jointly learns a policy model and value
function by optimizing the soft Bellman Equation. We show in principle that it
reduces the need to collect pairwise data and enables better credit assignment.
Empirically, OREO surpasses existing offline learning methods on multi-step
reasoning benchmarks, including mathematical reasoning tasks (GSM8K, MATH) and
embodied agent control (ALFWorld). The approach can be extended to a
multi-iteration framework when additional resources are available. Furthermore,
the learned value function can be leveraged to guide the tree search for free,
which can further boost performance during test time.Summary
AI-Generated Summary