理论基础(7)BatchNormalization
归一化的作用:
- 可解释性:回归模型中自变量X的量纲不一致导致了回归系数无法直接解读或者错误解读;需要将X都处理到统一量纲下,这样才可比【可解释性】;取决于我们的逻辑回归是不是用了正则化。如果你不用正则,标准化并不是必须的,如果用正则,那么标准化是必须的。
- 距离计算:机器学习任务和统计学任务中有很多地方要用到“距离”的计算,比如PCA,比如KNN,比如kmeans等等,假使算欧式距离,不同维度量纲不同可能会导致距离的计算依赖于量纲较大的那些特征而得到不合理的结果;
- 加速收敛(BN):参数估计时使用梯度下降,在使用梯度下降的方法求解最优化问题时, 归一化/标准化后可以加快梯度下降的求解速度,即提升模型的收敛速度。
Batch Normalization 作用:
- 更好的尺度不变性:也就是说不管低层的参数如何变化,逐层的输入分布都保持相对稳定。
- 尺度不变性能够提高梯度下降算法的效率,从而加快收敛;
- 归一化到均值为0,方差为1的分布也能够使得经过sigmoid,tanh等激活函数以后,尽可能落在梯度非饱和区,缓解梯度消失的问题。【bn和ln都可以比较好的抑制梯度消失和梯度爆炸的情况】;
- 更平滑的优化地形:更平滑的优化地形意味着局部最小值的点更少,能够使得梯度更加reliable和predictive,从而让我们有更大的”信心”迈出更大的step来优化,即可以使用更大的学习率来加速收敛。
- 对参数初始化和学习率大小不太敏感:BN操作可以抑制参数微小变化随网络加深的影响,使网络可以对参数初始化和尺度变化适应性更强,从而可以使用更大的学习率而不用担心参数更新step过大带来的训练不稳定。
- 隐性的正则化效果:(Batch)训练时采用随机选取mini-batch来计算均值和方差,不同mini-batch的均值和方差不同,近似于引入了随机噪音,使得模型不会过拟合到某一特定的均值和方差参数下,提高网络泛化能力。
一、BatchNormalization的原理、作用和实现
归一化(白化)的方法有很多,为什么要设计BN这个样子的?
概要
上一节介绍了归一化方法(Batch Normalization)在深度神经网络中的作用:有人认为是更好的尺度不变性来缓解ICS现象。有人认为是更平滑的优化地形。但实际上我们还没有介绍Batch Normalization究竟是什么东西。这一节我们从缓解ICS现象的角度来引出Batch Normalization,并介绍其原理和实现。
必须要说明的是,这个出发点在现在看来很可能已经有问题了。。(参见上一篇文章对优化地形的讨论)
1.1 从缓解ICS现象出发
ICS现象指的是该层的输入分布会因为之前层的参数更新而发生改变。为了缓解这一问题,我们需要对输入分布进行归一化。上上节讲了白化(其实还没填坑)是一种机器学习中常见的归一化手段,其好处在于:
- 能够使得逐层的输入分布具有相同的均值和方差(PCA白化能够使得所有特征分布均值为0,方差为1)
- 同时去除特征之间的相关性
通过白化这一方法,可以有效缓解ICS现象,加速收敛。然而呢,其存在一些问题:
- 计算成本高(参见上上节还没填坑的计算过程,需要涉及到协方差,奇异值分解等)
- 由于对输入分布进行了限制,会损害输入数据原本的表达能力(其实就是说原本数据的分布信息丢失了)
- 均值为0,方差为1的输入分布容易使得经过sigmoid或者tanh的激活函数时,落入梯度饱和区
解决思路很简单:设计一种简化计算的白化操作,归一化后让数据尽量保留原始的表达能力。
- 单独对每一维特征进行归一化,使其满足均值为0,方差为1
- 增加线性变换操作,让数据能够尽量恢复本身表达能力
1.2 Batch Normalization的算法过程
首先是对每一维特征进行归一化,可以借助上上节介绍的归一化方法:本质上是减去一个统计量,再除以一个统计量。另一方面,BN的操作是在mini-batch层面进行计算,而不是full batch。具体来说:
假设输入样本的形状是 \(m \times d\), 其中 \(m\) 指batch size。
- 计算第 \(i\) 个样本的第 \(j\) 个维度上的均值: \(\mu_j=\frac{1}{m} \sum_{i=1}^m Z_j^{(i)}\)
- 计算第 \(i\) 个样本的第 \(j\) 个维度上的方差: \(\frac{1}{m} \sum_{i=1}^m\left(Z_j^{(i)}-\mu_j\right)^2\)
- 归一化: \(\hat{Z}_j=\frac{Z_j-\mu_j}{\sqrt{\sigma^2+\epsilon}}\) (加 \(\epsilon\) 防止分母为 0 )
通过上述变换实现每个特征维度上的均值和方差为0和1。
进一步的, 为了保证输入数据的表达能力, 引入两个可学习参数 \(\gamma\) 和 \(\beta\) (都是 \(\mathrm{d}\) 维向量) 来对归一化后的数据进行线性变换: \(\tilde{Z}_j=\gamma_j \hat{Z}_j+\beta_j\) 。特别地, 当 \(\gamma^2=\sigma^2, \beta=\mu\) 时, 即可实现identity transform, 并保留了原始输入特征的分布信息。
通过上述变换,在一定程度上保证了输入数据的表达能力。
综上所述: \[ \tilde{Z}_j=\gamma_j \cdot \frac{Z_j-\mu_j}{\sqrt{\sigma^2+\epsilon}}+\beta_j \]
补充:在进行归一化过程中, 由于归一化操作会减去均值, 所以偏置项可以忽略或者置0, 即 \(BN(W x+b)=B N(Wx)\)
1.3 Batch Normalization在测试阶段
首先, 在训练阶段, 我们是对一个mini-batch的数据计算每个维度上的均值和方差。为什么不用全量数据的均值 和方差呢? 这样一来, 不管哪个batch都用的一种分布了,会降低模型的鲁棒性。
在测试阶段, 有可能只需要预测一个样本或者很少样本, 不足以拼成一个mini-batch, 此时计算得到的均值和方 差一定是有偏估计。为了解决这一问题, BN的原论文中提出下面的方法:
保留训练阶段, 每个mini-batch的均值和方差信息: \(\mu_{\text {batch }}, \sigma_{\text {batch }}^2\) 。对测试集的数据, 计算均值和方差的无偏 估计:
\[ \begin{gathered} \mu_{\text {test }}=\mathbb{E}\left(\mu_{\text {batch }}\right), \sigma_{\text {test }}^2=\frac{m}{m-1} \mathbb{E}\left(\sigma_{\text {batch }}^2\right) \\ B N\left(X_{\text {test }}\right)=\gamma \cdot \frac{X_{\text {test }}-\mu_{\text {test }}}{\sqrt{\sigma^2+\epsilon}}+\beta \end{gathered} \]
为什么训练和测试的时候计算方差不一样呢?
训练时,计算当前batch var时,当前batch的样本就是该随机变量的所有样本了,因此除以n就好了。而测试时是全局样本的var,因此当前batch的样本只是该随机变量的部分采样样本,为了是无偏估计,必须乘以 n/n-1。
在计算随机变量的均值和方差时, 一般情况下无法知道该随机变量的分布公式, 因此我们通常会采样一些样本, 然后计算这些样本的均值和方差作为该随机变量的均值和方差。由于计算这些样本的方差时减的是样本均值而不 是随机变量的均值, 而样本均值是和采样的样本有关的, 是有偏估计, 如果要得到无偏估计, 需要乘以 \(n / n-1\) 。 \[ \begin{array}{r} E\left(\frac{1}{n} \sum_{i=1}^n\left(X_i-\bar{X}\right)^2\right)=\frac{1}{n} E\left(\sum_{i=1}^n\left(X_i-\mu+\mu-\bar{X}\right)^2\right) \\ =\frac{1}{n}\left(\sum_{i=1}^n E\left(\left(x_i-\mu\right)^2\right)-n E\left((\bar{X}-\mu)^2\right)\right) \\ =\frac{1}{n}(n \operatorname{Var}(X)-n \operatorname{Var}(\bar{X}))=\operatorname{Var}(X)-\operatorname{Var}(\bar{X}) \\ =\sigma^2-\frac{\sigma^2}{n}=\frac{n-1}{n} \sigma^2 \end{array} \] 在实际中,采用的是moving average的方式来实现,也就是在每个batch训练时,用当前batch计算出的均值和方差(即sample_mean和sample_var) 来更新 running_mean和runing_var。最后测试使用的实际是running_mean和running_var。
1 | running_mean = momentum * running_mean + (1 - momentum) * sample_mean |
1.4 Batch Normalization的作用
总结起来就是为了稳定训练,加速收敛。具体来说,有下面几种作用:
- 更好的尺度不变性:也就是说不管低层的参数如何变化,逐层的输入分布都保持相对稳定。
- 尺度不变性能够提高梯度下降算法的效率,从而加快收敛;
- 归一化到均值为0,方差为1的分布也能够使得经过sigmoid,tanh等激活函数以后,尽可能落在梯度非饱和区,缓解梯度消失的问题。【bn和ln都可以比较好的抑制梯度消失和梯度爆炸的情况】;
- 更平滑的优化地形:更平滑的优化地形意味着局部最小值的点更少,能够使得梯度更加reliable和predictive,从而让我们有更大的”信心”迈出更大的step来优化,即可以使用更大的学习率来加速收敛。
- 隐性的正则化效果:训练时采用随机选取mini-batch来计算均值和方差,不同mini-batch的均值和方差不同,近似于引入了随机噪音,使得模型不会过拟合到某一特定的均值和方差参数下,提高网络泛化能力。
- 对参数初始化和学习率大小不太敏感:BN操作可以抑制参数微小变化随网络加深的影响,使网络可以对参数初始化和尺度变化适应性更强,从而可以使用更大的学习率而不用担心参数更新step过大带来的训练不稳定。
假设对网络参数 \(W\) 进行缩放得到 \(a W\) 。对于缩放前的值 \(W x\), 设其均值为 \(\mu_1\), 方差为 \(\sigma_1^2\); 对于缩放值 \(a W x\), 设其均值为 \(\mu_2\), 方差为 \(\sigma_2^2\), 则有: \(\mu_2=a \mu_1, \sigma_2^2=a^2 \sigma_1^2\) 。忽略 \(\epsilon\), 则有: \[ \begin{aligned} & B N(a W x)=\gamma \cdot \frac{a W x-\mu_2}{\sqrt{\sigma_2^2}}+\beta=\gamma \cdot \frac{a W x-a \mu_1}{\sqrt{a^2 \sigma_1^2}}+\beta=\gamma \cdot \frac{W x-\mu_1}{\sqrt{\sigma_1^2}}+\beta=B N(W x) \\ & \frac{\partial B N(a W x)}{\partial x}=\gamma \cdot \frac{a W}{\sqrt{\sigma_2^2}}=\gamma \cdot \frac{a W}{\sqrt{a^2 \sigma_1^2}}=\gamma \cdot \frac{W}{\sqrt{\sigma_1^2}}=\frac{\partial B N(W x)}{\partial x} \\ & \frac{\partial B N(a W x)}{\partial(a W)}=\gamma \cdot \frac{x}{\sqrt{\sigma_2^2}}=\gamma \cdot \frac{x}{a \sqrt{\sigma_1^2}}=\frac{1}{a} \frac{\partial B N(W x)}{\partial W} \end{aligned} \] 经过BN操作后, 权重的缩放会被抺去, 保证输入分布稳定, 同时权重的缩放也不会改变对输入的梯度, 而当权 重越大时 (即a越大时),权重的梯度越小,即变化越小,保证了梯度不会依赖于参数的尺度。 注意一个问题是: 计算特征 \(x\) 的统计量的时候, 都是在统计每个特征维度上统计量, 也就是对该维度上的所有样 本求和取均值,或者求方差, axis=bsz那个维度!
还要注意一个问题: (一个简单的MLP) 上面讨论的所有情况的shape都是:(batch_size,hidden_dim), 此时针对每个特征维度, 我们对整个batch的样本在这个维度上计算统计量。但实际情况是, CV和NLP在应用 normalization时, shape并没有这么简单。
1.5 CNN中的Batch Normalization实现
针对CV,一个CNN的常见的输出shape是:(N, C, H, W) ,针对BN的话,这个特征维度是Channel,也就是我们需要在一个batch中所有样本所有H所有W上进行统计。最后得到的均值和方差向量都是(C), 因此两个参数也是C维的向量。
卷积操作:最初输入的图片样本的 channels ,取决于图片类型,比如RGB;常见的图像就是(N,3,H,W)。一开始,卷积核的shape可以是(3,3,3),第一个维度表示3个通道,也就是说通道有多少个,卷积核就有多少个。注意不同图像,图像不同位置使用的卷积核的参数都共享的。一个卷积核可以输出一个特征图(h1,w1),多个卷积核可以得到多个特征图,从而形成特征图的shape是(C,H,W),也就是说卷积核数= 通道数。
因为卷积核数=通道数, 所以一个卷积核可以得到一个通道的特征图数据, 我们希望不同图像, 图像不同位置用 这个卷积核执行卷积以后的数据分布是稳定的,所以需要在通道维度执行normalization。
对于输入的特征图: \(x \in \mathbb{R}^{N \times C \times H \times W}\) 包含 \(\mathrm{N}\) 个样本, 每个样本的通道数为 \(\mathrm{C}\), 高为 \(\mathrm{H}\), 宽为 \(\mathrm{W}\) 。求均值和方差 时, 是在 \(N, H, W\) 上操作, 保留 \(\mathrm{C}\) 的维度, 最后形成维度为 \(\mathrm{C}\) 的均值和方差向量。 \[ \begin{gathered} \mu_c(x)=\frac{1}{N H W} \sum_{n=1}^N \sum_{h=1}^H \sum_{w=1}^W x_{n c h w} \\ \sigma_c(x)=\sqrt{\frac{1}{N H W} \sum_{n=1}^N \sum_{h=1}^H \sum_{w=1}^W\left(x_{n c h w}-\mu_c(x)\right)^2+\epsilon} \end{gathered} \] 在实现中需要注意的一点是:到底是对哪个维度求均值和方差
- 对于shape为 \(b x d\) 的张量来说, 特征维度是最后一维: \(d\) 。求均值和方差实际就是:对于 \(d\) 中的每一维, 统 计 \(b\) 个样本的均值和方差, 均值和方差向量的形状为( \(d)\) 。实现:x.mean(dim=0)。
- 对于shape为BCHW的张量来说, 如果是batch Normalization, 特征维度是channel: C。求均值和方差 实际就是:先reshape:C, BHW, 然后统计BHW个样本的均值和方差, 均值和方差向量形状为(C)。实现: x.permute \((1,0,2,3) \cdot v i e w(3,-1) \cdot\) mean(dim=1)
- 总结来说: batch Normalization实际是对特征的每一维统计所有样本的均值和方差, CNN里面特征维度是 channel维, 所以最后向量形状就是(C)。
- 提前说一下: Layer Normalization实际是对每个样本的所有维度统计均值和方差, 所以求和取平均的是d 维度, 最后向量形状就是(B, max_len)。
二、Batch Normalization VS Layer Normalization
概要
上一节介绍了Batch Normalization的原理,作用和实现(既讲了MLP的情况,又讲了CNN的情况)。然而我们知道,Transformer里面实际使用的Layer Normalization。因此,本文将对比Batch Normalization介绍Layer Normalization。
2.1 Batch Normalization的些许缺陷
要讲Layer Normalization,先讲讲Batch Normalization存在的一些问题:即不适用于什么场景。
- BN在mini-batch较小的情况下不太适用。BN是对整个mini-batch的样本统计均值和方差,当训练样本数很少时,样本的均值和方差不能反映全局的统计分布信息,从而导致效果下降。
- BN无法应用于RNN(Sq2Sq),RNN实际是共享的MLP,在时间维度上展开,每个step的输出是(bsz, hidden_dim)。由于不同句子的同一位置的分布大概率是不同的,所以应用BN来约束是没意义的。注:而BN应用在CNN可以的原因是同一个channel的特征图都是由同一个卷积核产生的。
LN原文的说法是:在训练时,对BN来说需要保存每个step的统计信息(均值和方差)。在测试时,由于变长句子的特性,测试集可能出现比训练集更长的句子,所以对于后面位置的step,是没有训练的统计量使用的。(不过实践中的话都是固定了maxlen,然后padding的。)不同句子的长度不一样,对所有的样本统计均值是无意义的,因为某些样本在后面的timestep时其实是padding。
2.2 Layer Normalization的原理
BN是对batch的维度去做归一化,也就是针对不同样本的同一特征做操作。LN是对hidden的维度去做归一化,也就是针对单个样本的不同特征做操作。因此LN可以不受样本数的限制。
BN就是在每个特征维度上统计所有样本的值,计算均值和方差;LN就是在每个样本上统计所有维度的值,计算均值和方差(注意,这里都是指的简单的MLP情况,输入特征是(bsz,hidden_dim))。所以BN在每个特征维度上分布是稳定的,LN是每个样本的分布是稳定的。
2.3 Transformer中Layer Normalization的实现
对于一个输入tensor:(batch_size, max_len, hidden_dim) 应该如何应用LN层呢?
注意,和Batch Normalization一样,同样会施以线性映射的。区别就是操作的维度不同而已!公式都是统一的:减去均值除以标准差,施以线性映射。同时LN也有BN的那些个好处!
1 | # features: (bsz, max_len, hidden_dim) |
2.4 讨论:Transformer 为什么使用 Layer normalization,而不是其他的归一化方法?
当然这个问题还没有啥定论,包括BN和LN为啥能work也众说纷纭。这里先列出一些相关的研究论文。
- Leveraging Batch Normalization for Vision Transformers
- PowerNorm: Rethinking Batch Normalization in Transformers
- Understanding and Improving Layer Normalization
(1) Understanding and Improving Layer Normalization
这篇文章主要研究LN为啥work,除了一般意义上认为可以稳定前向输入分布,加快收敛快,还有没有啥原因。最后的结论有:
- 相比于稳定前向输入分布,反向传播时mean和variance计算引入的梯度更有用,可以稳定反向传播的梯度(让\(\frac{\partial l o s s}{\partial x}\) 梯度的均值趋于0,同时降低其方差,相当于re-zeros和re-scales操作),起名叫gradient normalization(其实就是ablation了下,把mean和variance的梯度断掉,看看效果)
- 去掉 gain和bias这两个参数可以在很多数据集上有提升,可能是因为这两个参数会带来过拟合,因为这两个参数是在训练集上学出来的
注:Towards Stabilizing Batch Statistics in Backward Propagation 也讨论了额外两个统计量:mean和variance的梯度的影响。实验中看到了对于小的batch size,在反向传播中这两个统计量的方差甚至大于前向输入分布的统计量的方差,其实说白了就是这两个与梯度相关的统计量的不稳定是BN在小batch size下不稳定的关键原因之一。
(2) PowerNorm: Rethinking Batch Normalization in Transformers
这篇文章就主要研究Transformer中BN为啥表现不太好。研究了训练中的四个统计量:batch的均值和方差,以及他们的梯度的均值和方差。对于batch的均值和方差,计算了他们和running statistics(就是用移动平均法累积的均值和方差,见前面的文章)的欧氏距离。可以看到NLP任务上(IWSLT14)batch的均值和方差一直震荡,偏离全局的running statistics,而CV任务也相对稳定。对于他们梯度的均值和方差,研究了其magnitude(绝对值),可以看到CV任务上震荡更小,且训练完成后,也没有离群点。
总结来说,Transformer中BN表现不太好的原因可能在于CV和NLP数据特性的不同,对于NLP数据,前向和反向传播中,batch统计量及其梯度都不太稳定。
(3) Leveraging Batch Normalization for Vision Transformers
刚刚讲了对于NLP data,为啥Transformer的BN表现不好。这篇文章就是去研究对于CV data,VIT中能不能用BN呢。有一些有意思的观点:
- LN特别适合处理变长数据,因为是对channel维度做操作(这里指NLP中的hidden维度),和句子长度和batch大小无关
- BN比LN在inference的时候快,因为不需要计算mean和variance,直接用running mean和running variance就行
- 直接把VIT中的LN替换成BN,容易训练不收敛,原因是FFN没有被Normalized,所以还要在FFN block里面的两层之间插一个BN层。(可以加速20% VIT的训练)
总结
- Layer Normalization和Batch Normalization一样都是一种归一化方法,因此,BatchNorm的好处LN也有
- 然而BN无法胜任mini-batch size很小的情况,也很难应用于RNN。
- LN特别适合处理变长数据,因为是对channel维度做操作(这里指NLP中的hidden维度),和句子长度和batch大小无关。
- BN比LN在inference的时候快,因为不需要计算mean和variance,直接用running mean和running variance就行。
- BN和LN在实现上的区别仅仅是:BN是对batch的维度去做归一化,也就是针对不同样本的同一特征做操作。LN是对hidden的维度去做归一化,也就是针对单个样本的不同特征做操作。因此,他们都可以归结为:减去均值除以标准差,施以线性映射。
参考文献
- Transformer中的归一化(四):BatchNormalization的原理、作用和实现:https://zhuanlan.zhihu.com/p/481277619?utm_source=wechatMessage_undefined_bottom
- Transformer中的归一化(五):Layer Norm的原理和实现 & 为什么Transformer要用LayerNorm - Gordon Lee的文章 - 知乎 https://zhuanlan.zhihu.com/p/492803886