この論文では、従来のTransformerのメモリ制約を長いシーケンスを扱う際に克服する新しいアプローチを紹介する。この革新的な方法は、ビデオや複雑な環境データなど、長大なシーケンスの分析が必要なAIアプリケーションにとって重要である。
Ring Attentionの核心となるアイデアは、自己注意とフィードフォワード処理をブロック単位で計算することだ。この方法により、長いシーケンスを複数のデバイスに分散させ、従来のメモリ効率の良いTransformerが扱えるシーケンスの長さをデバイス数倍に拡張することが可能になる 。特に、Ring Attentionは自己注意とフィードフォワードのブロック単位の計算を利用し、長いシーケンスを複数のデバイスにまたがって処理することで、キーとバリューのブロックの通信をブロック単位の注意計算と完全に重ね合わせることができる。このアプローチにより、個々のデバイスに課されるメモリ制約を効果的に排除し、言語モデリングタスクでの大規模シーケンス入力サイズの処理と性能の向上を実現する。
この論文では、各ホストが外部ループのブロック単位の注意の要素を担当し、そのブロックに特有のフィードフォワードネットワークを実行する分散入力シーケンスの処理方法が述べられる。ブロック間の相互作用は、各ホストが次のホストへのキーとバリューのブロックの送信と、前のホストからの受信を同時に行うことで、ブロック単位の計算とブロックの転送を効率的に重ね合わせる。この方法により、追加の通信コストなしで、デバイス間の算術強度を最適化し、メモリ要件を最小化する 。
複数のデバイスにまたがる長大なシーケンスのトレーニングと推論を可能にすることで、Ring Attentionは従来のメモリ効率の良いTransformerによって達成可能なシーケンスの長さをデバイス数倍に増加させることなく、近似や追加の通信および計算オーバーヘッドを伴わずに実現する。言語モデリングや強化学習タスクにおける広範な実験は、数百万トークンのコンテキストサイズを扱う能力と性能の向上を示す 。
通常のAttentionとRing Attentionの違い
通常のAttention(標準的なTransformerモデルで使用)
• 全結合: 入力シーケンスの全トークンが互いに関連付けられる。各トークンは、他の全トークンとの関係を計算し、その情報を用いて出力を生成する。
• メモリ要求: 入力シーケンスの長さに応じて、計算量とメモリ使用量が二次的に増加する。これは長いシーケンスを扱う際に問題となる。
• 計算の集中性: 一つのデバイスで全計算を行うため、大規模なシーケンスや複数のデバイスを使用する場合のスケーリングが難しい。
Ring Attention
• ブロック単位の処理: シーケンスを複数のブロックに分割し、それぞれのブロックが異なるデバイスで処理される。これにより、長いシーケンスをより効率的に扱える。
• リング構造: デバイス間でブロックがリング状に渡されることで、通信と計算を重ね合わせ、効率的に処理を行う。各デバイスは、次のデバイスへキーとバリューのブロックを送信しながら、自身の計算を進める。
• メモリと通信の効率化: ブロック単位での処理により、メモリ使用量が大幅に削減され、デバイス間通信も最適化される。これにより、従来の方法に比べて長いシーケンスを扱えるようになる。
Ring Attentionは、従来の全結合Attentionに比べて、長いシーケンスをメモリ効率よく、かつスケーラブルに処理できる点が大きな違いである。Ring Attentionは高いパフォーマンスを次のように実現する。
1. メモリ使用量の削減: 通常のAttentionでは、シーケンス内の全トークン間で相互作用を計算し保存する必要があるため、長さが増すとメモリ使用量も急増する。Ring Attentionではシーケンスを小ブロックに分割して異なるデバイスで処理するため、各デバイスで必要なメモリ量が減り、より長いシーケンスを効率的に扱える。
2. 通信と計算の効率化: デバイス間でブロックをリング状に渡すことで、キーとバリューのブロックの通信と自己注意計算を同時に行う。これにより、通信と計算が重複し、無駄な待ち時間が減少し、計算効率が向上する。
3. スケーラビリティの向上: シーケンスを扱うデバイスの数に比例して性能が向上する。デバイスを追加すれば、さらに長いシーケンスを処理できるため、大規模モデルや複雑なタスクへの柔軟な対応が可能になる。
4. 大規模シーケンスの扱いやすさ: メモリと計算の効率化により、特に大規模シーケンスや長期依存性があるデータを扱う際に、優れたパフォーマンスを発揮する。これは、言語モデリングや時系列分析など、長いコンテキストが重要なタスクでの性能向上を意味する。
通常のAttentionの数式は次のように表される。
[
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
]
この式は、(Q)(クエリ)、(K)(キー)、(V)(バリュー)の3つの行列を用いる。まず(Q)と(K^T)(キーの転置)の内積を取り、その後(d_k)(キーの次元)の平方根で割ってスケーリングする。これにより、各クエリに対する全キーの関連度を計算し、softmax関数を適用してこれらの関連度を正規化する。各バリュー(V)に対する重み付けを行い、最終的に加重和を取ることで出力を得る。
Ring Attentionに関して、具体的な数式は直接の記載がないため、通常のAttentionとの比較から導出する形で想定することになる。Ring Attentionの目的は、大規模なシーケンスを効率的に扱うことにあるため、シーケンスをブロックに分割し、それぞれのブロックでAttention計算を行いつつ、全体としてのコンテキストを保持するような処理が含まれると考えられる。したがって、Ring Attentionの一般的な概念を表す数式は以下のようになるかもしれない。
[
\text{RingAttention}(Q, K, V) = \sum_{i=1}^{N} \text{BlockAttention}(Q_i, K_i, V_i)
]
ここで、(Q_i, K_i, V_i)は、分割された各ブロック(i)に対するクエリ、キー、バリューを表し、(N)はブロックの総数である。各ブロックでのAttention計算(ここでは仮想の関数(\text{BlockAttention})として表現)を行い、その結果を合計または平均することで全体の出力を得る。これにより、各ブロック内での局所的な情報処理と、ブロック間の情報の統合が行われる。
重要な点は、この数式がRing Attentionの概念を説明するための仮想的なものであり、実際の実装や数式は更に複雑な通信パターンや最適化を含むことである。