MiniMind 05: 拼装 Model | Feixiang Tao
MiniMind Reproduction 2026-03-18 · 11 min read

MiniMind 05: 拼装 Model

Part 5 拼装Model

这一部分从 HollowStoneMindModel 出发,把整个模型的数据流完整串起来:一个 token id 是怎么一步一步变成隐藏状态的,每个函数的签名是什么,输入输出张量形状怎么变化。

1. 总入口:HollowStoneMindModel.forward(...)

对应代码里的主入口:

def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    past_kv: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
    use_cache: bool = False,
    **kwargs,
)

这里最核心的输入有三个:

  • input_ids:形状 [B, S]
    • B 是 batch size
    • S 是 sequence length
    • 每个元素还是一个离散的 token id
  • attention_mask:通常形状 [B, S]
    • 一般 1 表示有效 token
    • 0 表示 padding
  • past_kv:每一层的 KV cache
    • 类型是一个长度等于层数的 list
    • 每个元素要么是 None
    • 要么是 (past_k, past_v)

当前这个 Model 的输出是:

return hidden_states, presents_kv

也就是:

  • hidden_states:形状 [B, S, hidden_size]
  • presents_kv:每层一个 (k, v) 的 list

注意:现在这份代码还没有 lm_head,所以它还只是 Transformer 主干,还不会输出词表 logits。


2. 配置先决定所有维度

配置类是:

class HollowStoneMindConfig(PretrainedConfig)

默认比较关键的配置有:

  • hidden_size = 512
  • num_attention_heads = 8
  • num_key_value_heads = 2
  • num_hidden_layers = 8
  • vocab_size = 6400

于是可以推出:

head_dim = hidden_size // num_attention_heads = 512 // 8 = 64

所以后面 attention 里常见形状会是:

  • Q: [B, 8, S, 64]
  • K: [B, 2, S, 64]
  • V: [B, 2, S, 64]

这里已经能看出来你在用 GQA:Q 的 head 数多,K/V 的 head 数少。


3. token id 先变成 embedding 向量

HollowStoneMindModel.forward(...) 里,首先做的是:

hidden_states = self.dropout(self.embed_tokens(input_ids))

其中:

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)

所以形状变化是:

input_ids:      [B, S]
-> embed_tokens
hidden_states:  [B, S, hidden_size]

也就是说:

  • 之前每个 token 只是一个整数编号
  • 到这里它第一次变成一个长度为 hidden_size 的连续向量

如果 hidden_size = 512,那一个 token 现在就对应一个 512 维向量。


4. 模型初始化时就把 RoPE 表算好了

HollowStoneMindModel.__init__(...) 里:

freqs_cos, freqs_sin = precompute_freqs_cis(
    dim=config.hidden_size // config.num_attention_heads,
    end=config.max_position_embeddings,
    rope_base=config.rope_theta,
    rope_scaling=config.rope_scaling,
)

对应函数签名:

def precompute_freqs_cis(
    dim: int,
    end: int = 32 * 1024,
    rope_base: float = 1e6,
    rope_scaling: Optional[dict] = None,
)

这个函数返回:

  • cos
  • sin

它们的形状都是:

[max_position_embeddings, head_dim]

以当前默认配置来说就是:

[32768, 64]

然后注册成 buffer:

self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)

所以 forward 里只是拿来用,不用每次重新算。


5. 进入每一层 MindBlock

模型主干是:

self.layers = nn.ModuleList(
    [MindBlock(config, i) for i in range(config.num_hidden_layers)]
)

在 forward 中:

for layer, layer_past_kv in zip(self.layers, past_kv):
    hidden_states, present_kv = layer(...)

所以每一层 block 的输入输出都是:

  • 输入 hidden_states: [B, S, hidden_size]
  • 输出 hidden_states: [B, S, hidden_size]
  • 另外返回当前层的 present_kv

MindBlock.forward(...) 的签名是:

def forward(
    self,
    hidden_states,
    position_embeding,
    past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    use_cache: bool = False,
    attention_mask: Optional[torch.Tensor] = None,
)

它内部做的事情非常标准:

  1. input_layernorm
  2. self-attention
  3. residual add
  4. post_attention_layernorm
  5. MLP
  6. residual add

具体是:

residual = hidden_states
hidden_states, present_kv = self.attention(self.input_layernorm(hidden_states), ...)
hidden_states = residual + hidden_states
hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))

所以整个 block 不改外部形状,只改变语义表示。


6. RMSNorm 在做什么

函数签名:

class RMSNorm(nn.Module):
    def forward(self, x)

输入输出形状都不变:

[B, S, hidden_size] -> [B, S, hidden_size]

它做的是对最后一维做 RMS 归一化:

x_float = x.float()
x_float * rsqrt(mean(x_float^2) + eps)

再乘以一个可学习参数 weight

所以它的作用不是改变 token 数量,也不是改变 hidden size,而是重新调整每个 token 向量的尺度。


7. Attention:先把 hidden states 投影成 Q/K/V

attention 的签名:

def forward(
    self,
    x: torch.Tensor,
    position_embeding: Tuple[torch.Tensor, torch.Tensor],
    past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    use_cache: bool = False,
    attention_mask: Optional[torch.Tensor] = None,
)

这里输入的 x 是:

[B, S, hidden_size]

先做三次线性投影:

xq = self.q_proj(x)
xk = self.k_proj(x)
xv = self.v_proj(x)

此时形状分别是:

  • xq: [B, S, num_attention_heads * head_dim]
  • xk: [B, S, num_key_value_heads * head_dim]
  • xv: [B, S, num_key_value_heads * head_dim]

然后:

.view(B, S, H, D).transpose(1, 2)

所以最终变成统一约定的 B H S D

  • xq: [B, H_q, S, D]
  • xk: [B, H_kv, S, D]
  • xv: [B, H_kv, S, D]

默认配置下是:

  • xq: [B, 8, S, 64]
  • xk: [B, 2, S, 64]
  • xv: [B, 2, S, 64]

8. RoPE 是怎么加到 Q/K 上的

在 attention 里:

cos, sin = position_embeding
position_ids = torch.arange(past_len, past_len + S).unsqueeze(0)
xq, xk = apply_rope(cos, sin, xq, xk, position_ids=position_ids, unsqueeze_dim=1)

这里:

  • cos: [max_pos, D]
  • sin: [max_pos, D]
  • position_ids: [1, S]
  • xq: [B, H_q, S, D]
  • xk: [B, H_kv, S, D]

为什么 position_ids 是从 past_len 开始?

因为增量解码时,新 token 的真实位置不再是 0,1,2...,而是接在历史序列后面。

举个例子:

  • 历史缓存长度 past_len = 20
  • 当前来了 S = 3 个新 token
  • 那这三个 token 的位置就应该是 [20, 21, 22]

apply_rope(...) 的签名是:

def apply_rope(cos, sin, Q, K, position_ids=None, unsqueeze_dim=1)

它内部会先做:

cos = cos[position_ids]
sin = sin[position_ids]

于是形状变成:

[1, S, D]

然后再 unsqueeze(1),变成:

[1, 1, S, D]

这样就能和:

Q/K: [B, H, S, D]

做广播相乘。

最后返回:

  • q_embed: [B, H_q, S, D]
  • k_embed: [B, H_kv, S, D]

也就是说:RoPE 只改数值,不改形状。


9. KV cache 是怎么拼接的

在 attention 中:

if past_kv is not None:
    xk = torch.cat([past_kv[0], xk], dim=2)
    xv = torch.cat([past_kv[1], xv], dim=2)

由于你的统一约定是 B H S D,所以:

  • dim=2 正好是序列维

如果:

  • past_k: [B, H_kv, S_past, D]
  • xk: [B, H_kv, S_cur, D]

拼接后:

  • xk: [B, H_kv, S_past + S_cur, D]

xv 同理。

如果 use_cache=True,函数最后会把更新后的 (xk, xv) 返回出去。


10. 为什么还要 repeat_kv

你这里的 Q 头数和 K/V 头数不一样。

  • Q 有 num_attention_heads = 8
  • K/V 有 num_key_value_heads = 2

所以在真正做 attention 前,要先把 K/V 在 head 维复制到和 Q 一样多。

调用:

xk = repeat_kv(xk, self.num_key_value_groups)
xv = repeat_kv(xv, self.num_key_value_groups)

函数签名:

def repeat_kv(x: torch.Tensor, rep: int)

输入:

[B, H_kv, S, D]

输出:

[B, H_q, S, D]

在默认配置下:

  • rep = 8 // 2 = 4
  • 所以 2 个 KV heads 会被复制成 8

于是 attention 前的形状统一成:

  • xq: [B, 8, S_q, 64]
  • xk: [B, 8, S_k, 64]
  • xv: [B, 8, S_k, 64]

11. Attention score 是怎么得到的

attention 有两条实现路径:

  • Flash Attention 路径
  • 手写 Attention 路径

不管哪条路,进入时张量形状都是:

  • xq: [B, H, S_q, D]
  • xk: [B, H, S_k, D]
  • xv: [B, H, S_k, D]

11.1 手写路径

在代码中:

scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)

形状变化是:

[B, H, S_q, D] @ [B, H, D, S_k]
-> [B, H, S_q, S_k]

这就是每个 query token 对所有 key token 的打分。

然后会加两类 mask:

causal mask

作用:不让当前位置看到未来 token。

形状:

[S_q, S_k]

再广播到:

[B, H, S_q, S_k]

attention mask

原始输入:

attention_mask: [B, S_k]

扩展后:

[B, 1, 1, S_k]

再和 scores: [B, H, S_q, S_k] 广播相加。

接着:

attn_weights = softmax(scores, dim=-1)
out_put = attn_weights @ xv

所以:

[B, H, S_q, S_k] @ [B, H, S_k, D]
-> [B, H, S_q, D]

11.2 Flash Attention 路径

如果环境支持:

F.scaled_dot_product_attention(...)

那它内部也是同样的数学逻辑,只是实现更高效。

输出依然是:

[B, H, S_q, D]

12. 多头结果如何拼回 [B, S, hidden_size]

attention 输出还是:

[B, H, S, D]

接下来:

out_put = out_put.transpose(1, 2).contiguous().view(B, S, H * D)

所以:

[B, H, S, D]
-> [B, S, H, D]
-> [B, S, hidden_size]

然后经过输出线性层:

self.o_proj(out_put)

以及 dropout。

因此 attention 子层整体效果是:

[B, S, hidden_size] -> [B, S, hidden_size]

外加一个 present_kv


13. MLP / FeedForward 在做什么

函数签名:

def forward(self, x: torch.Tensor)

输入:

[B, S, hidden_size]

内部做:

up_proj(x)    -> [B, S, intermediate_size]
gate_proj(x)  -> [B, S, intermediate_size]
act_fn(gate_proj(x)) * up_proj(x)
             -> [B, S, intermediate_size]
down_proj(...) -> [B, S, hidden_size]

最后再 dropout。

所以 MLP 子层整体是:

[B, S, hidden_size]
-> [B, S, intermediate_size]
-> [B, S, hidden_size]

它的作用是把每个 token 的表示先升维做非线性混合,再压回 hidden size。


14. 一个 MindBlock 的完整流向

如果把一个 block 展开来看,流程是:

hidden_states: [B, S, hidden_size]
-> input_layernorm
-> self-attention
-> residual add
-> post_attention_layernorm
-> FeedForward
-> residual add
-> output hidden_states: [B, S, hidden_size]

所以 block 的本质是:

  • attention 负责 token 与 token 之间的信息交互
  • MLP 负责每个 token 内部表示的非线性变换
  • residual 负责保留原始信息并稳定训练

15. 整个 HollowStoneMindModel 的完整数据流

如果从最顶层看,一个 batch 的 token 进来之后,会经历:

input_ids: [B, S]
-> embedding
hidden_states: [B, S, hidden_size]
-> dropout
-> layer 1
-> layer 2
-> ...
-> layer N
-> final RMSNorm
-> 输出最终 hidden_states: [B, S, hidden_size]

同时如果开启 use_cache=True

  • 每层都会返回自己的 present_kv
  • 最终聚合成 presents_kv

16. 一句话总结 token 的旅程

一个 token 在当前这份代码里经历的是:

token id
-> embedding 向量
-> 进入每一层 block
-> 在 attention 中变成 Q/K/V
-> Q/K 加上 RoPE 位置编码
-> 和上下文做注意力交互
-> 拼回 hidden_size
-> 再经过 MLP
-> 多层重复
-> 最后得到最终 hidden state

但要注意:

当前这份代码还没有:

hidden_states -> lm_head -> logits -> loss / generate

所以它目前还是一个 Transformer backbone,还不是完整的语言模型头。


17. 当前这份代码里最值得记住的几个形状

建议直接背下来:

  • input_ids: [B, S]
  • hidden_states: [B, S, hidden_size]
  • Q: [B, H_q, S, D]
  • K/V: [B, H_kv, S, D]
  • repeat_kv(K/V): [B, H_q, S, D]
  • scores: [B, H, S_q, S_k]
  • attn output: [B, H, S, D]
  • 拼回后:[B, S, hidden_size]
外层主干一直是 [B, S, hidden_size]
attention 内部临时切成 [B, H, S, D]

总体架构对比

当前 HollowStoneMind 的架构选择与主流开源模型高度一致:

组件HollowStoneMindLLaMA 2GPT-2
归一化RMSNorm (pre-norm)RMSNorm (pre-norm)LayerNorm (post-norm)
位置编码RoPE + YaRNRoPE学习式绝对位置
注意力GQAGQAMHA
FFNSwiGLU 门控SwiGLU 门控GELU MLP
偏置项

这套”RMSNorm + RoPE + GQA + SwiGLU”的组合已经成为 2023 年以来 LLM 架构的事实标准(de facto standard),被 LLaMA、Mistral、Qwen、DeepSeek 等主流模型采用。


参考

  • Vaswani, A. et al. “Attention Is All You Need.” NeurIPS, 2017. — Transformer 架构的原始论文
  • Touvron, H. et al. “LLaMA: Open and Efficient Foundation Language Models.” arXiv:2302.13971, 2023. — 当前架构选择的主要参考
  • Radford, A. et al. “Language Models are Unsupervised Multitask Learners.” OpenAI, 2019. — GPT-2 作为对比基准
END

Series: MiniMind Reproduction

  1. 1. MiniMind 总架构图
  2. 2. MiniMind 01: RMSNorm
  3. 3. MiniMind 02: RoPE & YaRN
  4. 4. MiniMind 03: GQA
  5. 5. MiniMind 04: FFN
  6. 6. MiniMind 05: 拼装 Model

Comments