序列长度翻一倍,计算量翻四倍。当我们想从处理一段话变成处理一整本书时,这个数学现实变成了一道墙。

一道数学墙

让我们算一笔账。

Self-attention 需要让每个 token 和所有其他 token 算一次分数。$n$ 个 token 就是 $n \times n$ 个分数。

  • 512 token(约 1 页文字):约 26 万个分数。轻松。
  • 4096 token(约 10 页):约 1700 万个分数。还行。
  • 32K token(一篇长论文):约 10 亿个分数。开始吃力了。
  • 128K token(一本小书):约 160 亿个分数。GPU 显存要爆了。
  • 1M token(多本书):1 万亿个分数。不可能直接算。

更痛的是:这些分数不只是算完就完了——标准实现需要把整个 $n \times n$ 的矩阵存下来。128K 的序列,光这个矩阵就需要约 64GB 显存(fp16)。连最贵的 GPU(80GB)都装不下。

这意味着即使你有无限算力,显存也成了硬限制。

怎么办?2019-2025 年间,研究者们从两个方向进攻这个问题:

方向一:不改数学,改工程

flowchart LR
    subgraph 标准实现["标准 Attention"]
        direction TB
        S1["Q,K → 计算 n×n 分数矩阵"] --> S2["写入GPU主存 (慢)"]
        S2 --> S3["读出做Softmax"] --> S4["写回主存"]
        S4 --> S5["读出 × V"] --> S6["写回主存"]
    end
    subgraph Flash["FlashAttention"]
        direction TB
        F1["分块加载 Q,K,V 到快速缓存"] --> F2["在缓存内完成全部计算"]
        F2 --> F3["只写最终结果到主存"]
    end

FlashAttention:聪明地使用硬件 (2022)

Tri Dao(斯坦福博士生,后来成了 Mamba 的作者之一)做了一个精妙的观察:GPU 的真正瓶颈不是算力,是内存搬运

GPU 的结构简单说是这样的:

  • 有一个大但慢的”主存”(HBM,高带宽内存,几十 GB)
  • 有一个小但快的”缓存”(SRAM,几 MB)

标准 attention 实现的问题是:它先算出完整的 $n \times n$ 分数矩阵写入主存,再从主存读出来做 softmax,再写回去,再读出来乘以 Value……每一步都在搬运巨大的数据进出那个慢速主存。

FlashAttention 的核心想法:永远不要把那个 $n \times n$ 矩阵完整地写出来。

怎么做到?把计算分成小块(tiles),每个小块完全在快速缓存中完成。通过一个巧妙的”在线 softmax”算法(softmax 通常需要看到所有数据才能归一化,但他们找到了一种分块递进计算的方法),整个 attention 在缓存中一块一块地处理,最终结果写出时已经是完成品了。

结果:

  • 速度提升 2-4 倍
  • 显存使用从 $O(n^2)$ 降到 $O(n)$——那个巨大的中间矩阵根本不需要存在
  • 计算结果完全精确——和标准 attention 没有任何数值差异

这不是近似方法,不会损失精度。它只是用了一种更聪明的计算顺序。

FlashAttention 的意义在于:它把”能处理多长的序列”这个上限从硬件直接受限变成了”只受算力限制”。之后的 FlashAttention-2(2023)进一步优化,把 GPU 利用率推到了理论峰值的 50-73%。

一个类比

如果你在一个小厨房(缓存)做菜,但食材放在楼下仓库(主存):

  • 笨办法:把所有食材全搬上来,做完再全送下去。厨房太小放不下。
  • FlashAttention 的做法:只搬一道菜需要的食材,做完立刻出品。一道接一道,从不堆满厨房。

方向二:减少需要缓存的东西

KV Cache:推理时的隐性杀手

训练时,整个序列一起算,$O(n^2)$ 是一次性代价。但推理(生成文本)时有另一个问题:

生成是一个词一个词来的。生成第 $t$ 个词时,需要看前面 $t-1$ 个词的信息。如果每次都重新算前面所有词的 attention,那生成 $n$ 个词总共要算 $1 + 2 + 3 + … + n \approx n^2/2$ 次——效率极低。

解决方案:KV Cache。把之前每一步的 Key 和 Value 缓存下来,生成新词时只需要算新词的 Query 和所有缓存的 Key 的内积。这样每一步只需 $O(n)$ 而非 $O(n^2)$。

但这个缓存本身会线性增长:每一层、每个注意力头都要存 Key 和 Value。一个 70B 参数的模型处理 32K 序列,KV cache 可能需要 32GB 显存——比模型本身还大!

flowchart LR
    MHA["MHA<br/>每头独立KV<br/>32组KV"] -->|"太占显存"| MQA["MQA<br/>所有头共享1组KV<br/>压缩32×"]
    MQA -->|"质量有损"| GQA["GQA<br/>分组共享<br/>8组KV (压缩4×)"]
    GQA -->|"更进一步"| MLA["MLA<br/>低秩压缩<br/>质量更好"]

从 MHA 到 MQA 到 GQA:共享 KV

既然 KV cache 太大,能不能共享

Multi-Query Attention (MQA, 2019):最激进的做法——所有 query 头共享一组 Key 和 Value。如果有 32 个头,KV cache 一下缩小了 32 倍。但质量有损失——毕竟所有头看到的是完全相同的 KV。

Grouped-Query Attention (GQA, 2023):折中方案。把 32 个 query 头分成 8 组,每组共享一组 KV。KV cache 缩小 4 倍,质量损失很小。LLaMA 2 的 70B 模型采用了这个方案。

Multi-Head Latent Attention (MLA, 2024):DeepSeek-V2 的创新。与其直接共享 KV,不如把它们压缩到一个低维的”潜在空间”,需要时再解压。就像把一张高清照片压缩成 JPEG——文件小了很多,但质量损失几乎看不出来。MLA 比 GQA 压缩得更狠,质量反而更好。

这条演进线的逻辑是:信息有冗余,可以压缩。 32 个头的 KV 并不是完全独立的信息——它们之间有大量相关性,压缩不会损失太多有效信息。

PagedAttention:像操作系统一样管理内存

vLLM(2023)提出了另一个思路:KV cache 的问题不只是”太大”,还有”太浪费”。

传统实现为每个请求预分配能容纳最大长度的连续内存。但实际上大多数请求不会用满——就像预定了一张可坐 100 人的桌子,只来了 10 个人。

PagedAttention 借鉴了操作系统的虚拟内存思想:把 KV cache 分成固定大小的”页”,按需分配,不连续存储。就像操作系统不会为每个程序预留所有它可能用的内存——用到哪页分哪页。

结果:相同的 GPU 显存可以同时服务多 2-4 倍的用户请求。这对生产部署意义重大。

这些优化意味着什么

把这些技术叠加起来:

  • FlashAttention 让训练和推理的计算效率最大化
  • GQA/MLA 让 KV cache 缩小数倍
  • PagedAttention 让内存利用率最大化

结合位置编码的 YaRN 扩展,2024 年的模型已经可以实际处理 128K 甚至 1M token 的上下文了。这在 2022 年还被认为是不可能的。

但这些都是在 attention 框架内部优化。$O(n^2)$ 的本质没有改变——只是常数因子被极大地压低了。

有没有可能从根本上改变复杂度?从 $O(n^2)$ 变成 $O(n)$?

有人在尝试——用完全不同的架构。这是第八篇的内容。但在那之前,我们先看看 Transformer 在架构层面产生了哪些变体,以及如何把模型做得又大又高效。