MiniMind 03: GQA | Feixiang Tao
MiniMind Reproduction 2026-03-18 · 6 min read

MiniMind 03: GQA

Part 3 GQA

这一部分讲 GQA,也就是 Grouped Query Attention。它看起来像是多头注意力的一个小变体,但在大模型里非常关键,因为它正好打在一个很实际的瓶颈上:

KV cache 太贵了。

所以这一部分我们会从普通多头注意力出发,讲为什么 KV 会成为负担,GQA 到底做了什么折中,以及在当前这份 minimind / HollowStoneMind 代码里,它是怎样落地的。

1. 先回忆普通多头注意力在做什么

普通多头注意力里,每个头都会各自拥有自己的:

  • Query
  • Key
  • Value

如果有 H 个头,那么通常:

  • Q 有 H 个头
  • K 有 H 个头
  • V 也有 H 个头

所以对于每个 token,模型会同时维护很多组独立的比较视角。

这很好,因为表示能力很强;但坏处也很直接:

K/V 的缓存也要按 H 份来存。

而在推理阶段,真正越存越大的正是 KV cache。


2. 为什么 KV cache 会成为瓶颈

在自回归生成里,每生成一个新 token,都会把当前层的 K/V 缓存起来,留给后面的 token 用。

如果某一层的 K 形状是:

[B, H, S, D]

那随着 S 变长,缓存显存会线性增长。

更进一步,整网一共有 L 层,所以总缓存量近似是:

O(L * B * H * S * D)

这就是为什么大模型推理时,真正贵的常常不是参数本身,而是:

长上下文下的 KV cache

于是一个自然的问题就是:

Q 一定要有很多头,那 K/V 也一定要同样多吗?

GQA 的答案是:

不一定。

3. GQA 到底做了什么

GQA 的核心思想可以一句话概括:

让很多 query 头共享一组较少的 key/value 头。

也就是说:

  • Query 头数保持较大
  • Key / Value 头数变少
  • 多个 Q 头分组共享同一组 K/V

这就是 Grouped Query Attention 的“group”含义。

例如你当前默认配置里:

  • num_attention_heads = 8
  • num_key_value_heads = 2

那就意味着:

  • Q 有 8 个头
  • K/V 只有 2 个头
  • 每 4 个 Q 头共享 1 个 KV 头组

也就是:

8 个 query heads
共享 2 个 key/value heads

4. GQA、MHA、MQA 三者的关系

可以把它们看成同一条线上不同位置的折中。

4.1 MHA:Multi-Head Attention

  • Q 有 H
  • K/V 也有 H
  • 表示最丰富
  • cache 最贵

4.2 MQA:Multi-Query Attention

  • Q 有 H
  • K/V 只有 1
  • cache 最省
  • 但 K/V 表达能力压缩得最狠

4.3 GQA:Grouped Query Attention

  • Q 还是 H
  • K/V 是 H_kv 头,其中 1 < H_kv < H
  • 在表示能力和 cache 成本之间取中间折中

所以 GQA 可以理解成:

MHA 和 MQA 之间的工程平衡点

5. 当前代码里 GQA 的形状是什么

在 attention 初始化时:

self.num_key_value_heads = args.num_key_value_heads if args.num_key_value_heads is not None else args.num_attention_heads
self.n_local_heads = args.num_attention_heads
self.num_key_value_groups = self.n_local_heads // self.num_key_value_heads
self.head_dim = args.hidden_size // args.num_attention_heads

如果默认配置是:

  • hidden_size = 512
  • num_attention_heads = 8
  • num_key_value_heads = 2

那就有:

head_dim = 512 // 8 = 64
num_key_value_groups = 8 // 2 = 4

于是 attention 里:

Query

xq = self.q_proj(x)

投影后 reshape 成:

[B, 8, S, 64]

Key / Value

xk = self.k_proj(x)
xv = self.v_proj(x)

投影后 reshape 成:

[B, 2, S, 64]

所以当前模型里最核心的 GQA 形状就是:

  • Q: [B, 8, S, 64]
  • K: [B, 2, S, 64]
  • V: [B, 2, S, 64]

6. 那最后 attention 不是还得一头对一头吗

是的,所以你代码里有一个关键函数:

def repeat_kv(x: torch.Tensor, rep: int)

它的作用就是:

把较少的 KV heads 复制成和 Q 一样多的 heads

如果输入:

x.shape = [B, H_kv, S, D]
rep = H_q // H_kv

输出就是:

[B, H_q, S, D]

默认配置下:

[B, 2, S, 64] -> [B, 8, S, 64]

你代码里的实现是:

x = x.unsqueeze(2).expand(B, H, rep, S, D).reshape(B, H * rep, S, D)

本质上就是在 head 维做复制。


7. 那为什么这还能省内存

这是很多人第一次看 GQA 时最容易困惑的点:

既然最后还是 repeat_kv 到和 Q 一样多,那省在哪?

关键在于:

真正长期缓存的是“未重复前”的 K/V。

也就是缓存里存的是:

[B, H_kv, S, D]

而不是:

[B, H_q, S, D]

repeat_kv 只是 attention 计算时临时扩出来用。

所以长期显存成本按 H_kv 计算,而不是按 H_q 计算。

这就是 GQA 节省 cache 的关键。


8. 从成本角度看,GQA 到底省了什么

假设:

  • H_q = 32
  • H_kv = 8

那么 K/V cache 大小大致就缩成原来的:

8 / 32 = 1/4

也就是说:

  • attention 的“感受头数”仍然可以维持很多 query heads
  • 但缓存只需要存四分之一的 KV heads

这在长上下文生成里非常值钱。

Insight:GQA 省的不是一点线性层参数,而是推理系统里最贵的那块运行态状态

很多结构改动只是在模型参数量上做文章,但 GQA 打的是更实际的问题:

推理时,随着上下文增长不断膨胀的 KV cache。

所以它的价值不只在数学上,而在系统层面上非常直接:

  • 更长上下文更可行
  • 更大 batch 推理更可行
  • 同样显存下吞吐更容易做上去

这就是为什么它在大模型里几乎成了常见配置。


9. GQA 会不会损失表达能力

会有折中,但通常是值得的。

因为你确实减少了 K/V 的多样性:

  • 原来每个 query 头都可以对应自己独立的 K/V 视角
  • 现在很多 query 头要共享同一组 K/V

但经验上,这个损失往往没有想象中大。

一种很实用的理解方式是:

注意力的“选择性”很多时候主要体现在 query 上,而不一定必须让每个头都拥有完全独立的 key/value 度量。

换句话说:

  • Query 头仍然很多,模型依然可以从很多不同角度去问问题
  • 只是这些问题不再都需要各自完全独立的 K/V 存储

这就是 GQA 成立的经验基础。

Insight:GQA 的核心假设不是“K/V 不重要”,而是“Q 的多样性比 K/V 的完全独立更重要”

这个角度很值得记住。

GQA 不是说:

K/V 可以随便糊弄。

而是在说:

模型的选择自由度,很多时候更多来自 query 端;
K/V 端适度共享,损失可能比我们直觉里小。

这其实是一种对注意力结构的“职责重分配”。


10. 在当前代码里的完整数据流

假设输入到 attention 的 x 是:

[B, S, hidden_size]

那 GQA 相关的数据流是:

x
-> q_proj -> [B, S, H_q * D]
-> view/transpose -> [B, H_q, S, D]

x
-> k_proj -> [B, S, H_kv * D]
-> view/transpose -> [B, H_kv, S, D]

x
-> v_proj -> [B, S, H_kv * D]
-> view/transpose -> [B, H_kv, S, D]

然后:

  • Q/K 加 RoPE
  • K/V 与 cache 拼接
  • repeat_kv(K/V)
  • 变成 [B, H_q, S, D]
  • 再和 Q 做真正的 attention

所以你可以把 GQA 的实现记成一句话:

先用少头数生成 K/V,最后在计算 attention 前把它们广播回 query 头数。

11. 一个更高层的系统视角

如果你把 Transformer 看成一个推理系统,而不只是一个数学模块,那 GQA 的价值会更清楚。

推理阶段真正决定成本的,不只是:

  • 参数量
  • FLOPs

还有:

  • cache 大小
  • 显存带宽
  • 长上下文扩展时的状态管理

从这个角度看,GQA 很像一种“结构化压缩”:

不是把模型整体缩小,而是只压缩最贵的那部分运行态状态。

这非常符合现代大模型工程的思路:

  • 不一定到处都省
  • 但一定要省在瓶颈上

12. 这一部分最值得记住的话

如果只记四句,建议记这四句:

  1. GQA 的动机是减轻 KV cache 成本,而不是改 attention 的基本数学形式。
  2. Q 头数保持较多,K/V 头数减少,多个 query heads 共享同一组 K/V。
  3. repeat_kv 只是计算前临时复制,真正缓存的是较少头数的 K/V。
  4. GQA 是 MHAMQA 之间非常实用的工程折中。

一句适合以后回忆的总结:

GQA 不是在削弱 attention,而是在把“昂贵的独立性”只保留给最值得保留的那一边。

量化分析

设模型有 HqH_q 个 query 头、HkvH_{kv} 个 KV 头,每个头维度 dd,序列长度 SS,层数 LL,batch size BB

KV cache 内存Cache=2×L×B×Hkv×S×d×sizeof(dtype)\text{Cache} = 2 \times L \times B \times H_{kv} \times S \times d \times \text{sizeof(dtype)}

配置HqH_qHkvH_{kv}Cache 相对 MHA
MHA3232100%
GQA-832825%
GQA-1 (=MQA)3213.1%

对于 7B 参数模型(L=32,d=128,Hq=32L=32, d=128, H_q=32),在 S=4096S=4096 时,MHA 的 KV cache 约 4 GB(fp16),GQA-8 仅需约 1 GB。


参考

  • Ainslie, J. et al. “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” EMNLP, 2023. — GQA 的原始论文
  • Shazeer, N. “Fast Transformer Decoding: One Write-Head is All You Need.” arXiv:1911.02150, 2019. — MQA 的原始论文
  • Touvron, H. et al. “Llama 2: Open Foundation and Fine-Tuned Chat Models.” arXiv:2307.09288, 2023. — GQA 在大模型中的应用
END

Series: MiniMind Reproduction

  1. 1. MiniMind 总架构图
  2. 2. MiniMind 01: RMSNorm
  3. 3. MiniMind 02: RoPE & YaRN
  4. 4. MiniMind 03: GQA
  5. 5. MiniMind 04: FFN
  6. 6. MiniMind 05: 拼装 Model

Comments