当序列长到装不下:高效 Attention 的效率战争
序列长度翻一倍,计算量翻四倍。当我们想从处理一段话变成处理一整本书时,这个数学现实变成了一道墙。
一道数学墙
让我们算一笔账。
Self-attention 需要让每个 token 和所有其他 token 算一次分数。$n$ 个 token 就是 $n \times n$ 个分数。
- 512 token(约 1 页文字):约 26 万个分数。轻松。
- 4096 token(约 10 页):约 1700 万个分数。还行。
- 32K token(一篇长论文):约 10 亿个分数。开始吃力了。
- 128K token(一本小书):约 160 亿个分数。GPU 显存要爆了。
- 1M token(多本书):1 万亿个分数。不可能直接算。
更痛的是:这些分数不只是算完就完了——标准实现需要把整个 $n \times n$ 的矩阵存下来。128K 的序列,光这个矩阵就需要约 64GB 显存(fp16)。连最贵的 GPU(80GB)都装不下。
这意味着即使你有无限算力,显存也成了硬限制。
怎么办?2019-2025 年间,研究者们从两个方向进攻这个问题:
方向一:不改数学,改工程
flowchart LR
subgraph 标准实现["标准 Attention"]
direction TB
S1["Q,K → 计算 n×n 分数矩阵"] --> S2["写入GPU主存 (慢)"]
S2 --> S3["读出做Softmax"] --> S4["写回主存"]
S4 --> S5["读出 × V"] --> S6["写回主存"]
end
subgraph Flash["FlashAttention"]
direction TB
F1["分块加载 Q,K,V 到快速缓存"] --> F2["在缓存内完成全部计算"]
F2 --> F3["只写最终结果到主存"]
end
FlashAttention:聪明地使用硬件 (2022)
Tri Dao(斯坦福博士生,后来成了 Mamba 的作者之一)做了一个精妙的观察:GPU 的真正瓶颈不是算力,是内存搬运。
GPU 的结构简单说是这样的:
- 有一个大但慢的”主存”(HBM,高带宽内存,几十 GB)
- 有一个小但快的”缓存”(SRAM,几 MB)
标准 attention 实现的问题是:它先算出完整的 $n \times n$ 分数矩阵写入主存,再从主存读出来做 softmax,再写回去,再读出来乘以 Value……每一步都在搬运巨大的数据进出那个慢速主存。
FlashAttention 的核心想法:永远不要把那个 $n \times n$ 矩阵完整地写出来。
怎么做到?把计算分成小块(tiles),每个小块完全在快速缓存中完成。通过一个巧妙的”在线 softmax”算法(softmax 通常需要看到所有数据才能归一化,但他们找到了一种分块递进计算的方法),整个 attention 在缓存中一块一块地处理,最终结果写出时已经是完成品了。
结果:
- 速度提升 2-4 倍
- 显存使用从 $O(n^2)$ 降到 $O(n)$——那个巨大的中间矩阵根本不需要存在
- 计算结果完全精确——和标准 attention 没有任何数值差异
这不是近似方法,不会损失精度。它只是用了一种更聪明的计算顺序。
FlashAttention 的意义在于:它把”能处理多长的序列”这个上限从硬件直接受限变成了”只受算力限制”。之后的 FlashAttention-2(2023)进一步优化,把 GPU 利用率推到了理论峰值的 50-73%。
一个类比
如果你在一个小厨房(缓存)做菜,但食材放在楼下仓库(主存):
- 笨办法:把所有食材全搬上来,做完再全送下去。厨房太小放不下。
- FlashAttention 的做法:只搬一道菜需要的食材,做完立刻出品。一道接一道,从不堆满厨房。
方向二:减少需要缓存的东西
KV Cache:推理时的隐性杀手
训练时,整个序列一起算,$O(n^2)$ 是一次性代价。但推理(生成文本)时有另一个问题:
生成是一个词一个词来的。生成第 $t$ 个词时,需要看前面 $t-1$ 个词的信息。如果每次都重新算前面所有词的 attention,那生成 $n$ 个词总共要算 $1 + 2 + 3 + … + n \approx n^2/2$ 次——效率极低。
解决方案:KV Cache。把之前每一步的 Key 和 Value 缓存下来,生成新词时只需要算新词的 Query 和所有缓存的 Key 的内积。这样每一步只需 $O(n)$ 而非 $O(n^2)$。
但这个缓存本身会线性增长:每一层、每个注意力头都要存 Key 和 Value。一个 70B 参数的模型处理 32K 序列,KV cache 可能需要 32GB 显存——比模型本身还大!
flowchart LR
MHA["MHA<br/>每头独立KV<br/>32组KV"] -->|"太占显存"| MQA["MQA<br/>所有头共享1组KV<br/>压缩32×"]
MQA -->|"质量有损"| GQA["GQA<br/>分组共享<br/>8组KV (压缩4×)"]
GQA -->|"更进一步"| MLA["MLA<br/>低秩压缩<br/>质量更好"]
从 MHA 到 MQA 到 GQA:共享 KV
既然 KV cache 太大,能不能共享?
Multi-Query Attention (MQA, 2019):最激进的做法——所有 query 头共享一组 Key 和 Value。如果有 32 个头,KV cache 一下缩小了 32 倍。但质量有损失——毕竟所有头看到的是完全相同的 KV。
Grouped-Query Attention (GQA, 2023):折中方案。把 32 个 query 头分成 8 组,每组共享一组 KV。KV cache 缩小 4 倍,质量损失很小。LLaMA 2 的 70B 模型采用了这个方案。
Multi-Head Latent Attention (MLA, 2024):DeepSeek-V2 的创新。与其直接共享 KV,不如把它们压缩到一个低维的”潜在空间”,需要时再解压。就像把一张高清照片压缩成 JPEG——文件小了很多,但质量损失几乎看不出来。MLA 比 GQA 压缩得更狠,质量反而更好。
这条演进线的逻辑是:信息有冗余,可以压缩。 32 个头的 KV 并不是完全独立的信息——它们之间有大量相关性,压缩不会损失太多有效信息。
PagedAttention:像操作系统一样管理内存
vLLM(2023)提出了另一个思路:KV cache 的问题不只是”太大”,还有”太浪费”。
传统实现为每个请求预分配能容纳最大长度的连续内存。但实际上大多数请求不会用满——就像预定了一张可坐 100 人的桌子,只来了 10 个人。
PagedAttention 借鉴了操作系统的虚拟内存思想:把 KV cache 分成固定大小的”页”,按需分配,不连续存储。就像操作系统不会为每个程序预留所有它可能用的内存——用到哪页分哪页。
结果:相同的 GPU 显存可以同时服务多 2-4 倍的用户请求。这对生产部署意义重大。
这些优化意味着什么
把这些技术叠加起来:
- FlashAttention 让训练和推理的计算效率最大化
- GQA/MLA 让 KV cache 缩小数倍
- PagedAttention 让内存利用率最大化
结合位置编码的 YaRN 扩展,2024 年的模型已经可以实际处理 128K 甚至 1M token 的上下文了。这在 2022 年还被认为是不可能的。
但这些都是在 attention 框架内部优化。$O(n^2)$ 的本质没有改变——只是常数因子被极大地压低了。
有没有可能从根本上改变复杂度?从 $O(n^2)$ 变成 $O(n)$?
有人在尝试——用完全不同的架构。这是第八篇的内容。但在那之前,我们先看看 Transformer 在架构层面产生了哪些变体,以及如何把模型做得又大又高效。