Mamba架構第一次做大!混合Transformer,打敗Transformer

精彩精彩,第一個把爆火Mamba架構真正擴展到足夠大的工作來了。

520億參數,還是Mamba+Transformer混合架構。

它的名字叫Jamba。

取兩種架構之長,模型質量和效率兼得,要吞吐量有吞吐量,要低內存有低內存。

初步跑分顯示:

這項成果,來自以色列AI公司AI21labs。

Mamba原作者看了之後都激動轉發:

Mamba、Transformer,合體

由CMU和普林斯頓大學提出的Mamba,解決了Transformer的侷限性(隨着推理上下文越長,模型內存佔用量越大,同時推理速度變慢,由此導致算力消耗巨大)。

但它也有自己的缺點——

在不關注整個上下文的情況下,Mamba的輸出質量很差,尤其是在召回相關的任務上。

本着“既要也要”的原則,Jamba站出來提供兩全其美之作。

Jamba由Transformer、Mamba和MoE層組成,可同時優化內存、吞吐量和性能。

如下圖所示,爲了集成兩種架構,Jamba採用塊層(blocks-and-layers)組合的創新方法。

簡單來說,就是每個Jamba塊包含一個注意力層或一個Mamba層,再跟一個多層感知器MLP,總體比例保證爲每八層一個Transformer層。

其次,Jamba利用MoE來增加模型參數的總量,同時簡化推理中使用的活動參數量。

最終模型容量高了,計算需求也沒有相應的增加。

而爲了在單張GPU(80GB)上最大限度地提高模型吞吐量,Jamba還優化了所用MoE層和專家數量,最終爲日常推理工作負載留出足夠內存。

值得一提的是,在推理時,Jamba的MoE層僅需520億可用參數中的120億,就能同時保證比同等大小的僅Transformer模型更高效。

要知道,此前有人光是嘗試過擴展Mamba,就沒能做到30億參數之上。

因此,除了成功合體Mamba和Transformer,Jamba也達成了第二大成就:

同類中第一個達到生產級規模和質量的混合架構(SSM混Transformer)(ps. Mamba就是一種狀態空間模型SSM)。

吞吐量和效率up

初步評估顯示,Jamba在吞吐量和效率等關鍵指標上表現出色。

首先,Jamba可以在長上下文中提供3倍吞吐量,比Mixtral 8x7B等大小相當的Transformer模型都要高效。

如下圖所示,當上下文窗口達到128k時,Jamba的每秒token數近乎1500,而此時表現最好的Mixtral 8x7B應該纔在500往上的樣子。

其次,在單張GPU上,Jamba最多可以容納140k上下文,經濟又高效。

相比之下,Mixtral 8x7B爲64k,Llama2 70B則僅爲16k。

第三,Jamba的輸出質量也得到了保證。

在如下一系列推理基準上,4項中有3項它都拿下了SOTA。同時,在GSM8K等基準上,Jamba即使沒有奪魁,也和SOTA模型打了個不相上下。

總體來說,Jamba的性能接近Mixtral 8x7B。

最後,作者提示,別忘了,這些都還只是初步改造後的結果,後續還有很多優化空間(比如MoE並行、更快的Mamba實現)。所以到時性能會更強。

好消息:Jamba現在已經上線Hugging Face,並且劃重點:採用apache-2.0許可。

(Jamba的指令版本則將很快通過AI21labs平臺上線。)

網友看完都感動哭了。

傳送門:https://huggingface.co/ai21labs/Jamba-v0.1

參考鏈接:[1]https://www.ai21.com/blog/announcing-jamba[2]https://www.ai21.com/jamba[3]https://twitter.com/AI21Labs/status/1773350888427438424?s=20[4]https://twitter.com/tri_dao/status/1773418926518734957?s=20