Dark Dwarf Blog background

NiN 网络与 1x1 卷积核

NiN 网络与 1x1 卷积核

1. 1x1 卷积核

在讲解具体的 NiN 网络架构前,我们先谈谈 1x1 卷积核。

一般而言,对二维信号进行 1x1 卷积是没有意义的。但是在 CNN 中,整个空间是三维的,这个卷积核的实际大小为 1×1×num_channels1\times 1\times \text{num{\_}channels}。在这个卷积核在原有的 m×nm\times n 图像空间的每一个位置,都会与该点的 num_channels\text{num{\_}channels} 个通道进行点积。而这个过程可以看作是对所有通道的信息整合

同时,我们在 1x1 卷积核后面接上 ReLU 非线性层还可以增加我们网络的非线性。

上面这些内容是 NiN 网络设计的重要部分。

2. MLPConv

a.a. MLPConv 的引入

NiN 论文中提到了传统的“卷积+非线性激活”方法的问题:卷积的本质是线性操作,因此这样的方式只有在我们的实例是线性可分的时候才有效。而为了解决这个问题,传统 CNN 会不断引入新的卷积核、提取出不同的特征。但是这些特征往往过于零散、并未经过充分抽象,对接下来的卷积层带来了压力。

为了解决这个问题,NiN引入了 MLPConv,使用多层感知机来提取更准确的局部特征。

b.b. MLPConv 层的数学表达与理解

假设我们的 NiN 网络有 nn 层,对第一层:

fi,j,k11=max((wk11)xi,j+bk1, 0)f^{1}_{i,j,k1} = \max\left( (\mathbf{w}_{k1}^{1})^\top \mathbf{x}_{i,j} + b_{k1},\ 0 \right)

之后的每一层,我们都对前一层进行卷积然后 ReLU 非线性激活

fi,j,knn=max((wknn)fi,j(n1)+bkn, 0)f^{n}_{i,j,k_n} = \max\left( (\mathbf{w}^{n}_{k_n})^\top \mathbf{f}^{(n-1)}_{i,j} + b_{k_n},\ 0 \right)

我们来换个角度理解这个过程:第 nn 层对 n1n-1 层的结果使用卷积,实际上是在n1n-1 层的结果通道进行加权求和。而这正是 1x1 卷积可以使用的地方!因此我们可以大量使用我们的 1x1 卷积核来完成这个加权求和。

3. 全局平均池化

全局平均池化 (Global Average Pooling, GAP) 是针对传统 CNN 的又一优化。传统的卷积神经网络会将卷积 feature 层的输出展平到一维长向量,然后添加一些全连接层作为分类层来处理这个长向量。但是这样的网络架构的参数主要在全连接层部分,这很容易导致过拟合

为了解决这个问题,NiN 引入了全局平均池化来取代全连接层:在 MLPConv 的最后一层,我们对卷积网络输出的每个特征图,都计算其在空间维度上的平均值,得到一个向量。例如一个大小为 H×WH\times W 的特征图,它的每个像素的值会被加起来、然后除以 H×WH\times W。然后我们直接这些向量扔到 Softmax 中进行分类。

这样的做法有如下的好处:

  1. GAP 强制建立了特征图与类别的对应关系,这使得分类部分与前面的特征提取部分更加统一。
  2. GAP 本身只做了求平均值这一件事,没有任何需要学习的参数,也就没有过拟合的风险了。

在 NiN 框架中,在 GAP 前的最后一个卷积层输出的特征图被称作类别置信度图(Category Confidence Map)。例如我们在完成一个猫狗分类任务,对于猫的置信度图,这张图上每个像素点的激活值(也就是亮度)代表了模型认为原始图像对应位置存在“猫”这个概念的特征的可能性有多大。如果原始图像的左上角有一只猫,那么理想情况下,这张“猫”置信度图的左上角区域就会被高度激活。而 GAP 就是给每个置信度图计算平均亮度。这带来了很好的可解释性:我们可以直接将这些特征图可视化出来、判断模型作出决定的证据是什么。

而我们前面的 MLPConv 正好是一个对局部信息进行充分抽象的特征提取器,它能够输出一个足够精确的置信度图、让 GAP 能够得到准确的结果。

4. 简单实现

我们在 CIFAR-10 数据集上实现一个 NiN 架构网络。核心的组件是 MLPConv 模块:我们先在第一层放一个卷积层,然后使用 1x1 的卷积核不断整合信息、提取更精确的局部特征:

nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(192),
nn.ReLU(),
nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(192),
nn.ReLU(),
nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(192),
nn.ReLU(),
nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
nn.Dropout(0.2)

完整实现如下: