MiniMind 01: RMSNorm | Feixiang Tao
MiniMind Reproduction 2026-03-18 · 6 min read

MiniMind 01: RMSNorm

Part 1 RMS-norm

这一部分讲 RMSNorm 为什么会出现在 Transformer 里,它和 LayerNorm 的差别是什么,在当前这份 minimind / HollowStoneMind 代码里具体是怎么实现的,以及它在整个模型里到底扮演什么角色。

1. 先说结论:Norm 到底想解决什么

在很深的网络里,最容易失控的不是“某一层会不会算错”,而是表示的尺度会不会越来越漂。

更具体一点:

  • 某一层输出太大,下一层的线性层和 softmax 会很难受
  • 某一层输出太小,残差和梯度的有效信号会被淹没
  • 层数一深,表示的尺度如果不受控,训练会变得很脆

所以 Norm 的核心任务,不是“让数据更漂亮”,而是:

把每一层看到的输入尺度稳定住

这样深层网络才更容易训练。


2. LayerNorm 和 RMSNorm 的核心差别

大家更熟悉的通常是 LayerNorm。它对最后一维做两件事:

  1. 减去均值
  2. 除以标准差

也就是大致做:

x -> (x - mean(x)) / std(x)

RMSNorm 更轻,它不减均值,只看均方根(root mean square):

x -> x / RMS(x)

其中:

RMS(x) = sqrt(mean(x^2))

所以它和 LayerNorm 的最大区别就是:

  • LayerNorm 会把均值也消掉
  • RMSNorm 只控制尺度,不主动把均值移到 0

这也是它名字的来源:

Root Mean Square Normalization

3. 为什么很多 LLM 喜欢 RMSNorm

直觉上可以这么理解:

  • 对 Transformer 来说,最重要的问题常常不是“均值偏了一点”
  • 而是“向量整体长度有没有失控”

所以很多时候,我们真正想要的是:

只把向量长度管住,不要过度干预向量本身的方向和偏移

这就给了 RMSNorm 一个很自然的位置。

它比 LayerNorm 少做了一步减均值,因此:

  • 计算更简单
  • 数学操作更少
  • 对表示的干预更小

这里有一个容易忽略但很重要的 insight:

Insight:RMSNorm 的价值不只是“省一点算力”

很多人第一次看 RMSNorm,会觉得:

不就是少算个均值吗?省不了多少吧。

如果只看单次浮点运算,这个判断不算错;它确实不是那种“能把训练速度翻倍”的改动。

但更高一层看,它真正提供的是一种更温和的归一化偏置:

模型只需要被约束“尺度”,不一定需要每层都被强行居中。

这在 pre-norm Transformer 里尤其自然,因为残差流里本来就很希望原始信息尽量顺畅地穿过去。

所以 RMSNorm 更像是:

  • 一个优化稳定器
  • 而不是一个强力重整器

4. 当前代码里的实现长什么样

你现在的实现本质是:

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        x_float = x.float()
        return x_float * torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self.weight * self._norm(x).type_as(x)

先看签名:

def forward(self, x)

输入形状通常是:

[B, S, hidden_size]

输出形状仍然是:

[B, S, hidden_size]

所以它不改 token 数量,不改 hidden size,只改变每个 token 向量内部的尺度。


5. 这段实现到底在算什么

第一步:转成 float 计算

x_float = x.float()

这是一个很常见的小技巧。

原因是:

  • 训练时可能会用 fp16bf16
  • 归一化涉及平方、均值、开方这类对数值稳定性比较敏感的操作
  • 所以先转成 float32 计算更稳

第二步:求每个 token 向量的 RMS

x_float.pow(2).mean(dim=-1, keepdim=True)

如果输入是:

[B, S, hidden_size]

那这一步得到的是:

[B, S, 1]

含义是:

  • 对每个 batch
  • 对每个 token
  • 在最后一维 hidden_size 上求均方值

第三步:取倒数平方根

torch.rsqrt(mean_square + eps)

这相当于:

1 / sqrt(mean_square + eps)

其中 eps 是为了防止分母太小导致数值不稳定。

第四步:乘回原向量

x_float * rsqrt(...)

于是每个 token 的向量都会被缩放到一个相对稳定的尺度上。

第五步:乘可学习参数 weight

self.weight * normalized_x

这里的 weight 形状是:

[hidden_size]

会广播到:

[B, S, hidden_size]

它相当于给每个通道一个单独的可学习放缩系数。


6. gamma / weight 为什么是必须的

如果只有纯归一化,那么模型每层的表示尺度会被钉得很死。

但神经网络往往希望:

  • 某些维度更强
  • 某些维度更弱
  • 某些特征需要被额外放大

所以 weight 的作用就是:

在“整体尺度受控”的前提下,恢复模型对各个通道幅度的表达自由

很多人第一次看 norm 会把它理解成“把信息洗平了”,其实不是。

更准确地说:

  • 归一化负责把表示带回稳定区域
  • 可学习缩放负责把模型需要的结构再长回来

7. 在整个模型里,RMSNorm 出现在哪里

在你当前的 MindBlock 里,RMSNorm 出现两次:

self.input_layernorm = RMSNorm(...)
self.post_attention_layernorm = RMSNorm(...)

forward 里是:

hidden_states, present_kv = self.attention(
    self.input_layernorm(hidden_states),
    ...
)
hidden_states = residual + hidden_states
hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))

这是一种典型的 pre-norm 结构。

也就是说:

  • 先 norm
  • 再进 attention / MLP
  • 然后残差相加

这种结构在大模型里非常常见,因为训练稳定性通常更好。


8. 从张量形状角度看,RMSNorm 做了什么

假设输入:

x.shape = [B, S, hidden_size]

那么:

x.pow(2)                 -> [B, S, hidden_size]
mean(dim=-1, keepdim=True) -> [B, S, 1]
rsqrt(...)               -> [B, S, 1]
乘回 x                   -> [B, S, hidden_size]
再乘 weight              -> [B, S, hidden_size]

所以一句话概括:

RMSNorm 不改变形状,只在最后一维上重标定每个 token 向量的长度。

9. 一个非常实用的理解方式

如果把 Transformer 的表示空间想成“很多向量在流动”,那:

  • attention 更像是在重新路由信息
  • MLP 更像是在局部改写特征
  • RMSNorm 更像是在保证这些向量不会因为层数加深而越跑越飘

这也是它为什么看起来不起眼,但实际上不可少。

你很少会说:

这个模型之所以强,是因为它有一个超惊艳的 RMSNorm。

但你很容易遇到:

去掉 norm 之后,整个训练稳定性都坏掉了。

所以它是那种“不是主角,但没有它整个系统会很难受”的组件。


10. 高屋建瓴的视角:RMSNorm 在深层模型里到底意味着什么

一个很值得记住的 insight 是:

深层网络里,表示能力和优化稳定性是两股经常互相拉扯的力量。
  • 你希望模型足够自由,能学到复杂表示
  • 你又希望它别因为太自由而数值失控

RMSNorm 就是在这两者之间取的一个平衡点:

  • 它不像某些更强的归一化那样“把表示纠正得太狠”
  • 也不像完全不管那样让深层残差流自由漂移

所以从系统设计上看,RMSNorm 提供的是:

一种“轻约束、高兼容”的稳定机制

这非常符合现代 LLM 的整体风格:

  • 主干尽量简单
  • 残差尽量顺滑
  • 稳定性靠少量但关键的结构来托底

11. 这一部分最值得记住的话

如果只记三句,建议记这三句:

  1. RMSNorm 只控制尺度,不主动减均值。
  2. 它不改变张量形状,只在最后一维重标定向量长度。
  3. 它在 Transformer 里更像是“优化稳定器”,而不是“表示主角”。

以及一句很适合以后回忆的总结:

RMSNorm 做的不是“重写信息”,而是“让信息在深层网络里别跑偏”。

数学形式化

给定输入向量 xRd\mathbf{x} \in \mathbb{R}^d,RMSNorm 的计算为:

RMSNorm(x)=xRMS(x)γ,RMS(x)=1di=1dxi2+ϵ\text{RMSNorm}(\mathbf{x}) = \frac{\mathbf{x}}{\text{RMS}(\mathbf{x})} \odot \boldsymbol{\gamma}, \quad \text{RMS}(\mathbf{x}) = \sqrt{\frac{1}{d}\sum_{i=1}^d x_i^2 + \epsilon}

其中 γRd\boldsymbol{\gamma} \in \mathbb{R}^d 是可学习的缩放参数,ϵ\epsilon 是数值稳定常数。

与 LayerNorm 的对比:LayerNorm 额外计算均值 μ=1dxi\mu = \frac{1}{d}\sum x_i 并做 xxμ\mathbf{x} \leftarrow \mathbf{x} - \mu,因此有两个可学习参数(γ\gammaβ\beta)。RMSNorm 去掉了均值偏移和偏置项,在实践中通常不会降低模型质量,但减少了约 7-10% 的归一化层计算量。


参考

  • Zhang, B. & Sennrich, R. “Root Mean Square Layer Normalization.” NeurIPS, 2019. — RMSNorm 的原始论文
  • Ba, J. L., Kiros, J. R., & Hinton, G. E. “Layer Normalization.” arXiv:1607.06450, 2016. — LayerNorm 的原始论文
  • Xiong, R. et al. “On Layer Normalization in the Transformer Architecture.” ICML, 2020. — Pre-norm vs Post-norm 的分析
END

Series: MiniMind Reproduction

  1. 1. MiniMind 总架构图
  2. 2. MiniMind 01: RMSNorm
  3. 3. MiniMind 02: RoPE & YaRN
  4. 4. MiniMind 03: GQA
  5. 5. MiniMind 04: FFN
  6. 6. MiniMind 05: 拼装 Model

Comments