RWKVという,Transformerの学習の並列性と,RNNのようにシーケンス長によらず対して一定の空間計算量で推論ができる,いいとこ取りをしたモデルが2023年5月にarxivに上がった.

自分の理解は,以下の方向の考え方がしやすかったが,あまりネットに(特に日本語で)落ちていなかったのでブログに書いておきます.

GPT modeとRNN mode

どうやって学習の並列性と推論の計算量を実現しているかというと,RWKVにおけるAttentionに相当する構造であるwkvを,a. GPT modeとb. RNN modeが使い分けれるからである.

それぞれ,a. 学習時には隠れ状態を一般項で計算,b. 推論時には隠れ状態を2項間の漸化式で計算 ,として使う.

具体的には,隠れ状態 h_t=(a_t,b_t)の漸化式と一般項は以下の式になる.

$$
a_t = e^{-w} a_{t-1} + e^{k_t} v_t=\sum_{i=1}^{t}e^{-(t-i)w+{k_i}}v_{i}
$$

$$
b_t = e^{-w} b_{t-1} + e^{k_t}=\sum_{i=1}^{t}e^{-(t-1)w+{k_i}}
$$

これにより,以下のメリットが得られる

a. 一般項で隠れ状態を計算→時刻1,2,..,t-1,tの隠れ状態は各々並列に計算可能

b. 2項間の漸化式で隠れ状態が計算→時刻tの隠れ状態はひとつ前(t-1)の隠れ状態から計算可能であり,計算に使うメモリを小さくできる.

注1) 計算量を気にしなければ,GPT modeで推論することも可能である.

注2) RWKVのAttentionに相当するWKVの計算は.Attention Free Transformerを元にしており,GPT modeの計算においてもTransformerよりも計算量が小さい.

RNN mode→GPT modeの導出

wkv_tの計算が漸化式でも一般項でも計算できることを示す.まずば漸化式から.
$$
wkv_t = \frac{a_{t-1} + e^{u+k_t} v_t}{b_{t-1} + e^{u+k_t}},・・・(1)
$$
$$
a_0, b_0 = 0
$$
$$
a_t = e^{-w} a_{t-1} + e^{k_t} v_t
$$
$$
b_t = e^{-w} b_{t-1} + e^{k_t}
$$

a_tを展開していくと

$$
a_t=e^{k_t}v_t+e^{-w}a_{t-1}
$$

$$
\text{       }=e^{k_t}v_t+e^{-w}(e^{k_{t-1}}v_{t-1}+e^{-w}a_{t-2})
$$

$$
\text{      }=e^{k_t}v_t+e^{k_{t-1}-w}v_{t-1}+e^{-2w}a_{t-2}
$$

$$
\text{          }=e^{k_{t}}v_te^{k_{t-1}-w}v_{t-1}+e^{-2w}(e^{k_{t-2}}v_{t-2}+e^{-w}a_{t-3})
$$

$$
\text{          }=e^{k_t}v_t+e^{k_{t-1}-w}v_{t-1}+e^{k_{t-2}-2w}v_{t-2}+e^{-3w}a_{t-3}\
$$

$$
\text{}=…
$$

$$
\text{          }=e^{k_t}v_t+e^{k_{t-1}-w}v_{t-1}+e^{k_{t-2}-2w}v_{t-2}+e^{k_{t-3}-3w}v_{t-3}+…+e^{k_{1}-(t-1)w}v_{1}+e^{-tw}a_{0}
$$

$$
\text{          }=e^{k_{t}}v_t+e^{k_{t-1}-w}v_{t-1}+e^{k_{t-2}-2w}v_{t-2}+e^{k_{t-3}-3w}v_{t-3}+…+e^{k_{1}-(t-1)w}v_{1}
$$

$$
\text{ }=\sum_{i=1}^{t}e^{-(t-i)w+{k_i}}v_{i}
$$

同様にb_tを展開していくと

$$
b_t=e^{k_t}+e^{-w}b_{t-1}
$$

$$
\text{     }=e^{k_t}+e^{-w}(e^{k_{t-1}}+e^{-w}b_{t-2})
$$

$$
\text{     }=e^{k_t}+e^{-w+{k_{t-1}}}+e^{-2w}b_{t-2}
$$

$$
\text{         }=e^{k_t}+e^{-w+{k_{t-1}}}+e^{-2w+{k_{t-2}}}+e^{-3w}b_{t-3}
$$

$$
\text{          }=e^{k_{t}}+e^{-w+k_{t-1}}+e^{-2w+k_{t-2}}+e^{-3w+{k_{t-3}}}+…+e^{-(t-1)w+{k_{1}}}+e^{-tw}b_{0}
$$

$$
\text{          }=e^{k_t}+e^{-w+{k_{t-1}}}+e^{-2w+{k_{t-1}}}+e^{-3w+{k_{t-2}}}+…+e^{-(t-1)w+k_1}
$$

$$
\text{ }=\sum_{i=1}^{t}e^{-(t-1)w+{k_i}}
$$

これらを 1≦i≦t- 1の範囲で(1)に代入すれば,Attentionに相当する計算である,wkv_t(あるいは隠れ状態)を漸化式(回帰的に計算)でも求まるし,任意の時刻tのwkv_t(あるいは隠れ状態)を一般項で求めることもできる.

$$
\begin{align*}
wkv_t = \frac{a_{t-1} + e^{u+k_t} v_t}{b_{t-1} + e^{u+k_t}}= \frac{\sum_{i=1}^{t-1} e^{-(t-1-i)w+{k_i}}v_i + e^{u+k_t} v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i)w+{k_i}} + e^{u+k_t}}
\end{align*}
$$

公式実装との対応

Pytorch実装では

a. GPT mode RWKV v2 model.py#L38(詳細はcuda実装)

b. RNN mode RWKV v2 model_run.py#L102

最後に

日本の高校数学の順序で勉強してきた自分には,漸化式→一般項の方向で解いた方が直感的なので,(簡単な計算ですが)このブログに書いておきます.どこかで誰かの助けになれればと思います.間違い等あったらご連絡ください.

参考文献

  • Bo Pen, et al. (2023). RWKV: Reinventing RNNs for the Transformer Era.

  • Shuangfei Zhai, et al. (2021). An Attention Free Transformer.