论文阅读笔记:看完也许能进一步了解Batch Normalization
提示:阅读论文时进行相关思想、结构、优缺点,内容进行提炼和记录,论文和相关引用会标明出处。
文章目录
- 前言
- 介绍
- BN之前的一些减少Covariate Shift的方法
- BN算法描述
- Batch Normalization的反向传播
- Batch Normalization的预测阶段
- 网络inference阶段conv层和BN层的融合
- 实验结果
- 关于BN的几个讨论
- 总结
- 写在最后
前言
标题:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
原文链接:Link
Github:NLP相关Paper笔记和代码复现
说明:阅读论文时进行相关思想、结构、优缺点,内容进行提炼和记录,论文和相关引用会标明出处,引用之处如有侵权,烦请告知删除。
转载请注明:DengBoCong
训练深度神经网络非常复杂,因为在训练过程中,随着先前各层的参数发生变化,各层输入的分布也会发生变化,导致调参工作要做的很小心,训练更加困难,论文中将这种现象称为“internal covariate shift”,而Batch Normalization正式用来解决深度神经网络中internal covariate shift现象的方法。有关covariate shift的内容,可以参阅我另一篇论文阅读笔记。
介绍
Batch Normalization是在每个mini-batch进行归一化操作,并将归一化操作作为模型体系结构的一部分,使用BN可以获得如下的好处:
- 可以使用更大的学习率,训练过程更加稳定,极大提高了训练速度。
- 可以将bias置为0,因为Batch Normalization的Standardization过程会移除直流分量,所以不再需要bias。
- 对权重初始化不再敏感,通常权重采样自0均值某方差的高斯分布,以往对高斯分布的方差设置十分重要,有了Batch Normalization后,对与同一个输出节点相连的权重进行放缩,其标准差也会放缩同样的倍数,相除抵消。
- 对权重的尺度不再敏感。
- 深层网络可以使用sigmoid和tanh了,BN抑制了梯度消失。
- Batch Normalization具有某种正则作用,不需要太依赖dropout,减少过拟合。
我们从梯度计算开始看起,如在SGD中是优化参数 θ \theta θ,从而最小化损失,如下公式:
θ = a r g m i n θ 1 N ∑ i = 1 N l ( x i , θ ) \theta=arg\underset{\theta}{min}\frac{1}{N}\sum_{i=1}^{N}l(x_i,\theta) θ=argθminN1i=1∑Nl(xi,θ)
其中, x 1 . . . x N x_1…x_N x1…xN是训练数据集。使用SGD,训练将逐步进行,并且在每个步骤中,我们考虑大小为 m m m 的mini-batch,即 x 1 . . . m x_1…m x1…m,通过计算 1 m ∂ ( x i , θ ) ∂ θ \frac{1}{m}\frac{\partial(x_i,\theta)}{\partial\theta} m1∂θ∂(xi,θ),使用小批量数据来近似损失函数关于参数的梯度。使用小批量样本,而不是一次一个样本,在一些方面是有帮助的。首先,小批量数据的梯度损失是训练集上的梯度估计,其质量随着批量增加而改善。第二,由于现代计算平台提供的并行性,对一个批次的计算比单个样本计算 m m m 次效率更高。
虽然随机梯度是简单有效的,但它需要仔细调整模型的超参数,特别是优化中使用的学习速率以及模型参数的初始值。训练的复杂性在于每层的输入受到前面所有层的参数的影响——因此当网络变得更深时,网络参数的微小变化就会被放大。如果我们能保证非线性输入的分布在网络训练时保持更稳定,那么优化器将不太可能陷入饱和状态,训练将加速。
BN之前的一些减少Covariate Shift的方法
对网络的输入进行白化,网络训练将会收敛的更快——即输入线性变换为具有零均值和单位方差,并去相关。当每一层观察下面的层产生的输入时,实现每一层输入进行相同的白化将是有利的。通过白化每一层的输入,我们将采取措施实现输入的固定分布,消除Internal Covariate Shift的不良影响。那么如何消除呢?考虑在每个训练步骤或在某些间隔来白化激活值,通过直接修改网络或根据网络激活值来更改优化方法的参数,但这样会弱化梯度下降步骤。
例如:例如,考虑一个层,其输入u加上学习到的偏置 b b b,通过减去在训练集上计算的激活值的均值对结果进行归一化: x ^ = x − E [ x ] \hat x=x - E[x] x^=x−E[x], x = u + b x = u+b x=u+b, X = x 1 … N X={x_{1\ldots N}} X=x1…N 是训练集上 x x x 值的集合, E [ x ] = 1 N ∑ i = 1 N x i E[x] = \frac{1}{N}\sum_{i=1}^N x_i E[x]=N1∑i=1Nxi。如果梯度下降步骤忽略了 E [ x ] E[x] E[x] 对 b b b的依赖,那它将更新 b ← b + Δ b b\leftarrow b+\Delta b b←b+Δb,其中 Δ b ∝ − ∂ ℓ / ∂ x ^ \Delta b\propto -\partial{\ell}/\partial{\hat x} Δb∝−∂ℓ/∂x^。然后 u + ( b + Δ b ) − E [ u + ( b + Δ b ) ] = u + b − E [ u + b ] u+(b+\Delta b) -E[u+(b+\Delta b)] = u+b-E[u+b] u+(b+Δb)−E[u+(b+Δb)]=u+b−E[u+b]。因此,结合 b b b 的更新和接下来标准化中的改变会导致层的输出没有变化,从而导致损失没有变化。随着训练的继续, b b b 将无限增长而损失保持不变。如果标准化不仅中心化而且缩放了激活值,问题会变得更糟糕。在最初的实验中,当标准化参数在梯度下降步骤之外计算时,模型会爆炸。
总结而言就是使用白话来缓解ICS问题,白化是机器学习里面常用的一种规范化数据分布的方法,主要是PCA白化与ZCA白化。白化是对输入数据分布进行变换,进而达到以下两个目的:
- 使得输入特征分布具有相同的均值与方差,其中PCA白化保证了所有特征分布均值为0,方差为1,而ZCA白化则保证了所有特征分布均值为0,方差相同。
- 去除特征之间的相关性。
通过白化操作,我们可以减缓ICS的问题,进而固定了每一层网络输入分布,加速网络训练过程的收敛。但是白话过程的计算成本太高,并且在每一轮训练中的每一层我们都需要做如此高成本计算的白化操作,这未免过于奢侈。而且白化过程由于改变了网络每一层的分布,因而改变了网络层中本身数据的表达能力,底层网络学习到的参数信息会被白化操作丢失掉。
BN算法描述
文中使用了类似z-score的归一化方式:每一维度减去自身均值,再除以自身标准差,由于使用的是随机梯度下降法,这些均值和方差也只能在当前迭代的batch中计算,故作者给这个算法命名为Batch Normalization。BN变换的算法如下所示,其中,为了数值稳定, ϵ \epsilon ϵ 是一个加到小批量数据方差上的常量。
我们可以将上面的算法总结为两步:
- Standardization:首先对 m m m 个 x x x 进行Standardization,得到 zero mean unit variance的分布 x ^ \hat{x} x^。
- scale and shift:然后再对 x ^ \hat{x} x^ 进行scale and shift,缩放并平移到新的分布 y y y,具有新的均值 β \beta β 方差 γ \gamma γ。
更形象一点,假设BN层有 d d d 个输入节点,则 x x x 可构成 d × m d\times m d×m大小的矩阵 X X X,BN层相当于通过行操作将其映射为另一个 d × m d\times m d×m 大小的矩阵 Y Y Y,如下所示:
将2个过程写在一个公式里如下:
y i ( b ) = B N ( x i ) ( b ) = γ ( x i ( b ) − μ ( x i ) σ ( x i ) 2 + ϵ ) + β y_i^{(b)}=BN(x_i)^{(b)}=\gamma (\frac{x_i^{(b)}-\mu(x_i)}{\sqrt{\sigma(x_i)^2+\epsilon}})+\beta yi(b)=BN(xi)(b)=γ(σ(xi)2+ϵxi(b)−μ(xi))+β
其中, x i ( b ) x_i^{(b)} xi(b) 表示输入当前batch的b-th样本时该层i-th输入节点的值, x i x_i xi 为 [ x i ( 1 ) , x i ( 2 ) , … , x i ( m ) ] [x_i^{(1)},x_i^{(2)},…,x_i^{(m)}] [xi(1),xi(2),…,xi(m)] 构成的行向量,长度为batch size m m m, μ \mu μ和 σ \sigma σ为该行的均值和标准差, ϵ \epsilon ϵ 为防止除零引入的极小量(可忽略), γ \gamma γ和 β \beta β为该行的scale和shift参数,可知
- μ \mu μ 和 σ \sigma σ 为当前行的统计量,不可学习。
- γ \gamma γ 和 β \beta β 为待学习的scale和shift参数,用于控制 y i y_i yi 的方差和均值。
- BN层中, x i x_i xi 和 x j x_j xj 之间不存在信息交流 ( i ≠ j ) (i\neq j) (i=j)
可见,无论xi原本的均值和方差是多少,通过BatchNorm后其均值和方差分别变为待学习的 γ \gamma γ 和 β \beta β。为什么需要 γ \gamma γ 和 β \beta β 的可训练参数?Normalization操作我们虽然缓解了ICS问题,让每一层网络的输入数据分布都变得稳定,但却导致了数据表达能力的缺失。也就是我们通过变换操作改变了原有数据的信息表达(representation ability of the network),使得底层网络学习到的参数信息丢失。另一方面,单纯通过让每一层的输入分布均值为0,方差为1,而不做缩放和移位,会使得输入在经过sigmoid或tanh激活函数时,容易陷入非线性激活函数的线性区域。
在训练初期,分界面还在剧烈变化时,计算出的参数不稳定,所以退而求其次,在 W x + b Wx+b Wx+b 之后,ReLU激活层前面进行归一化。因为初始的 W W W 是从标准高斯分布中采样得到的,而 W W W 中元素的数量远大于 x x x, W x + b Wx+b Wx+b 每维的均值本身就接近 0 0 0、方差接近 1 1 1,所以在 W x + b Wx+b Wx+b 后使用Batch Normalization能得到更稳定的结果,如下图所示:
Batch Normalization的反向传播
讲反向传播之前,我们先来简单的写一下正向传递的代码,如下:
def batchnorm_forward(x, gamma, beta, eps):
N, D = x.shape
# 第一步:计算平均
mu = 1./N * np.sum(x, axis=0)
# 第二步:每个训练样本减去平均
xmu = x - mu
# 第三步:计算分母
sq = xmu ** 2
# 第四步:计算方差
var = 1./N * np.sum(sq, axis=0)
# 第五步:加上eps保证数值稳定性,然后计算开方
sqrtvar = np.sqrt(var + eps)
# 第六步:倒转sqrtvar
ivar = 1./sqrtvar
# 第七步:计算归一化
xhat = xmu * ivar
# 第八步:加上两个参数
gammax = gamma * xhat
out = gammax + beta
# cache储存计算反向传递所需要的一些内容
cache = (xhat, gamma, xmu, ivar, sqrtvar, var, eps)
return out, cache
我们都知道,对于目前的神经网络计算框架,一个层要想加入到网络中,要保证其是可微的,即可以求梯度。BatchNorm的梯度该如何求取?反向传播求梯度只需抓住一个关键点,如果一个变量对另一个变量有影响,那么他们之间就存在偏导数,找到直接相关的变量,再配合链式法则,公式就很容易写出了。
根据反向传播的顺序,首先求取损失 l l l 对BN层输出 y i y_i yi 的偏导 ∂ l ∂ y i \frac{\partial l}{\partial y_i} ∂yi∂l,然后是对可学习参数的偏导 ∂ l ∂ γ \frac{\partial l}{\partial \gamma} ∂γ∂l 和 ∂ l ∂ β \frac{\partial l}{\partial \beta} ∂β∂l,用于对参数进行更新,想继续回传的话还需要求对输入 x x x 偏导,于是引出对变量 μ \mu μ、 σ 2 \sigma^2 σ2 和 x ^ \hat{x} x^ 的偏导,根据链式法则再求这些变量对 x x x 的偏导,计算图如下:
通过链式法则,我们可以对上面的正向传递的代码进行运算,得到反向传播的代码,如下(结合代码理解更方便):
def batchnorm_backward(dout, cache):
# 展开存储在cache中的变量
xhat, gamma, xmu, ivar, sqrtvar, var, eps = cache
# 获得输入输出的维度
N, D = dout.shape
dbeta = np.sum(dout, axis=0)
dgammax = dout
dgamma = np.sum(dgammax * xhat, axis=0)
dxhat = dgammax * gamma
divar = np.sum(dxhat * xmu, axis=0)
dxmu1 = dxhat * ivar
dsqrtvar = -1./(sqrtvar ** 2) * divar
dvar = 0.5 * 1. / np.sqrt(var + eps) * dsqrtvar
dsq = 1. / N * np.ones((N, D)) * dvar
dxmu2 = 2 * xmu * dsq
dx1 = (dxmu1 + dxmu2)
dmu = -1 * np.sum(dxmu1+dxmu2, axis=0)
dx2 = 1. / N * np.ones((N, D)) * dmu
dx = dx1 + dx2
return dx, dgamma, dbeta
Batch Normalization的预测阶段
在预测阶段,所有参数的取值是固定的,对BN层而言,意味着 μ \mu μ、 σ \sigma σ、 γ \gamma γ、 β \beta β 都是固定值。 γ \gamma γ和 β \beta β 比较好理解,随着训练结束,两者最终收敛,预测阶段使用训练结束时的值即可。对于 μ \mu μ 和 σ \sigma σ,在训练阶段,它们为当前mini batch的统计量,随着输入batch的不同, μ \mu μ 和 σ \sigma σ 一直在变化。在预测阶段,输入数据可能只有1条,该使用哪个 μ \mu μ 和 σ \sigma σ ,或者说,每个BN层的 μ \mu μ 和 σ \sigma σ 该如何取值?可以采用训练收敛最后几批mini batch的 μ \mu μ 和 σ \sigma σ 的期望,作为预测阶段的 μ \mu μ 和 σ \sigma σ ,如下所示:
因为Standardization和scale and shift均为线性变换,在预测阶段所有参数均固定的情况下,参数可以合并成 y = k x + b y=kx+b y=kx+b 的形式,如上图中行号11所示。
这里多说一句,BN在卷积中使用时,1个卷积核产生1个feature map,1个feature map有1对 γ \gamma γ和 β \beta β 参数,同一batch同channel的feature map共享同一对 γ \gamma γ和 β \beta β 参数,若卷积层有 n n n 个卷积核,则有 n n n 对 γ \gamma γ和 β \beta β 参数。
对于测试集均值和方差已经不是针对某一个Batch了,而是针对整个数据集而言。因此,在训练过程中除了正常的前向传播和反向求导之外,我们还要记录每一个Batch的均值和方差,以便训练完成之后按照下式计算整体的均值和方差:
E [ x ] ← E β [ μ β ] E[x]\leftarrow E_\beta[\mu_\beta] E[x]←Eβ[μβ] V a r [ x ] ← m m − 1 E β [ σ β 2 ] Var[x]\leftarrow \frac{m}{m-1}E_\beta[\sigma_\beta^2] Var[x]←m−1mEβ[σβ2]
上面简单理解就是:对于均值来说直接计算所有batch u值的平均值;然后对于标准偏差采用每个batch σB的无偏估计。最后测试阶段,BN的使用公式就是行号11所示。
网络inference阶段conv层和BN层的融合
现在很多的网络结构都将BN层直接放在卷积层和激活层之间,这种做法可以在网络的inference阶段,将BN层的运算直接嵌入到卷积层中,减少运算量,提升网络的运行速度。在inference阶段,已知某层卷积层的kernel参数 w w w, b b b ,以及输入 x x x ,紧随其后的BN层参数(已学习到):尺度参数 γ \gamma γ 、偏移参数 β \beta β 、以及样本均值 μ ^ \hat{\mu} μ^ 和标准差 σ ^ \hat{\sigma} σ^ 。
- 卷积层输出为: w ∗ x + b w*x+b w∗x+b
bn层输出为: γ w ∗ x + b − μ ^ σ 2 + ϵ + β \gamma \frac{w*x+b-\hat{\mu}}{\sqrt{\sigma^2+\epsilon}}+\beta γσ2+ϵw∗x+b−μ^+β
bn层的输出可以化为如下形式:
γ w ∗ x + b − μ ^ σ 2 + ϵ + β = ( γ σ 2 + ϵ w ) ∗ x + γ σ 2 + ϵ b − γ σ 2 + ϵ μ ^ + β = k w ∗ x + b ′ \gamma \frac{w*x+b-\hat{\mu}}{\sqrt{\sigma^2+\epsilon}}+\beta=(\frac{\gamma}{\sqrt{\sigma^2+\epsilon}}w)*x+\frac{\gamma}{\sqrt{\sigma^2+\epsilon}}b-\frac{\gamma}{\sqrt{\sigma^2+\epsilon}}\hat{\mu}+\beta=kw*x+b^{‘} γσ2+ϵw∗x+b−μ^+β=(σ2+ϵγw)∗x+σ2+ϵγb−σ2+ϵγμ^+β=kw∗x+b′
所以,在inference阶段,如果BN直接跟在卷积层后,可以将BN层直接嵌入到卷积层的计算中,相当于将卷积核缩放一定倍数,并对偏置进行一定改变。
将BN层融合到卷积层中,就相当于对卷积核进行一定的修改,并没有增加卷积层的计算量,同时整个BN层的计算量都省去了。
实验结果
下图是使用三层全连接层,在每层之后添加BN以及无添加的实验对比:
下图是训练步和精度的实验结果:
下图是使用BN在Inception上的相关实验结果:
关于BN的几个讨论
- 没有scale and shift过程可不可以?
BatchNorm有两个过程,Standardization和scale and shift,前者是机器学习常用的数据预处理技术,在浅层模型中,只需对数据进行Standardization即可,Batch Normalization可不可以只有Standardization呢?答案是可以,但网络的表达能力会下降。直觉上理解,浅层模型中,只需要模型适应数据分布即可。对深度神经网络,每层的输入分布和权重要相互协调,强制把分布限制在zero mean unit variance并不见得是最好的选择,加入参数 γ \gamma γ和 β \beta β ,对输入进行scale and shift,有利于分布与权重的相互协调,特别地,令 γ = 1 \gamma=1 γ=1和 β = 0 \beta=0 β=0 等价于只用Standardization,令 γ = σ \gamma=\sigma γ=σ和 β = μ \beta=\mu β=μ 等价于没有BN层,scale and shift涵盖了这2种特殊情况,在训练过程中决定什么样的分布是适合的,所以使用scale and shift增强了网络的表达能力。表达能力更强,在实践中性能就会更好吗?并不见得,就像曾经参数越多不见得性能越好一样。在caffenet-benchmark-batchnorm中,作者实验发现没有scale and shift性能可能还更好一些,如下图: - BN层放在ReLU前面还是后面?实验表明,放在前后的差异似乎不大,甚至放在ReLU后还好一些(如上图),放在ReLU后相当于直接对每层的输入进行归一化,这与浅层模型的Standardization是一致的。caffenet-benchmark-batchnorm中有很多组合实验结果,可以看看。BN究竟应该放在激活的前面还是后面?以及,BN与其他变量,如激活函数、初始化方法、dropout等,如何组合才是最优?可能只有直觉和经验性的指导意见,具体问题的具体答案可能还是得实验说了算
总结
Batch Normalization的加速作用体现在两个方面:一是归一化了每层和每维度的scale,所以可以整体使用一个较高的学习率,而不必像以前那样迁就小scale的维度;二是归一化后使得更多的权重分界面落在了数据中,降低了overfit的可能性,因此一些防止overfit但会降低速度的方法,例如dropout和权重衰减就可以不使用或者降低其权重。
写在最后
BN层的有效性已有目共睹,但为什么有效可能还需要进一步研究,还需要进一步研究,这里整理了一些关于BN为什么有效的论文,贴在这:
- How Does Batch Normalization Help Optimization?:BN层让损失函数更平滑。
论文中通过分析训练过程中每步梯度方向上步长变化引起的损失变化范围、梯度幅值的变化范围、光滑度的变化,认为添加BN层后,损失函数的landscape(loss surface)变得更平滑,相比高低不平上下起伏的loss surface,平滑loss surface的梯度预测性更好,可以选取较大的步长。如下图所示: - An empirical analysis of the optimization of deep network loss surfaces:BN更有利于梯度下降。
论文中绘制了VGG和NIN网络在有无BN层的情况下,loss surface的差异,包含初始点位置以及不同优化算法最终收敛到的local minima位置,如下图所示。没有BN层的,其loss surface存在较大的高原,有BN层的则没有高原,而是山峰,因此更容易下降。
参考文献:
- caffenet-benchmark-batchnorm
- Understanding the backward pass through Batch Normalization Layer
- Why Does Batch Normalization Work?
- Batch Normalization详解
- Batch Normalization — What the hey
- How does Batch Normalization Help Optimization?
- How does Batch Normalization Help Optimization?
还没有评论,来说两句吧...