相关博客
【自然语言处理】【大模型】RWKV:基于RNN的LLM
【自然语言处理】【大模型】CodeGen:一个用于多轮程序合成的代码大语言模型
【自然语言处理】【大模型】CodeGeeX:用于代码生成的多语言预训练模型
【自然语言处理】【大模型】LaMDA:用于对话应用程序的语言模型
【自然语言处理】【大模型】DeepMind的大模型Gopher
【自然语言处理】【大模型】Chinchilla:训练计算利用率最优的大语言模型
【自然语言处理】【大模型】大语言模型BLOOM推理工具测试
【自然语言处理】【大模型】GLM-130B:一个开源双语预训练语言模型
【自然语言处理】【大模型】用于大型Transformer的8-bit矩阵乘法介绍
【自然语言处理】【大模型】BLOOM:一个176B参数且可开放获取的多语言模型
【自然语言处理】【大模型】PaLM:基于Pathways的大语言模型
【自然语言处理】【chatGPT系列】大语言模型可以自我改进
【自然语言处理】【ChatGPT系列】FLAN:微调语言模型是Zero-Shot学习器
【自然语言处理】【ChatGPT系列】ChatGPT的智能来自哪里?

RWKV:基于RNN的LLM

​ 基于Transformer的LLM已经取得了巨大的成功,但是其在显存消耗和计算复杂度上都很高。RWKV是一个基于RNN的LLM,其能够像Transformer那样高效的并行训练,也能够像RNN那样高效的推理。

一、背景知识

1. RNN

​ RNN是指一类神经网络模型结构,其中最具有代表性的是LSTM:
f t = σg( Wfxt+ Ufh t − 1+ bf)i t = σg( Wixt+ Uih t − 1+ bi)o t = σg( Woxt+ Uoh t − 1+ bo)c~t = σc( Wcxt+ Uch t − 1+ bc)c t = ft⊙ c t − 1+ it⊙c ~t h t = ot⊙ σh( ct) \begin{align} f_t&=\sigma_g(W_fx_t+U_f h_{t-1}+b_f) \tag*{(1)} \\ i_t&=\sigma_g(W_ix_t+U_i h_{t-1}+b_i) \tag*{(2)} \\ o_t&=\sigma_g(W_ox_t+U_o h_{t-1}+b_o) \tag*{(3)} \\ \tilde{c}_t&=\sigma_c(W_cx_t+U_c h_{t-1}+b_c) \tag*{(4)} \\ c_t&=f_t\odot c_{t-1}+i_t\odot\tilde{c}_t \tag*{(5)} \\ h_t&=o_t\odot\sigma_h(c_t) \tag*{(6)} \end{align} \\ ftitotc~tctht=σg(Wfxt+Ufht1+bf)=σg(Wixt+Uiht1+bi)=σg(Woxt+Uoht1+bo)=σc(Wcxt+Ucht1+bc)=ftct1+itc~t=otσh(ct)(1)(2)(3)(4)(5)(6)
其中, xt x_txt是当前时间步的输入, h t − 1 h_{t-1}ht1是上一个时间步的隐藏状态,所有的 WWW UUU bbb都是可学习参数, σ\sigmaσ表示 sigmoid\text{sigmoid}sigmoid函数。 ft f_tft是“遗忘门”,用来控制前一个时间步上传递信息的比例; it i_tit是“输入门”,用于控制当前时间步保留的信息比例; ot o_tot是”输出门”,用于产生最终的输出。

2. Transformers和AFT

​ Transformer是NLP中主流的一种模型架构,其依赖于注意力机制来捕获所有输入和输出tokens的关系:
Attn ( Q , K , V ) = softmax ( Q K⊤) V(7) \text{Attn}(Q,K,V)=\text{softmax}(QK^\top)V \tag{7} \\ Attn(Q,K,V)=softmax(QK)V(7)
为了简洁,这里忽略了多头和缩放因子 1 dk\frac{1}{\sqrt{d_k}}dk 1 Q K⊤ QK^\topQK是序列中每个token之间的成对注意力分数,其能够被分解为向量表示:
Attn ( Q , K , V )t= ∑ i = 1T eq t ⊤ k i∑ i = 1Te qt⊤kivi=∑ i = 1Te qt⊤ki vi∑ i = 1Te qt⊤ki (8) \text{Attn}(Q,K,V)_t=\sum_{i=1}^T\frac{e^{q_t^\top k_i}}{\sum_{i=1}^T e^{q_t^\top k_i}}v_i=\frac{\sum_{i=1}^T e^{q_t^\top k_i}v_i}{\sum_{i=1}^T e^{q_t^\top k_i}}\tag{8} \\ Attn(Q,K,V)t=i=1Ti=1Teqtkieqtkivi=i=1Teqtkii=1Teqtkivi(8)
在AFT中,设计了一种注意力变体:
Attn+( W , K , V )t=∑ i = 1te w t , i+ ki vi∑ i = 1te w t , i+ ki (9) \text{Attn}^+(W,K,V)_t=\frac{\sum_{i=1}^t e^{w_{t,i}+k_i}v_i}{\sum_{i=1}^t e^{w_{t,i}+k_i}} \tag{9} \\ Attn+(W,K,V)t=i=1tewt,i+kii=1tewt,i+kivi(9)
其中, { w t , i} ∈ R T × T \{w_{t,i}\}\in R^{T\times T}{wt,i}RT×T是可学习的位置偏差,每个 w t , i w_{t,i}wt,i是一个标量。

​ 受AFT启发,在RWKV中的 w t , i w_{t,i}wt,i是一个乘以相对位置的时间衰减向量:
w t , i= − ( t − i ) w(10) w_{t,i}=-(t-i)w \tag{10} \\ wt,i=(ti)w(10)
其中, w ∈ ( R ≥ 0)d w\in (R_{\geq 0})^dw(R0)d ddd是通道数。这里需要 www是非负来保证 e w t,i ≤ 1e^{w_{t,i}}\leq 1ewt,i1并且每个信道随时间衰减。

二、RWKV(Receptance Weighted Key Value)

​ RWKV由一系列的基本Block组成,每个Block则由time-mixing block和channel-mixing block组成的(如上图所示)。

​ RWKV递归的形式可以看做是当前输入和前一个时间不输入的线性插值,如上图所示。

1. Time-mixing block

​ Time-mixing block的作用同Self-Attention相同,就是提供全局token的交互。细节如下:
r t = Wr⋅ ( μrxt+ ( 1 − ur) x t − 1)k t = Wk⋅ ( μkxt+ ( 1 − uk) x t − 1)v t = Wv⋅ ( μvxt+ ( 1 − μv) x t − 1)w k vt=∑ i = 1 t − 1e − ( t − 1 − i ) w + ki vi+ e u + kt vt∑ i = 1 t − 1e − ( t − 1 − i ) w + ki + e u + kt o t = Wo⋅ ( σ ( rt) ⊙ w k vt) \begin{align} r_t&=W_r\cdot(\mu_rx_t+(1-u_r)x_{t-1}) \tag*{(11)} \\ k_t&=W_k\cdot(\mu_kx_t+(1-u_k)x_{t-1}) \tag*{(12)} \\ v_t&=W_v\cdot(\mu_vx_t+(1-\mu_v)x_{t-1}) \tag*{(13)} \\ wkv_t&=\frac{\sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}v_i+e^{u+k_t}v_t}{\sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}+e^{u+k_t}} \tag*{(14)} \\ o_t&=W_o\cdot(\sigma(r_t)\odot wkv_t) \tag*{(15)} \end{align} \\ rtktvtwkvtot=Wr(μrxt+(1ur)xt1)=Wk(μkxt+(1uk)xt1)=Wv(μvxt+(1μv)xt1)=i=1t1e(t1i)w+ki+eu+kti=1t1e(t1i)w+kivi+eu+ktvt=Wo(σ(rt)wkvt)(11)(12)(13)(14)(15)
所有的 μ\muμ WWW都是可训练参数, rt r_trt kt k_tkt vt v_tvt是当前输入 xt x_txt和上一个时间步输入 x t − 1 x_{t-1}xt1的加权投影。

公式(14)中, www uuu是可训练参数,分子的第一项 ∑ i = 1 t − 1e − ( t − 1 − i ) w + ki vi \sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}v_ii=1t1e(t1i)w+kivi表示前 t − 1t-1t1步的加权结果, − ( t − 1 − i ) w + ki -(t-1-i)w+k_i(t1i)w+ki是随相对距离逐步衰减; e u + kt vt e^{u+k_t}v_teu+ktvt则是当前时间步的结果。

公式(15)中,则通过 σ ( rt)\sigma(r_t)σ(rt)控制最终输出的比例。

2. Channel-mixing block

​ Channel-mixing block类似于Transformer中的FFN部分,细节如下:
r t = Wr⋅ ( μrxt− ( 1 − μr) x t − 1)k t = Wk⋅ ( μkxt− ( 1 − μk) x t − 1)o t = σ ( rt) ⊙ ( Wv⋅ max ⁡ ( kt, 0 )2) \begin{align} r_t&=W_r\cdot(\mu_rx_t-(1-\mu_r)x_{t-1}) \tag*{(16)} \\ k_t&=W_k\cdot(\mu_kx_t-(1-\mu_k)x_{t-1}) \tag*{(17)} \\ o_t&=\sigma(r_t)\odot(W_v\cdot\max(k_t,0)^2) \tag*{(18)} \\ \end{align} \\ rtktot=Wr(μrxt(1μr)xt1)=Wk(μkxt(1μk)xt1)=σ(rt)(Wvmax(kt,0)2)(16)(17)(18)

三、并行训练和序列解码

​ RWKV可以类似Transformer那样高效的并行。设batch size为B、seq_length为T、channels为d,计算量主要来自于矩阵乘法 W□, □ ∈ { r , k , v , o }W_\square,\square\in \{r,k,v,o\}W,{r,k,v,o},单层的时间复杂度为 O ( B T d2)O(BTd^2)O(BTd2)。此外,更新注意力分数 w k vt wkv_twkvt需要顺序扫描,其时间复杂度为 O ( B T d )O(BTd)O(BTd)。矩阵乘法可以像Transformer那样并行,但是WKV的计算是依赖时间步的,所以只能在其他维度上并行。

​ RWKV具有类似RNN的结构,解码时将 ttt步的输出作为 t + 1t+1t+1步的输入。相比于自注意力机制随着序列长度,计算复杂度呈平方次增长,RWKV则是与序列长度呈线性关系。因此,RWKV能够更高效的处理更长的序列。