FlashAttention 原理 | 深度学习算法

FlashAttention 的计算结果和原始算法是严格对齐的,不是对 attention 进行近似。FlashAttention 通过分块的方式减少对全局内存的读写,从而加速运算过程。本文仅包含对 FlashAttention 在前向推理上的优化,不包含反向传播相关内容。

softmax 分块

attention 的运算过程中包括矩阵乘法和 softmax 运算。矩阵乘法的分块方式比较成熟,但是 softmax 因为涉及到了全局信息所以分块计算比较困难,本节将推导 softmax 分块运算的方式。

original softmax

原始 softmax 的公式和伪代码如下:

$$y_i=\frac{e^{x_i}}{\sum_{j=1}^Ve^{x_j}}$$

可以看出,访存的开销是 2 次 load 和 1 次 store。

safe softmax

在实际硬件上,所表示的数字范围是有限的,Algorithm 1 的第 3 行可能会由于指数而上溢或下溢。safe softmax 的公式和伪代码如下:

$$y_i=\frac{e^{x_i-\max_{k=1}^Vx_k}}{\sum_{j=1}^Ve^{x_j-\max_{k=1}^Vx_k}}$$

可以看出,访存的开销是 3 次 load 和 1 次 store。

online softmax

online softmax 算法可以减少 safe softmax 算法的访存次数,先直接看伪代码:

Algorithm 3 的 1~6 行在一次访存中同事计算出了最大值 $m_V$ 和指数和 $d_V$,其中第 5 行可以通过数学归纳法证明:

可以看出,访存的开销是 2 次 load 和 1 次 store。

tiled softmax

QKV 矩阵乘法分块

整体流程

先回顾一下 attention 的计算方式:

换成流程图描述下:

其中 $P=softmax(S)$,示意图如下:

注意这里矩阵每一行是一个样本,每个样本之间是独立的,所以在求最大值和以及求和的时候是针对每一行的。

在 GPT 类的模型中,还需要对 P 做 mask 处理。为了表达方便,诸如 mask、 dropout 之类的操作都忽略掉。

分块运算

j = 0 时,遍历 i:

j = 1 时,遍历 i:

整个计算过程分为两层循环,内层循环是 i,外层循环是 j。在 $S_{ij}=Q_iK_j^T$ 的计算过程中,$S_{ij}$ 是不需要反复读写的(因为没有在 k 方向上做切分),进而 $\widetilde{P}{ij}$ 也不需要反复读写。而在 $O{ij}=\widetilde{P}{ij}V_j$ 的计算过程中,$O{ij}$ 是需要反复读写进行累加的。虽然 $\widetilde{P}{ij}$ 不需要反复读写,但是其每个分块的结果不是我们最终想要的结果(因为使用的是每个分块的局部 softmax 结果),又因为 $O{ij}$ 需要反复读写,所以将局部 softmax 结果更新成全局结果的操作就放在 $O_{ij}$ 的计算中一起完成。所以,分块计算 softmax 的意义就是省去对 $S$、$P$ 的读写。

更新输出

和前面提到的 online softmax(注意不是 tiled softmax)类似,我们希望能够找出一个方法,可以在迭代中更新 $O_i$,即通过 $O_{i-1}$ 得到 $O_i$。

注意这里的 $O_i^{(j+1)}$ 是 $O$ 每个分块的最终结果,而不是需要加在一期的部分结果。三个变量对应的部分如下所示:

现在再回头看原论文中的伪代码就比较清楚了:

GPU 显存分布情况

上图说明了 FlashAttention 如何在硬件上计算。蓝色块代表驻留在 SRAM 中的块,而红色块对应于第 $i$ 行。 $L$ 表示序列长度(即前文中的 $N$),可以很大(例如 16k),$D$ 表示注意力头的尺寸(即前文中的 $d$),在 Transformers 中通常很小(例如 GPT-3 为 128),$B$(即前文中的 $B_r$)是可以分块的大小。值得注意的是,总体 SRAM 内存占用仅取决于 $B$ 和 $D$,与 $L$ 无关。因此,该算法可以扩展到长上下文而不会遇到内存问题(GPU 共享内存很小,H100 架构为 228kb/SM)。在计算过程中,从左到右扫描 $K^T$ 和 $A$,从上到下扫描 $V$,并相应地更新 $m$、$d$(即前文中的 $l$) 和 $O$ 的值。

复杂度分析

计算复杂度

FlashAttention 的计算复杂度可以通过原论文中的伪代码进行分析:

  • 根据伪代码第 9 行 $S_{ij}=Q_iK_j^T$,其中 $Q_i\in\mathbb{R}^{B_rd},K_j^T\in\mathbb{R}^{dB_c}$,可以得出 $S_{ij}$ 的计算复杂度是 $O(B_rB_cd)$
  • 同理根据伪代码第 12 行,可以得出 $\tilde{P}_{ij}V_j$ 的计算复杂度也是 $O(B_rB_cd)$
  • 循环一共执行了 $T_cT_r=\frac N{B_c}\frac N{B_r}$ 次

综上可以得出 FlashAttention 的计算复杂度是 $O(\frac{N^2}{B_cB_r}B_rB_cd)=O(N^2d)$。

IO 复杂度

标准 attention 的 IO 复杂度:

FlashAttention 的 IO 复杂度:

参考

Online normalizer calculation for softmax

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

From Online Softmax to FlashAttention

FlashAttention 计算过程梳理

FlashAttention 的速度优化原理是怎样的?

FlashAttention 原理 | 深度学习算法

/posts/49b73eba.html

作者

学习提升网

发布于

2024-01-13

更新于

2024-02-25

许可协议

评论