MiniMind 03: GQA
Part 3 GQA
这一部分讲 GQA,也就是 Grouped Query Attention。它看起来像是多头注意力的一个小变体,但在大模型里非常关键,因为它正好打在一个很实际的瓶颈上:
KV cache 太贵了。
所以这一部分我们会从普通多头注意力出发,讲为什么 KV 会成为负担,GQA 到底做了什么折中,以及在当前这份 minimind / HollowStoneMind 代码里,它是怎样落地的。
1. 先回忆普通多头注意力在做什么
普通多头注意力里,每个头都会各自拥有自己的:
- Query
- Key
- Value
如果有 H 个头,那么通常:
- Q 有
H个头 - K 有
H个头 - V 也有
H个头
所以对于每个 token,模型会同时维护很多组独立的比较视角。
这很好,因为表示能力很强;但坏处也很直接:
K/V 的缓存也要按 H 份来存。
而在推理阶段,真正越存越大的正是 KV cache。
2. 为什么 KV cache 会成为瓶颈
在自回归生成里,每生成一个新 token,都会把当前层的 K/V 缓存起来,留给后面的 token 用。
如果某一层的 K 形状是:
[B, H, S, D]
那随着 S 变长,缓存显存会线性增长。
更进一步,整网一共有 L 层,所以总缓存量近似是:
O(L * B * H * S * D)
这就是为什么大模型推理时,真正贵的常常不是参数本身,而是:
长上下文下的 KV cache
于是一个自然的问题就是:
Q 一定要有很多头,那 K/V 也一定要同样多吗?
GQA 的答案是:
不一定。
3. GQA 到底做了什么
GQA 的核心思想可以一句话概括:
让很多 query 头共享一组较少的 key/value 头。
也就是说:
- Query 头数保持较大
- Key / Value 头数变少
- 多个 Q 头分组共享同一组 K/V
这就是 Grouped Query Attention 的“group”含义。
例如你当前默认配置里:
num_attention_heads = 8num_key_value_heads = 2
那就意味着:
- Q 有 8 个头
- K/V 只有 2 个头
- 每 4 个 Q 头共享 1 个 KV 头组
也就是:
8 个 query heads
共享 2 个 key/value heads
4. GQA、MHA、MQA 三者的关系
可以把它们看成同一条线上不同位置的折中。
4.1 MHA:Multi-Head Attention
- Q 有
H头 - K/V 也有
H头 - 表示最丰富
- cache 最贵
4.2 MQA:Multi-Query Attention
- Q 有
H头 - K/V 只有
1头 - cache 最省
- 但 K/V 表达能力压缩得最狠
4.3 GQA:Grouped Query Attention
- Q 还是
H头 - K/V 是
H_kv头,其中1 < H_kv < H - 在表示能力和 cache 成本之间取中间折中
所以 GQA 可以理解成:
MHA 和 MQA 之间的工程平衡点
5. 当前代码里 GQA 的形状是什么
在 attention 初始化时:
self.num_key_value_heads = args.num_key_value_heads if args.num_key_value_heads is not None else args.num_attention_heads
self.n_local_heads = args.num_attention_heads
self.num_key_value_groups = self.n_local_heads // self.num_key_value_heads
self.head_dim = args.hidden_size // args.num_attention_heads
如果默认配置是:
hidden_size = 512num_attention_heads = 8num_key_value_heads = 2
那就有:
head_dim = 512 // 8 = 64
num_key_value_groups = 8 // 2 = 4
于是 attention 里:
Query
xq = self.q_proj(x)
投影后 reshape 成:
[B, 8, S, 64]
Key / Value
xk = self.k_proj(x)
xv = self.v_proj(x)
投影后 reshape 成:
[B, 2, S, 64]
所以当前模型里最核心的 GQA 形状就是:
Q:[B, 8, S, 64]K:[B, 2, S, 64]V:[B, 2, S, 64]
6. 那最后 attention 不是还得一头对一头吗
是的,所以你代码里有一个关键函数:
def repeat_kv(x: torch.Tensor, rep: int)
它的作用就是:
把较少的 KV heads 复制成和 Q 一样多的 heads
如果输入:
x.shape = [B, H_kv, S, D]
rep = H_q // H_kv
输出就是:
[B, H_q, S, D]
默认配置下:
[B, 2, S, 64] -> [B, 8, S, 64]
你代码里的实现是:
x = x.unsqueeze(2).expand(B, H, rep, S, D).reshape(B, H * rep, S, D)
本质上就是在 head 维做复制。
7. 那为什么这还能省内存
这是很多人第一次看 GQA 时最容易困惑的点:
既然最后还是
repeat_kv到和 Q 一样多,那省在哪?
关键在于:
真正长期缓存的是“未重复前”的 K/V。
也就是缓存里存的是:
[B, H_kv, S, D]
而不是:
[B, H_q, S, D]
repeat_kv 只是 attention 计算时临时扩出来用。
所以长期显存成本按 H_kv 计算,而不是按 H_q 计算。
这就是 GQA 节省 cache 的关键。
8. 从成本角度看,GQA 到底省了什么
假设:
H_q = 32H_kv = 8
那么 K/V cache 大小大致就缩成原来的:
8 / 32 = 1/4
也就是说:
- attention 的“感受头数”仍然可以维持很多 query heads
- 但缓存只需要存四分之一的 KV heads
这在长上下文生成里非常值钱。
Insight:GQA 省的不是一点线性层参数,而是推理系统里最贵的那块运行态状态
很多结构改动只是在模型参数量上做文章,但 GQA 打的是更实际的问题:
推理时,随着上下文增长不断膨胀的 KV cache。
所以它的价值不只在数学上,而在系统层面上非常直接:
- 更长上下文更可行
- 更大 batch 推理更可行
- 同样显存下吞吐更容易做上去
这就是为什么它在大模型里几乎成了常见配置。
9. GQA 会不会损失表达能力
会有折中,但通常是值得的。
因为你确实减少了 K/V 的多样性:
- 原来每个 query 头都可以对应自己独立的 K/V 视角
- 现在很多 query 头要共享同一组 K/V
但经验上,这个损失往往没有想象中大。
一种很实用的理解方式是:
注意力的“选择性”很多时候主要体现在 query 上,而不一定必须让每个头都拥有完全独立的 key/value 度量。
换句话说:
- Query 头仍然很多,模型依然可以从很多不同角度去问问题
- 只是这些问题不再都需要各自完全独立的 K/V 存储
这就是 GQA 成立的经验基础。
Insight:GQA 的核心假设不是“K/V 不重要”,而是“Q 的多样性比 K/V 的完全独立更重要”
这个角度很值得记住。
GQA 不是说:
K/V 可以随便糊弄。
而是在说:
模型的选择自由度,很多时候更多来自 query 端;
K/V 端适度共享,损失可能比我们直觉里小。
这其实是一种对注意力结构的“职责重分配”。
10. 在当前代码里的完整数据流
假设输入到 attention 的 x 是:
[B, S, hidden_size]
那 GQA 相关的数据流是:
x
-> q_proj -> [B, S, H_q * D]
-> view/transpose -> [B, H_q, S, D]
x
-> k_proj -> [B, S, H_kv * D]
-> view/transpose -> [B, H_kv, S, D]
x
-> v_proj -> [B, S, H_kv * D]
-> view/transpose -> [B, H_kv, S, D]
然后:
Q/K加 RoPEK/V与 cache 拼接repeat_kv(K/V)- 变成
[B, H_q, S, D] - 再和
Q做真正的 attention
所以你可以把 GQA 的实现记成一句话:
先用少头数生成 K/V,最后在计算 attention 前把它们广播回 query 头数。
11. 一个更高层的系统视角
如果你把 Transformer 看成一个推理系统,而不只是一个数学模块,那 GQA 的价值会更清楚。
推理阶段真正决定成本的,不只是:
- 参数量
- FLOPs
还有:
- cache 大小
- 显存带宽
- 长上下文扩展时的状态管理
从这个角度看,GQA 很像一种“结构化压缩”:
不是把模型整体缩小,而是只压缩最贵的那部分运行态状态。
这非常符合现代大模型工程的思路:
- 不一定到处都省
- 但一定要省在瓶颈上
12. 这一部分最值得记住的话
如果只记四句,建议记这四句:
- GQA 的动机是减轻 KV cache 成本,而不是改 attention 的基本数学形式。
- Q 头数保持较多,K/V 头数减少,多个 query heads 共享同一组 K/V。
repeat_kv只是计算前临时复制,真正缓存的是较少头数的 K/V。- GQA 是
MHA和MQA之间非常实用的工程折中。
一句适合以后回忆的总结:
GQA 不是在削弱 attention,而是在把“昂贵的独立性”只保留给最值得保留的那一边。
量化分析
设模型有 个 query 头、 个 KV 头,每个头维度 ,序列长度 ,层数 ,batch size 。
KV cache 内存:
| 配置 | Cache 相对 MHA | ||
|---|---|---|---|
| MHA | 32 | 32 | 100% |
| GQA-8 | 32 | 8 | 25% |
| GQA-1 (=MQA) | 32 | 1 | 3.1% |
对于 7B 参数模型(),在 时,MHA 的 KV cache 约 4 GB(fp16),GQA-8 仅需约 1 GB。
参考
- Ainslie, J. et al. “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” EMNLP, 2023. — GQA 的原始论文
- Shazeer, N. “Fast Transformer Decoding: One Write-Head is All You Need.” arXiv:1911.02150, 2019. — MQA 的原始论文
- Touvron, H. et al. “Llama 2: Open Foundation and Fine-Tuned Chat Models.” arXiv:2307.09288, 2023. — GQA 在大模型中的应用
Series: MiniMind Reproduction
- 1. MiniMind 总架构图
- 2. MiniMind 01: RMSNorm
- 3. MiniMind 02: RoPE & YaRN
- 4. MiniMind 03: GQA
- 5. MiniMind 04: FFN
- 6. MiniMind 05: 拼装 Model