兩個小模型互相驗證,直接比肩大模型?微軟的rStar甚至沒用CoT
機器之心報道
編輯:Panda
衆所周知,LLM 很強大,但執行復雜推理的能力還不夠強。
舉個例子,在 GSM8K 數據集上,Mistral-7B 即使使用思維鏈(CoT)等技術,也只能達到 36.5% 的準確度。儘管微調確實也能有效地提升推理能力,但大多數 LLM 依靠的微調數據都是經過 GPT-4 等更強大模型蒸餾過的,甚至可能原本就是這些強大模型合成的。
同時,研究者們也在積極開發一種能提供輔助但也更困難的方法:使用一個更優的教師 LLM 來提升推理能力。
爲了在沒有更優模型的前提下提升推理能力,一種頗有希望的範式是利用 LLM 自身之中的知識。舉個例子,一種名爲 RAP 的方法採用了一種自我探索式的解決方法,即通過自我獎勵的反饋來迭代式地提升 LLM 的推理性能。不幸的是,研究表明這一範式具有兩大根本性問題。
第一,在執行推理時,LLM 往往難以有效地探索解答空間。這種自我探索式方法往往會因推理步驟質量不佳而受困於某個解答空間,即使多次嘗試也是如此。
第二,即使自我探索找到了高質量的推理步驟,小版本的大型語言模型(SLM)也難以辨別哪些推理步驟的質量更高,也難以確定最終答案是否正確,由此難以有效地引導自我探索。研究表明,基於基本的常規獎勵的自我探索引導得到的結果並不比隨機猜測更好。
更麻煩的是,小版本的大型語言模型(SLM)更容易出現上述兩個問題,因爲它們的能力更差一些。舉個例子,GPT-4 能通過自我優化來提升輸出結果,但 SLM 卻很難做到這一點,甚至可能導致輸出結果質量下降。這會嚴重妨礙神經語言模型的推廣應用。
針對這些問題,微軟亞洲研究院和哈佛大學的一個研究團隊提出了 Self-play muTuAl Reasoning,即自博弈相互推理,簡稱 rStar。簡單來說,該方法就類似於讓兩個學習平平的人互相檢查考卷答案,最終提升得分,甚至達到比肩學霸的程度。該團隊宣稱 rStar 「無需微調或更優模型就能提升 SLM 的推理能力」。
方法
爲了解決上述難題,rStar 的做法是將推理過程分成了解答生成和相互驗證兩部分,如圖 2 所示。
針對第一個難題,該團隊引入了一個集合,其中包含豐富的類似人類的推理動作,可透徹地探索多種不同的推理任務空間。
針對第二個難題,他們設計了一個專門針對 SLM 的獎勵函數,這能對中間步驟進行評估,從而避免依賴它們那往往並不可靠的自我評估。
此外,該團隊還使用了另一個 SLM 作爲判別器來增強 MCTS 過程,與判別器 SLM 互相驗證每條軌跡的正確性。
使用 MCTS Rollout 自己生成推理軌跡
一個包含豐富的類人推理動作的集合。MCTS 生成的核心在於動作空間,其定義了樹探索的範圍。大多數基於 MCTS 的方法在構建樹時都使用了單一動作類型。比如 RAP 中的動作是提出下一個子問題,而 AlphaMath 和 MindStar 中的動作是生成下一推理步驟。但是,依賴單一動作類型可能容易導致空間探索效果不佳。
爲了解決這個問題,該團隊回顧了人類執行推理的方法。不同的人解決問題的方法也不同:某些人會將問題分解成子問題,另一些則會直接解決問題,還有些人則會換個視角重新表述問題。此外,人們還會根據當前狀態調整自己的方法,按需求選擇不同的動作。
受人類推理過程的啓發,該團隊構建了一個更爲豐富的數據集,其中包含 5 類動作,以儘可能地提升 SLM 正確解決複雜推理問題的潛力。
動作 1:提議一步思路。針對給定問題,該動作會讓 LLM 基於已有的推理步驟生成接下來的一步思路。
動作 2:提議餘下的思路步驟。該動作與標準 CoT 一樣,能實現「快速思考」,從而解決只需少量步驟的簡單問題。給定已經生成的推理步驟,它會讓 LLM 直接生成剩餘步驟,直到得到最終答案。
動作 3:提議下一個子問題及其答案。
動作 4:再次回答這個子問題。考慮到動作 3 有可能無法正確回答對應的子問題,因此這個動作的作用是再次回答它。
動作 5:重新表述問題 / 子問題。這個新動作是以更簡單的方式重新表述該問題。具體來說,這裡是讓 LLM 清晰列出問題陳述中的所有條件。
以上五個動作定義了一個高度多樣化的動作空間 {A1, A2, A3, A4, A5}。
在每個步驟 i,MCTS 從該空間選取一個動作 a_i。然後基於當前狀態(即之前生成的軌跡 x ⊕ s_1 ⊕ s_2 ⊕ ... ⊕ s_{i−1}),使用該動作 a_i 讓 LLM 生成下一推理步驟 s_i。請注意某些動作需要按順序執行。圖 3 給出了一個示例。
如表 1 所示,在提升最終推理準確度方面,每個動作都具有重要作用。
MCTS 的另一個關鍵組件是獎勵函數,其作用是評估每個動作的價值併爲樹的擴展提供指示。針對 SLM,該團隊設計了一個簡單卻有效的獎勵函數。他們的方法靈感來自 AlphaGo,即基於每個中間節點對最終正確答案的貢獻對它們進行評分。這樣一來,經常得到正確答案的動作就能獲得更高獎勵,它們也就更可能在未來的 MCTS 樹擴展中被選取。
這裡將執行動作 a 後生成的節點 s 的獎勵值定義爲 Q (s, a)。一開始,所有未被探索過的節點都被分配了 Q (s_i, a_i) = 0,從而實現隨機的樹擴展。在抵達首個端節點 n_d 時,根據其是否得到正確答案而計算一個獎勵分數 Q (s_d, a_d)。
然後,沿軌跡 t = x ⊕ s_1 ⊕ s_2 ⊕ ... ⊕ s_d 將該分數反向傳播給每個中間節點。具體來說,對於每個 s_i,都以如下方式更新其 Q 值:Q (s_i, a_i) = Q (s_i, a_i) + Q (s_d, a_d)。爲了計算端節點的 Q (s_d, a_d),這裡使用的獎勵值是自洽多數投票的似然(置信度)。
下面描述 MCTS 生成候選推理軌跡的方式。從初始的根節點 s_0 開始,執行包括選取、擴展、模擬和反向傳播在內的多種搜索。具體來說,模擬使用的是默認的 Rollout 策略。爲了得到更準確的獎勵估計,該團隊會執行多次 Rollout。爲了平衡探索與利用,他們使用了著名的 UCT(樹的置信度上界)來選取每個節點。這個選取過程的數學形式爲:
其中 N (s, a) 是之前的迭代中節點 s 被訪問的次數,N_parent (s) 表示對 s 的父節點的訪問次數。Q (s, a) 是估計的獎勵值,會在反向傳播過程中得到更新。c 是平衡探索與利用的常量。
一旦搜索到達某個端節點(可能是一個終端狀態,也可能到達了預定義的最大樹深度 d),便能得到一條從根到端節點的軌跡。將 Rollout 迭代得到的所有軌跡收集起來作爲候選解答。接下來就需要對它們進行驗證。
使用互恰性選擇推理軌跡
基於收集到的所有軌跡,該團隊提出使用推理互恰性來選擇答案。
如圖 2 所示,除了目標 SLM 外,該團隊還引入了一個判別器 SLM,其作用是爲每個候選軌跡提供外部無監督反饋。
具體來說,對於 t = x ⊕ s_1 ⊕ s_2 ⊕ ... ⊕ s_d,遮掩從某個隨機採樣的步驟 i 處開始的推理步驟。然後將之前的推理軌跡 t = x ⊕ s_1 ⊕ s_2 ⊕ ... ⊕ s_{i-1} 作爲 prompt 提供給判別器 SLM,讓其補全剩餘步驟。由於將之前的 i-1 個推理步驟作爲了提示,因此難度降低了,判別器 SLM 便更有可能給出正確答案。
圖 4 中比較了判別器 SLM 補全的答案是否與原始軌跡 t 匹配。如果兩者一致,則認爲 t 是可以最終選擇的已驗證軌跡。
由目標 SLM 選取最終軌跡。在對所有候選軌跡使用了推理互恰性之後,再回到目標 SLM,讓其從已驗證軌跡中選出最終軌跡。爲了計算每條軌跡的最終分數,該團隊的做法是用其獎勵乘以通過 Rollout 得到的其端節點的置信度分數。最終分數最高的軌跡被選作解答。
實驗
實驗設置
rStar 適用於多種 LLM 和推理任務。該團隊評估了 5 個 SLM:Phi3-mini、LLaMA2-7B、Mistral-7B、LLaMA3-8B、LLaMA3-8B-Instruct。
測試的推理任務有 5 個,其中包括 4 個數學任務(GSM8K、GSM-Hard、MATH、SVAMP)和 1 個常識任務(StrategyQA)。
實驗細節請訪問原論文。
主要結果
該團隊首先評估了 rStar 在一般推理基準上的有效性。表 2 比較了 rStar 和其它當前最佳方法在不同 SLM 和推理數據集上的準確度。爲了演示新生成器的效果,該團隊還提供了 rStar (generator @maj) 的準確度,即不使用判別器,僅使用多數投票來驗證答案而得到的準確度。
該團隊指出了其中的三項關鍵結果:
1. 得到 rStar 助力的 SLM 解決問題的能力更強。比如,在 GSM8K 數據集上,使用少樣本 CoT 的 LLaMA2-7B 的準確度只有 12.51%。但有了 rStar 的幫助,其準確度提升到了 63.91%,這一成績接近使用微調得到的準確度,如圖 1 所示。類似地,使用 rStar 的 Mistral 的性能甚至比微調版的 MetaMath 還高 4.18%。這樣的提升表明,SLM 本身已經具備很強的推理能力,但需要引導才能生成和選出正確解答。
2.rStar 可以穩定地將被評估的多種 SLM 在不同任務上的推理準確度提升至當前最佳水平。相較之下,其它對比方法都無法穩定地在所有四個基準上取得優良表現。舉個例子,儘管 SC(自我一致性)擅長三個數學任務,但卻無法有效解決 StrategyQA 的邏輯推理任務。
3. 即使沒有新提出的用於驗證推理軌跡的判別器,新提出的 MCTS 生成器在提升 SLM 的推理準確度方面依然效果很好。比如,在 GSM8K 數據集上,rStar (generator @maj) 的準確度比 RAP 高 2.88%-16.39%、比 ToT 高 10.60%- 38.37%、比 SC 高 1.69% - 7.34%。
該團隊還在一個更高難度的數學數據集上評估了 rStar。爲此他們選擇了 GSM-Hard 和 MATH 數據集。遵照同類研究的慣例,他們使用了 MATH-500,這是來自 MATH 數據集的一個包含代表性問題的子集。這樣做是爲了提升評估速度。如表 2 和 3 所示,rStar 能夠顯著提高 SLM 在這些高難度數學數據集上的推理準確度。
消融研究
rStar 使用了 Rollout 策略來執行 MCTS 樹擴展。更多 Rollout 會生成更多候選解答軌跡,但也會擡高推理成本。圖 5 比較了在 GSM8K 上,SC、RAP 和 rStar 使用不同 Rollout 時的準確度。
這裡得到兩個關鍵觀察結果:
1. 即使僅 2 次 Rollout,rStar 也能大幅提升 SLM 的推理準確度,這表明了其有效性;
2.Rollout 更多時對 rStar 和 SC 都有利,而 RAP 在 4 次 Rollout 之後往往會飽和甚至下降。一個原因是 RAP 的單類型動作空間會限制 MCTS 探索的效果。
該團隊比較了 MCTS 生成器與其它三種生成器的效果。如表 4 所示,新提出的 MCTS 生成器全面勝過其它生成器。此外,針對 SLM 調整過的獎勵函數的有效性也得到了證明,因爲自我評估會降低新生成器的準確度。
該團隊設置了兩個評估實驗。
第一個實驗是將判別方法與多數投票和自我驗證方法進行比較。結果見表 5(左),可以看到判別方法的優勢非常顯著。
第二個實驗則是研究不同的判別器模型的影響。結果見表 5(右),可以看到選擇不同的判別器模型通常不會影響推理互恰性方法驗證答案的效果。值得注意的是,即使使用強大的 GPT-4 作爲判別器,性能也只有略微提升(從 91.13% 提升到 92.57%)。這表明推理互恰性方法可以有效地使用 SLM 來驗證答案。