1.3 批归一化方法
前面讲了两个在CNN发展过程中有着重要地位的模型,在本节中,我们换一种角度,从结构上优化模型,也就是下面要讲的批归一化。
1.3.1 批归一化简介
批归一化(batch normalization,BN)是深度学习中常用的一种技术,它实际上是一种数据标准化方法。通过对每批输入数据进行标准化,可以提升神经网络的训练速度和精度。通过这种方法计算数据的均值和方差,将数据进行缩放和平移,使其分布在指定的范围内,有助于缓解深度神经网络中梯度消失和梯度爆炸的问题,通常被插入卷积层或全连接层的输入或输出之间,有助于模型的收敛。BN处理示意如图1-7所示。
图1-7 BN处理示意
BN是由谷歌的两位工程师Sergey Ioffe和Christian Szegedy在2015年共同提出的。截至目前,其论文已经被引用超过4万次。
这个BN到底是如何计算的呢?下面是论文给出的计算公式。
上述公式看起来好像很复杂,其实很简单。公式经过一定变形并推导后,输出数据y的每个元素通过下面的公式计算即可。
其中,x表示输入数据,和分别表示输入数据的均值和方差, 和表示BN层的两个可训练参数, 是一个很小的常数,用于避免方差为0的情况。这么看起来是不是就好理解多了?通过上述变换可以有效减小数据的发散程度,从而降低学习难度。
BN在实现时一般还有两个参数,即移动均值和方差,用于描述整个数据集的情况,这里就不做详细介绍了,感兴趣的读者可以自行查阅相关资料。
那么BN为什么有效呢?主要包括下面几个原因。
● 通过对输入和中间网络层的输出进行标准化处理,减少了内部神经元分布的改变,降低了不同样本值域的差异性,使得大部分数据处在非饱和区域,保证了梯度能够很好地回传,缓解了梯度消失和梯度爆炸。
● 通过减少梯度对参数或其初始值尺度的依赖,使得可以用较大的学习率对网络进行训练,从而加快了收敛速度。
● BN本质上是一种正则化手段,能够提升网络的泛化能力,可以减少或者去除Dropout机制,从而优化网络结构。
小 白:如果不进行BN,会有什么问题?
梗老师:问题太多了,比如训练效果可能会不稳定,甚至出现梯度消失和梯度爆炸。不过,如果模型本身能正常收敛的话,收敛速度往往要比经BN的快。
1.3.2 代码实现
了解BN的基本思想之后,下面我们来看如何用代码实现。
# 导入必要的库,torchinfo用于查看模型结构 import torch import torch.nn as nn from torchinfo import summary
接下来我们以《破解深度学习(基础篇):模型算法与实现》中的LeNet模型为例,对其用BN层进行改造。PyTorch框架提供了相关方法,直接使用就行。这部分改动量不大,只需要在卷积层和前两个全连接层后面都加上一个BatchNorm层,传入对应的通道数即可。
需要特别注意的是输入维度略有不同,卷积层后面用的是BatchNorm2d,而全连接层后面用的是BatchNorm1d。再下面的forward()函数不变,在卷积层后面接最大池化层,激活函数都是ReLU,依次处理后输出。
# 定义LeNet的网络结构 class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() # 卷积层1:输入1个通道,输出6个通道,卷积核大小为5×5,后接BN self.conv1 = nn.Sequential( nn.Conv2d(1, 6, 5), nn.BatchNorm2d(6) ) # 卷积层2:输入6个通道,输出16个通道,卷积核大小为5×5,后接BN self.conv2 = nn.Sequential( nn.Conv2d(6, 16, 5), nn.BatchNorm2d(16) ) # 全连接层1:输入16×4×4=256个节点,输出120个节点,由于输入数据略有差异,修改为16×4×4 self.fc1 = nn.Sequential( nn.Linear(16 * 4 * 4, 120), nn.BatchNorm1d(120) ) # 全连接层2:输入120个节点,输出84个节点 self.fc2 = nn.Sequential( nn.Linear(120, 84), nn.BatchNorm1d(84) ) # 输出层:输入84个节点,输出10个节点 self.fc3 = nn.Linear(84, 10) def forward(self, x): # 使用ReLU激活函数,并进行最大池化 x = torch.relu(self.conv1(x)) x = nn.functional.max_pool2d(x, 2) # 使用ReLU激活函数,并进行最大池化 x = torch.relu(self.conv2(x)) x = nn.functional.max_pool2d(x, 2) # 将多维张量展平为一维张量 x = x.view(-1, 16 * 4 * 4) # 全连接层 x = torch.relu(self.fc1(x)) # 全连接层 x = torch.relu(self.fc2(x)) # 全连接层 x = self.fc3(x) return x
再往下来看网络结构。调用torchinfo.summary()可以查看刚刚改造后的模型信息。新增的与池化层不同,BatchNorm层有可训练参数,因此改造后的参数量也会有变化。
# 查看模型结构及参数量,input_size表示示例输入数据的维度信息 summary(LeNet(), input_size=(1, 1, 28, 28)) ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== LeNet [1, 10] -- ├─Sequential: 1-1 [1, 6, 24, 24] -- │ └─Conv2d: 2-1 [1, 6, 24, 24] 156 │ └─BatchNorm2d: 2-2 [1, 6, 24, 24] 12 ├─Sequential: 1-2 [1, 16, 8, 8] -- │ └─Conv2d: 2-3 [1, 16, 8, 8] 2,416 │ └─BatchNorm2d: 2-4 [1, 16, 8, 8] 32 ├─Sequential: 1-3 [1, 120] -- │ └─Linear: 2-5 [1, 120] 30,840 │ └─BatchNorm1d: 2-6 [1, 120] 240 ├─Sequential: 1-4 [1, 84] -- │ └─Linear: 2-7 [1, 84] 10,164 │ └─BatchNorm1d: 2-8 [1, 84] 168 ├─Linear: 1-5 [1, 10] 850 ========================================================================================== Total params: 44,878 Trainable params: 44,878 Non-trainable params: 0 Total mult-adds (M): 0.29 ========================================================================================== Input size (MB): 0.00 Forward/backward pass size (MB): 0.08 Params size (MB): 0.18 Estimated Total Size (MB): 0.26 ==========================================================================================
1.3.3 模型训练
最后替换LeNet模型结构重新进行训练,代码和《破解深度学习(基础篇):模型算法与实现》中讲过的LeNet模型训练部分完全一致。损失和准确率曲线如图1-8所示,可以看出,在增加了BN层后,准确率较原始的LeNet又有一定提升,而且收敛得更快了。
# 代码部分与前面章节的LeNet代码一致 # ... Epoch: 0 Loss: 2.2088656986293165 Acc: 0.9578 Epoch: 2 Loss: 1.4376559603001913 Acc: 0.979 Epoch: 4 Loss: 1.228520721191793 Acc: 0.9803 Epoch: 6 Loss: 1.106042682456222 Acc: 0.9859 Epoch: 8 Loss: 1.0158490855052476 Acc: 0.9883 100%|██████████| 10/10 [01:56<00:00, 11.62s/it]
图1-8 损失和准确率曲线
1.3.4 小结
本节重点讲解了BN的特性及其计算方式,探讨了这种方法为什么能有效缓解梯度消失和梯度爆炸并加快收敛速度。最后,我们对LeNet模型进行了改造,并以MNIST数据集为例,对比了有无BN层情况下的训练效果。