2、Batch Normalization

柔光的暖阳◎ 2022-08-31 04:47 261阅读 0赞

Batch Normalization是google团队在2015年论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》提出的。通过该方法能够加速网络的收敛并提升准确率。在网上虽然已经有很多相关文章,但基本都是摆上论文中的公式泛泛而谈,bn真正是如何运作的很少有提及。本文主要分为以下几个部分:

(1)BN的原理

(2)使用pytorch验证本文的观点

(3)使用BN需要注意的地方(BN没用好就是个坑)

1.Batch Normalization原理

我们在图像预处理过程中通常会对图像进行标准化处理,这样能够加速网络的收敛,如下图所示,对于Conv1来说输入的就是满足某一分布的特征矩阵,但对于Conv2而言输入的feature map就不一定满足某一分布规律了(注意这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律)。而我们Batch Normalization的目的就是使我们的feature map满足均值为0,方差为1的分布规律。

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM3NTQxMDk3_size_16_color_FFFFFF_t_70

看到这里应该还是蒙的,不要慌,喝口水,慢慢来。下面是从原论文中截取的原话,注意标黄的部分:

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM3NTQxMDk3_size_16_color_FFFFFF_t_70 1

“对于一个拥有d维的输入x,我们将对它的每一个维度进行标准化处理。” 假设我们输入的x是RGB三通道的彩色图像,那么这里的d就是输入图像的channels即d=3,x=(x^\{(1)\}, x^\{(2)\}, x^\{(3)\}),其中x^\{(1)\}就代表我们的R通道所对应的特征矩阵,依此类推。标准化处理也就是分别对我们的R通道,G通道,B通道进行处理。上面的公式不用看,原文提供了更加详细的计算公式:

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM3NTQxMDk3_size_16_color_FFFFFF_t_70 2

我们刚刚有说让feature map满足某一分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律,也就是说要计算出整个训练集的feature map然后在进行标准化处理,对于一个大型的数据集明显是不可能的,所以论文中说的是Batch Normalization,也就是我们计算一个Batch数据的feature map然后在进行标准化(batch越大越接近整个数据集的分布,效果越好)。我们根据上图的公式可以知道\\mu \_\{\\ss \}代表着我们计算的feature map每个维度(channel)的均值,注意\\mu \_\{\\ss \}是一个向量不是一个值\\mu \_\{\\ss \}向量的每一个元素代表着一个维度(channel)的均值。\\sigma\_\{\\ss \}^\{2\}代表着我们计算的feature map每个维度(channel)的方差,注意\\sigma\_\{\\ss \}^\{2\}是一个向量不是一个值\\sigma\_\{\\ss \}^\{2\}向量的每一个元素代表着一个维度(channel)的方差,然后根据\\mu \_\{\\ss \}\\sigma\_\{\\ss \}^\{2\}计算标准化处理后得到的值。下图给出了一个计算均值\\mu \_\{\\ss \}和方差\\sigma\_\{\\ss \}^\{2\}的示例:

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM3NTQxMDk3_size_16_color_FFFFFF_t_70 3

上图展示了一个batch size为2(两张图片)的Batch Normalization的计算过程,假设feature1、feature2分别是由image1、image2经过一系列卷积池化后得到的特征矩阵,feature的channel为2,那么x^\{(1)\}代表该batch的所有feature的channel1的数据,同理x^\{^\{(2)\}\}代表该batch的所有feature的channel2的数据。然后分别计算x^\{(1)\}x^\{^\{(2)\}\}的均值与方差,得到我们的\\mu \_\{\\ss \}\\sigma\_\{\\ss \}^\{2\}两个向量。然后在根据标准差计算公式分别计算每个channel的值(公式中的\\epsilon是一个很小的常量,防止分母为零的情况)。在我们训练网络的过程中,我们是通过一个batch一个batch的数据进行训练的,但是我们在预测过程中通常都是输入一张图片进行预测,此时batch size为1,如果在通过上述方法计算均值和方差就没有意义了。所以我们在训练过程中要去不断的计算每个batch的均值和方差,并使用移动平均(moving average)的方法记录统计的均值和方差,在我们训练完后我们可以近似认为我们所统计的均值和方差就等于我们整个训练集的均值和方差。然后在我们验证以及预测过程中,就使用我们统计得到的均值和方差进行标准化处理。

细心的同学会发现,在原论文公式中不是还有\\gamma\\beta两个参数吗?是的,\\gamma是用来调整数值分布的方差大小,\\beta是用来调节数值均值的位置。这两个参数是在反向传播过程中学习得到的,\\gamma的默认值是1,\\beta的默认值是0。

2.使用pytorch进行试验

你以为你都懂了?不一定哦。刚刚说了在我们训练过程中,均值\\mu \_\{\\ss \}和方差\\sigma\_\{\\ss \}^\{2\}是通过计算当前批次数据得到的记为为\\mu \_\{now\}\\sigma \_\{now\}^\{2\},而我们的验证以及预测过程中所使用的均值方差是一个统计量记为\\mu \_\{statistic\}\\sigma \_\{statistic\}^\{2\}\\mu \_\{statistic\}\\sigma \_\{statistic\}^\{2\}的具体更新策略如下,其中momentum默认取0.1:

\\large \\mu \_\{statistic+1\}=(1-momentum)\*\\mu \_\{statistic\}+momentum\*\\mu \_\{now\}

\\large \\sigma \_\{statistic+1\}^\{2\}=(1-momentum)\*\\sigma \_\{statistic\}^\{2\}+momentum\*\\sigma \_\{now\}^\{2\}

这里要注意一下,在pytorch中对当前批次feature进行bn处理时所使用的\\large \\sigma \_\{now\}^\{2\}总体标准差,计算公式如下:

\\bg\_white \\large \\sigma \_\{now\}^\{2\}=\\frac\{1\}\{m\}\\sum\_\{i=1\}^\{m\}(x\_\{i\}-\\mu \_\{now\})^\{2\}

在更新统计量\\large \\sigma \_\{statistic\}^\{2\}时采用的\\large \\sigma \_\{now\}^\{2\}样本标准差,计算公式如下:

\\bg\_white \\large \\sigma \_\{now\}^\{2\}=\\frac\{1\}\{m-1\}\\sum\_\{i=1\}^\{m\}(x\_\{i\}-\\mu \_\{now\})^\{2\}

下面是我使用pytorch做的测试,代码如下:

(1)bn_process函数是自定义的bn处理方法验证是否和使用官方bn处理方法结果一致。在bn_process中计算输入batch数据的每个维度(这里的维度是channel维度)的均值和标准差(标准差等于方差开平方),然后通过计算得到的均值和总体标准差对feature每个维度进行标准化,然后使用均值和样本标准差更新统计均值和标准差。

(2)初始化统计均值是一个元素为0的向量,元素个数等于channel深度;初始化统计方差是一个元素为1的向量,元素个数等于channel深度,初始化\\gamma=1,\\beta=0。

  1. import numpy
  2. as np
  3. import torch.nn
  4. as nn
  5. import torch
  6. def bn_process(feature, mean, var):
  7. feature_shape = feature.shape
  8. for i
  9. in range(feature_shape[
  10. 1]):
  11. # [batch, channel, height, width]
  12. feature_t = feature[:, i, :, :]
  13. mean_t = feature_t.mean()
  14. # 总体标准差
  15. std_t1 = feature_t.std()
  16. # 样本标准差
  17. std_t2 = feature_t.std(ddof=
  18. 1)
  19. # bn process
  20. # 这里记得加上eps和pytorch保持一致
  21. feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 **
  22. 2 +
  23. 1e-5)
  24. # update calculating mean and var
  25. mean[i] = mean[i] *
  26. 0.9 + mean_t *
  27. 0.1
  28. var[i] = var[i] *
  29. 0.9 + (std_t2 **
  30. 2) *
  31. 0.1
  32. print(feature)
  33. # 随机生成一个batch为2,channel为2,height=width=2的特征向量
  34. # [batch, channel, height, width]
  35. feature1 = torch.randn(
  36. 2,
  37. 2,
  38. 2,
  39. 2)
  40. # 初始化统计均值和方差
  41. calculate_mean = [
  42. 0.0,
  43. 0.0]
  44. calculate_var = [
  45. 1.0,
  46. 1.0]
  47. # print(feature1.numpy())
  48. # 注意要使用copy()深拷贝
  49. bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)
  50. bn = nn.BatchNorm2d(
  51. 2, eps=
  52. 1e-5)
  53. output = bn(feature1)
  54. print(output)

首先我在最后设置了一个断点进行调试,查看下官方bn对feature处理后得到的统计均值和方差。我们可以发现官方提供的bn的running_mean和running_var和我们自己计算的calculate_mean和calculate_var是一模一样的(只是精度不同)。

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM3NTQxMDk3_size_16_color_FFFFFF_t_70 4

然后我们打印出通过自定义bn_process函数得到的输出以及使用官方bn处理得到输出,明显结果是一样的(只是精度不同):

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM3NTQxMDk3_size_16_color_FFFFFF_t_70 5

3.使用BN时需要注意的问题

(1)训练时要将traning参数设置为True,在验证时将trainning参数设置为False。在pytorch中可通过创建模型的model.train()和model.eval()方法控制。

(2)batch size尽可能设置大点,设置小后表现可能很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。

(3)建议将bn层放在卷积层(Conv)和激活层(例如Relu)之间,且卷积层不要使用偏置bias,因为没有用,参考下图推理,即使使用了偏置bias求出的结果也是一样的\\bg\_white \\large y\_\{i\}^\{b\}=y\_\{i\}

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM3NTQxMDk3_size_16_color_FFFFFF_t_70 6

最后给出李宏毅老师关于batch normalization的视频讲解:

https://www.bilibili.com/video/av9770302?p=10

发表评论

表情:
评论列表 (有 0 条评论,261人围观)

还没有评论,来说两句吧...

相关阅读