LLM 다단계 추론을 위한 오프라인 강화 학습
Offline Reinforcement Learning for LLM Multi-Step Reasoning
December 20, 2024
저자: Huaijie Wang, Shibo Hao, Hanze Dong, Shenao Zhang, Yilin Bao, Ziran Yang, Yi Wu
cs.AI
초록
대형 언어 모델(LLMs)의 다단계 추론 능력을 향상시키는 것은 복잡한 작업에 빠르게 적응하기 위해 오프라인 강화 학습(RL)이 필수적입니다. 직접 선호도 최적화(DPO)는 LLMs를 인간의 선호도와 조화롭게 만드는 데 유망한 가능성을 보여주었지만, 다단계 추론 작업에는 적합하지 않습니다. 왜냐하면 (1) DPO는 다단계 추론 작업에 즉시 사용할 수 없는 짝 지어진 선호 데이터에 의존하며, (2) 모든 토큰을 균일하게 처리하여 종종 희박한 보상이 따르는 다단계 추론 작업에서 신용 할당에 효과적이지 않습니다. 본 연구에서는 LLM 다단계 추론을 향상시키기 위한 오프라인 RL 방법인 OREO(Offline Reasoning Optimization)를 제안합니다. 최대 엔트로피 강화 학습 이전 연구의 통찰을 기반으로, 소프트 벨만 방정식을 최적화함으로써 정책 모델과 가치 함수를 함께 학습합니다. 이를 통해 짝 지어진 데이터 수집 필요성을 줄이고 더 나은 신용 할당을 가능하게 합니다. 경험적으로, OREO는 수학적 추론 작업(GSM8K, MATH) 및 실체화된 에이전트 제어(ALFWorld)를 포함한 다단계 추론 벤치마크에서 기존의 오프라인 학습 방법을 능가합니다. 이 방법은 추가 자원을 사용할 수 있는 경우 다단계 반복 프레임워크로 확장할 수 있습니다. 또한, 학습된 가치 함수는 트리 탐색을 무료로 안내하는 데 활용될 수 있으며, 이는 테스트 시 성능을 더 향상시킬 수 있습니다.
English
Improving the multi-step reasoning ability of large language models (LLMs)
with offline reinforcement learning (RL) is essential for quickly adapting them
to complex tasks. While Direct Preference Optimization (DPO) has shown promise
in aligning LLMs with human preferences, it is less suitable for multi-step
reasoning tasks because (1) DPO relies on paired preference data, which is not
readily available for multi-step reasoning tasks, and (2) it treats all tokens
uniformly, making it ineffective for credit assignment in multi-step reasoning
tasks, which often come with sparse reward. In this work, we propose OREO
(Offline Reasoning Optimization), an offline RL method for enhancing LLM
multi-step reasoning. Building on insights from previous works of maximum
entropy reinforcement learning, it jointly learns a policy model and value
function by optimizing the soft Bellman Equation. We show in principle that it
reduces the need to collect pairwise data and enables better credit assignment.
Empirically, OREO surpasses existing offline learning methods on multi-step
reasoning benchmarks, including mathematical reasoning tasks (GSM8K, MATH) and
embodied agent control (ALFWorld). The approach can be extended to a
multi-iteration framework when additional resources are available. Furthermore,
the learned value function can be leveraged to guide the tree search for free,
which can further boost performance during test time.