4 池化层¶
学习目标¶
- 掌握池化计算过程
- 掌握PyTorch池化层API
池化层 (Pooling) 降低维度, 缩减模型大小,提高计算速度. 即: 主要对卷积层学习到的特征图进行下采样(SubSampling)处理.
池化层主要有两种:
- 最大池化
- 平均池化
1. 池化层计算¶
最大池化:
- max(0, 1, 3, 4)
- max(1, 2, 4, 5)
- max(3, 4, 6, 7)
- max(4, 5, 7, 8)
平均池化:
- mean(0, 1, 3, 4)
- mean(1, 2, 4, 5)
- mean(3, 4, 6, 7)
- mean(4, 5, 7, 8)
2. Stride¶
最大池化:
- max(0, 1, 4, 5)
- max(2, 3, 6, 7)
- max(8, 9, 12, 13)
- max(10, 11, 14, 15)
平均池化:
- mean(0, 1, 4, 5)
- mean(2, 3, 6, 7)
- mean(8, 9, 12, 13)
- mean(10, 11, 14, 15)
3. Padding¶
最大池化:
- max(0, 0, 0, 0)
- max(0, 0, 0, 1)
- max(0, 0, 1, 2)
- max(0, 0, 2, 0)
- ... 以此类推
平均池化:
- mean(0, 0, 0, 0)
- mean(0, 0, 0, 1)
- mean(0, 0, 1, 2)
- mean(0, 0, 2, 0)
- ... 以此类推
4. 多通道池化计算¶
在处理多通道输入数据时,池化层对每个输入通道分别池化,而不是像卷积层那样将各个通道的输入相加。这意味着池化层的输出和输入的通道数是相等。
5. PyTorch 池化 API 使用¶
import torch
import torch.nn as nn
# 1. API 基本使用
def test01():
inputs = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]).float()
inputs = inputs.unsqueeze(0).unsqueeze(0)
# 1. 最大池化
# 输入形状: (N, C, H, W)
polling = nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
output = polling(inputs)
print(output)
# 2. 平均池化
polling = nn.AvgPool2d(kernel_size=2, stride=1, padding=0)
output = polling(inputs)
print(output)
# 2. stride 步长
def test02():
inputs = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]).float()
inputs = inputs.unsqueeze(0).unsqueeze(0)
# 1. 最大池化
polling = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
output = polling(inputs)
print(output)
# 2. 平均池化
polling = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
output = polling(inputs)
print(output)
# 3. padding 填充
def test03():
inputs = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]).float()
inputs = inputs.unsqueeze(0).unsqueeze(0)
# 1. 最大池化
polling = nn.MaxPool2d(kernel_size=2, stride=1, padding=1)
output = polling(inputs)
print(output)
# 2. 平均池化
polling = nn.AvgPool2d(kernel_size=2, stride=1, padding=1)
output = polling(inputs)
print(output)
# 4. 多通道池化
def test04():
inputs = torch.tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
[[10, 20, 30], [40, 50, 60], [70, 80, 90]],
[[11, 22, 33], [44, 55, 66], [77, 88, 99]]]).float()
inputs = inputs.unsqueeze(0)
# 最大池化
polling = nn.MaxPool2d(kernel_size=2, stride=1, padding=0)
output = polling(inputs)
print(output)
if __name__ == '__main__':
test04()
5. 小节¶
本小节主要学习了池化层的相关知识,池化层主要用于减少数据的维度。其主要分为: 最大池化、平均池化,我们在进行图像分类任务时,可以使用最大池化。