深度学习基础:图解 Batch Norm 与 Layer Norm 的维度切分逻辑
从数据结构角度出发,用'班级成绩单'的比喻,彻底搞懂 BN 和 LN 的维度切分逻辑
做深度学习的时候,Batch Normalization (BN) 和 Layer Normalization (LN) 几乎是绕不开的坎。
这两个概念,公式看一眼都懂——减均值、除方差,把数据拉回标准分布。但一到具体代码层面,或者被面试官问到深处时,很多人就容易”短路”:
- 这所谓的均值 $\mu$ 和方差 $\sigma$,到底是对谁算出来的?
- 为什么 BERT/Transformer 必须用 LN,而 ResNet 却用 BN?
- 维度切片到底怎么切?
今天不堆砌复杂的数学推导,我们从数据结构的角度出发,用一张“班级成绩单”的比喻,帮你彻底终结对这两个概念的混淆。
01 核心逻辑:方向决定一切
Normalization(归一化)的通用公式我们都知道:
\(y = \\frac{x - \\mu}{\\sigma} \\times \\gamma + \\beta\) 困扰我们的永远是:这个 $\mu$ 和 $\sigma$,究竟是把哪些数据捏在一起算出来的?
请记住这句核心口诀:
🟢 BN 是”竖着切”,看重群体的共性(特征)。
🔵 LN 是”横着切”,看重个体的特质(样本)。
02 极简比喻:班级成绩单
为了讲清楚,我们先忽略长宽 $(H, W)$,把输入数据 $[N, C]$ 想象成一张班级期末考试成绩单。
- N (Batch Size) = 班里有 N 个学生(样本)。
- C (Channel) = 这次考了 C 门科目(特征,如语文、数学、英语…)。
🟢 1. Batch Normalization (BN):纵向比较
BN 的核心逻辑: 抹平不同科目(特征)之间的量纲/难度差异。
- 怎么算? 此时不区分具体的学生(张三还是李四),而是把全班 N 个人的“数学成绩”这一整列拿出来,算一个平均分和方差。
- 为什么要这么做? 也许数学满分是 150,体育满分是 100。如果不归一化,模型会误以为数学比体育更重要(因为数值大)。BN 就是为了让”数学列”和”体育列”都变成均值为 0、方差为 1 的标准分布,让不同特征在同一起跑线上。
- 局限性: 强依赖 Batch Size。如果班里只有一个人(Batch Size=1),或者人很少,算出来的分布就没有统计意义了。
(图注:Batch Norm 就像是教导主任在分析单科成绩,纵向计算整列数据的分布)
🔵 2. Layer Normalization (LN):横向比较
LN 的核心逻辑: 抹平不同样本之间的基准差异。
- 怎么算? 此时不看其他同学考多少分,只把张三这一行的“所有 C 门科目”拿出来,算一个张三自己的平均分和方差。
- 为什么要这么做? 在 NLP 中,我们更关注一个样本内部的相对关系。不管张三整体分数偏高还是偏低,我们把他”标准化”后,只看他哪门课相对更好。
- 优势: 既然是自己跟自己算,那哪怕班里只有张三一个人,LN 照样能算。它完全独立于 Batch Size。
(图注:Layer Norm 就像是班主任在分析学生个人素质,横向计算整行数据的分布)
03 进阶视角:立体”切蛋糕” (4D)
在实际的 CV 任务中,数据通常是 4 维的 $[N, C, H, W]$。我们可以把 $(H, W)$ 拍扁,看作是特征的延伸。
此时,把数据想象成一块厚厚的切片面包:
- 面包的厚度 = Batch Size (N)
- 每一片面包 = 一个样本 (Sample)
- 面包上的纹理 = Channels (C)
🔪 BN 的切法(一刀插到底)
BN 是一把长刀,垂直切下去。
它固定住纹理(Channel),跨越了所有的面包片(Samples)。
- 适用场景:CV(计算机视觉)。
因为不同图片在同一个 Channel(比如 RGB 的 R 通道)提取的信息物理意义是一致的,大家放在一起统计是有意义的。
🔪 LN 的切法(只切一片)
LN 是拿一把小刀,水平切一片。
它把一片面包(Sample)单独拿出来,把这片面包上所有的纹理 $(C)$ 和像素 $(H, W)$ 全部揉在一起算均值。
- 适用场景:NLP(自然语言处理)。
因为句子有长有短,且句子之间相互独立。强行把”第一句话的第一个词”和”第二句话的第一个词”放在一起比,通常没啥意义。
04 代码视角的直觉
如果你习惯看代码,这两行 PyTorch 风格的伪代码最能说明问题。
假设输入 x 的形状是 $[N, C, H, W]$:
1
2
3
4
5
6
7
8
9
10
11
# Batch Normalization
# 只要是同一个 Channel (dim=1)
# 不管你是哪个样本(N),哪个位置(H,W),统统揉在一起算!
# 也就是在 (N, H, W) 这三个维度上求均值
mean_bn = x.mean(dim=(0, 2, 3), keepdim=True)
# Layer Normalization
# 只要是同一个样本 (dim=0)
# 不管你是哪个 Channel(C),哪个位置(H,W),统统揉在一起算!
# 也就是在 (C, H, W) 这三个维度上求均值
mean_ln = x.mean(dim=(1, 2, 3), keepdim=True)
05 总结
一张表总结它们的本质区别:
| 维度 | Batch Normalization (BN) | Layer Normalization (LN) |
|---|---|---|
| 计算视角 | 看列 (Column) | 看行 (Row) |
| 归一化对象 | 同一特征,跨所有样本 | 同一样本,跨所有特征 |
| Batch 依赖 | 高 (Batch 太小效果差) | 无 (单样本即可计算) |
| 主场领域 | CNN / 图像 (ResNet等) | RNN / Transformer (LLM基石) |
💡 为什么 Transformer/LLM 都用 LN?
- 序列长度不一: 文本数据通常长短不一,如果用 BN,需要对齐长度(Padding),Padding 部分的 0 会严重干扰 BN 的统计量。
- Batch Size 限制: 训练大模型时显存吃紧,Batch Size 往往很小,BN 在小 Batch 下表现极其不稳定,而 LN 稳如磐石。
希望这篇简单的图解,能帮你建立起直观的维度切分视角,下次再看到这两个词,脑海里直接浮现”切蛋糕”的画面就对了。