From standard softmax attention to FlashAttention & FlashDecoding (Japanese)

Attention calculation is a key component in transformer models. This blog post explains how attention calculation is accelerated, starting from naive softmax attention to FlashAttention and FlashDecoding. This post is written in Japanese.

このブログポストでは attention 計算がいかに高速化されているのかを step-by-step で説明する。具体的には、標準的な softmax attention からスタートし、FlashAttention と FlashDecoding に至るまでを説明する。

ゴール

FlashAttention
FlashAttention のイメージ図。key と value と block-wise に分割されているため、GPU の SM を効率的に利用できる。しかし、key & value の系列長方向(行方向)には並列化がなされておらず、query が少ない場合には効率が悪い。より引用
FlashAttention
FlashDecoding のイメージ図。同一の query に対して、異なる block にある key と value を並列に計算し、最後に結果をマージすることにより key & value の系列長方向(行方向)への並列化を実現している。より引用

これを理解する。そのためには block-wise に分割した attention 計算を、うまくマージする方法を考えれば良い。

Attention 計算

Attention は query, key から attention score を計算し、そのスコアで value を重み付けして出力する。

GPU の memory hierarchy

WIP

Naive な softmax 計算

WIP

FlashAttention:

WIP

FlashDecoding: 系列長方向への並列化

WIP