本文深入剖析卷积神经网络中池化层的核心原理,涵盖最大池化、平均池化、全局池化、自适应池化等多种变体,从数学原理到完整代码实现,帮你彻底理解这个CNN的关键组件。
一、池化层概述
1.1 什么是池化
池化(Pooling) 是一种下采样操作,用于减小特征图的空间尺寸,同时保留重要特征。
池化操作示意:
输入特征图 4×4 输出特征图 2×2
┌────┬────┬────┬────┐ ┌────┬────┐
│ 1 │ 3 │ 2 │ 1 │ │ │ │
├────┼────┼────┼────┤ 2×2池化 │ 4 │ 6 │
│ 4 │ 2 │ 1 │ 6 │ ───────→ ├────┼────┤
├────┼────┼────┼────┤ │ │ │
│ 1 │ 5 │ 3 │ 2 │ │ 5 │ 8 │
├────┼────┼────┼────┤ └────┴────┘
│ 2 │ 1 │ 8 │ 4 │
└────┴────┴────┴────┘
每个2×2区域取最大值(最大池化)
1.2 为什么需要池化
┌─────────────────────────────────────────────────────────────────┐
│ 池化层的核心作用 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 1. 降维(Dimensionality Reduction) │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 输入: 224×224 → 池化后: 112×112 │ │
│ │ 特征图尺寸减半,计算量减少到1/4 │ │
│ │ 让网络可以更深,感受野更大 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ 2. 平移不变性(Translation Invariance) │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 目标轻微移动时,池化后的特征保持不变 │ │
│ │ 提高模型对位置变化的鲁棒性 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ 3. 防止过拟合(Regularization) │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 减少参数数量 │ │
│ │ 丢弃部分空间信息,保留最重要的特征 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
│ 4. 扩大感受野(Receptive Field) │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ 池化后每个像素对应原图更大的区域 │ │
│ │ 让后续卷积能"看到"更大范围的信息 │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
1.3 池化在CNN中的位置
典型CNN架构中池化的位置:
Input (224×224×3)
│
▼
┌─────────────┐
│ Conv1 │ → 224×224×64
│ ReLU │
│ Pool │ → 112×112×64 ← 池化1:尺寸减半
└─────────────┘
│
▼
┌─────────────┐
│ Conv2 │ → 112×112×128
│ ReLU │
│ Pool │ → 56×56×128 ← 池化2:尺寸再减半
└─────────────┘
│
▼
┌─────────────┐
│ Conv3 │ → 56×56×256
│ ReLU │
│ Pool │ → 28×28×256 ← 池化3
└─────────────┘
│
▼
...更多层...
│
▼
┌─────────────┐
│ Global Pool │ → 1×1×512 ← 全局池化:变成向量
└─────────────┘
│
▼
┌─────────────┐
│ FC │ → 1000 ← 分类输出
└─────────────┘
二、最大池化(Max Pooling)
2.1 原理详解
最大池化取池化窗口内的最大值,保留最显著的特征。
最大池化示意(2×2池化,步长2):
输入 4×4:
┌────┬────┬────┬────┐
│ 1 │ 3 │ 2 │ 1 │
├────┼────┼────┼────┤ 窗口1: max(1,3,4,2) = 4
│ 4 │ 2 │ 1 │ 6 │ 窗口2: max(2,1,1,6) = 6
├────┼────┼────┼────┤ 窗口3: max(1,5,2,1) = 5
│ 1 │ 5 │ 3 │ 2 │ 窗口4: max(3,2,8,4) = 8
├────┼────┼────┼────┤
│ 2 │ 1 │ 8 │ 4 │
└────┴────┴────┴────┘
输出 2×2:
┌────┬────┐
│ 4 │ 6 │
├────┼────┤
│ 5 │ 8 │
└────┴────┘
2.2 数学定义
最大池化公式:
对于输入特征图 X,池化窗口大小 k×k,步长 s
输出位置 (i,j) 的值:
Y[i,j] = max { X[i×s+m, j×s+n] | m,n ∈ [0, k-1] }
例如 2×2池化,步长2:
Y[0,0] = max(X[0,0], X[0,1], X[1,0], X[1,1])
Y[0,1] = max(X[0,2], X[0,3], X[1,2], X[1,3])
...
2.3 最大池化的特性
优点:
┌─────────────────────────────────────────────────────────────┐
│ • 保留最显著的特征(边缘、纹理等激活最强的位置) │
│ • 提供一定的平移不变性 │
│ • 对噪声有一定的抑制作用 │
│ • 计算简单高效 │
└─────────────────────────────────────────────────────────────┘
直觉理解:
"这个区域有没有检测到边缘?" → 取最大值
只要区域内有一个位置检测到边缘,整个区域就"有边缘"
2.4 PyTorch实现
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class MaxPool2dManual:
"""
手动实现最大池化
"""
def __init__(self, kernel_size=2, stride=None, padding=0):
"""
Args:
kernel_size: 池化窗口大小
stride: 步长,默认等于kernel_size
padding: 填充大小
"""
self.kernel_size = kernel_size
self.stride = stride if stride is not None else kernel_size
self.padding = padding
def forward(self, x):
"""
前向传播
Args:
x: 输入张量 [B, C, H, W]
Returns:
output: 池化后的张量
indices: 最大值位置索引(用于反池化)
"""
B, C, H, W = x.shape
k = self.kernel_size
s = self.stride
p = self.padding
# 计算输出尺寸
H_out = (H + 2*p - k) // s + 1
W_out = (W + 2*p - k) // s + 1
# 填充
if p > 0:
x = F.pad(x, (p, p, p, p), mode='constant', value=float('-inf'))
# 初始化输出
output = torch.zeros(B, C, H_out, W_out, device=x.device)
indices = torch.zeros(B, C, H_out, W_out, dtype=torch.long, device=x.device)
# 滑动窗口池化
for i in range(H_out):
for j in range(W_out):
# 提取窗口
h_start = i * s
w_start = j * s
window = x[:, :, h_start:h_start+k, w_start:w_start+k]
# 取最大值
window_flat = window.reshape(B, C, -1)
max_vals, max_indices = window_flat.max(dim=2)
output[:, :, i, j] = max_vals
indices[:, :, i, j] = max_indices
return output, indices
def backward(self, grad_output, indices, input_shape):
"""
反向传播
最大池化的梯度只传给最大值位置,其他位置梯度为0
"""
B, C, H, W = input_shape
k = self.kernel_size
s = self.stride
grad_input = torch.zeros(input_shape, device=grad_output.device)
H_out, W_out = grad_output.shape[2], grad_output.shape[3]
for i in range(H_out):
for j in range(W_out):
h_start = i * s
w_start = j * s
for b in range(B):
for c in range(C):
# 找到最大值在窗口内的位置
idx = indices[b, c, i, j].item()
h_idx = idx // k
w_idx = idx % k
# 梯度只传给最大值位置
grad_input[b, c, h_start+h_idx, w_start+w_idx] += grad_output[b, c, i, j]
return grad_input
# 使用PyTorch内置实现
def max_pool_demo():
"""最大池化演示"""
# 创建输入
x = torch.tensor([
[1, 3, 2, 1],
[4, 2, 1, 6],
[1, 5, 3, 2],
[2, 1, 8, 4]
], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
print("输入:")
print(x.squeeze())
# 最大池化
pool = nn.MaxPool2d(kernel_size=2, stride=2)
output = pool(x)
print("\n2×2最大池化输出:")
print(output.squeeze())
# 带索引的最大池化(用于反池化)
pool_with_indices = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
output, indices = pool_with_indices(x)
print("\n最大值索引:")
print(indices.squeeze())
max_pool_demo()
三、平均池化(Average Pooling)
3.1 原理详解
平均池化取池化窗口内的平均值,保留区域的整体特征。
平均池化示意(2×2池化,步长2):
输入 4×4:
┌────┬────┬────┬────┐
│ 1 │ 3 │ 2 │ 1 │
├────┼────┼────┼────┤ 窗口1: avg(1,3,4,2) = 2.5
│ 4 │ 2 │ 1 │ 6 │ 窗口2: avg(2,1,1,6) = 2.5
├────┼────┼────┼────┤ 窗口3: avg(1,5,2,1) = 2.25
│ 1 │ 5 │ 3 │ 2 │ 窗口4: avg(3,2,8,4) = 4.25
├────┼────┼────┼────┤
│ 2 │ 1 │ 8 │ 4 │
└────┴────┴────┴────┘
输出 2×2:
┌──────┬──────┐
│ 2.5 │ 2.5 │
├──────┼──────┤
│ 2.25 │ 4.25 │
└──────┴──────┘
3.2 数学定义
平均池化公式:
Y[i,j] = (1/k²) × Σ Σ X[i×s+m, j×s+n]
m n
其中 m,n ∈ [0, k-1]
例如 2×2池化:
Y[0,0] = (X[0,0] + X[0,1] + X[1,0] + X[1,1]) / 4
3.3 最大池化 vs 平均池化
┌─────────────────────────────────────────────────────────────────┐
│ 最大池化 vs 平均池化 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 最大池化 (Max Pooling): │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ • 保留最显著的特征 │ │
│ │ • 对噪声更鲁棒 │ │
│ │ • 适合检测"是否存在某特征" │ │
│ │ • 常用于卷积层之后 │ │
│ │ • 例:边缘检测、纹理识别 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ 平均池化 (Average Pooling): │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ • 保留区域的整体信息 │ │
│ │ • 更平滑的下采样 │ │
│ │ • 适合需要保留背景信息的任务 │ │
│ │ • 常用于网络末端(全局平均池化) │ │
│ │ • 例:语义分割、全局特征提取 │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ 可视化对比: │
│ │
│ 输入图像: ■ □ □ □ 最大池化: ■ □ 平均池化: ▧ □ │
│ □ □ □ □ □ □ □ □ │
│ □ □ ■ □ │
│ □ □ □ □ 保留突出特征 保留整体亮度 │
│ │
└─────────────────────────────────────────────────────────────────┘
3.4 PyTorch实现
python
class AvgPool2dManual:
"""
手动实现平均池化
"""
def __init__(self, kernel_size=2, stride=None, padding=0):
self.kernel_size = kernel_size
self.stride = stride if stride is not None else kernel_size
self.padding = padding
def forward(self, x):
"""
前向传播
"""
B, C, H, W = x.shape
k = self.kernel_size
s = self.stride
p = self.padding
# 计算输出尺寸
H_out = (H + 2*p - k) // s + 1
W_out = (W + 2*p - k) // s + 1
# 填充
if p > 0:
x = F.pad(x, (p, p, p, p), mode='constant', value=0)
# 初始化输出
output = torch.zeros(B, C, H_out, W_out, device=x.device)
# 滑动窗口池化
for i in range(H_out):
for j in range(W_out):
h_start = i * s
w_start = j * s
window = x[:, :, h_start:h_start+k, w_start:w_start+k]
# 取平均值
output[:, :, i, j] = window.mean(dim=(2, 3))
return output
def backward(self, grad_output, input_shape):
"""
反向传播
平均池化的梯度均匀分配给窗口内所有位置
"""
B, C, H, W = input_shape
k = self.kernel_size
s = self.stride
grad_input = torch.zeros(input_shape, device=grad_output.device)
H_out, W_out = grad_output.shape[2], grad_output.shape[3]
for i in range(H_out):
for j in range(W_out):
h_start = i * s
w_start = j * s
# 梯度均匀分配
grad_input[:, :, h_start:h_start+k, w_start:w_start+k] += \
grad_output[:, :, i:i+1, j:j+1] / (k * k)
return grad_input
def avg_pool_demo():
"""平均池化演示"""
x = torch.tensor([
[1, 3, 2, 1],
[4, 2, 1, 6],
[1, 5, 3, 2],
[2, 1, 8, 4]
], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
print("输入:")
print(x.squeeze())
# 平均池化
pool = nn.AvgPool2d(kernel_size=2, stride=2)
output = pool(x)
print("\n2×2平均池化输出:")
print(output.squeeze())
avg_pool_demo()
四、全局池化(Global Pooling)
4.1 原理详解
全局池化将整个特征图压缩成单个值,常用于CNN末端替代全连接层。
全局池化示意:
输入特征图 7×7×512:
┌─────────────────────┐
│ │
│ 7×7 特征图 │ × 512个通道
│ │
└─────────────────────┘
│
▼ 全局平均池化
┌───┐
│ 1 │ × 512 = 512维向量
└───┘
每个通道的7×7特征图 → 取平均 → 1个数
512个通道 → 512维向量
4.2 为什么使用全局池化
传统方式(全连接层):
7×7×512 → Flatten → 25088维 → FC(25088, 4096) → FC(4096, 1000)
问题:
• 参数量巨大:25088×4096 ≈ 1亿参数
• 容易过拟合
• 要求固定输入尺寸
全局池化方式:
7×7×512 → GlobalAvgPool → 512维 → FC(512, 1000)
优点:
• 参数量小:只需512×1000 ≈ 50万参数
• 不容易过拟合
• 可以处理任意输入尺寸
• 每个通道对应一个类别的响应(可解释性强)
4.3 全局平均池化 vs 全局最大池化
python
class GlobalPooling(nn.Module):
"""
全局池化层
将 [B, C, H, W] → [B, C, 1, 1] 或 [B, C]
"""
def __init__(self, pool_type='avg', flatten=True):
"""
Args:
pool_type: 'avg' 或 'max'
flatten: 是否展平为 [B, C]
"""
super().__init__()
self.pool_type = pool_type
self.flatten = flatten
def forward(self, x):
"""
Args:
x: [B, C, H, W]
Returns:
[B, C] 或 [B, C, 1, 1]
"""
if self.pool_type == 'avg':
# 全局平均池化
out = x.mean(dim=(2, 3))
elif self.pool_type == 'max':
# 全局最大池化
out = x.amax(dim=(2, 3))
elif self.pool_type == 'avg+max':
# 同时使用两者(拼接)
avg_out = x.mean(dim=(2, 3))
max_out = x.amax(dim=(2, 3))
out = torch.cat([avg_out, max_out], dim=1)
else:
raise ValueError(f"Unknown pool_type: {self.pool_type}")
if not self.flatten:
out = out.unsqueeze(-1).unsqueeze(-1)
return out
def global_pool_demo():
"""全局池化演示"""
# 模拟最后一层特征图
x = torch.randn(2, 512, 7, 7) # [B, C, H, W]
print(f"输入形状: {x.shape}")
# 全局平均池化
gap = GlobalPooling(pool_type='avg')
out_avg = gap(x)
print(f"全局平均池化后: {out_avg.shape}")
# 全局最大池化
gmp = GlobalPooling(pool_type='max')
out_max = gmp(x)
print(f"全局最大池化后: {out_max.shape}")
# 两者结合
gp_both = GlobalPooling(pool_type='avg+max')
out_both = gp_both(x)
print(f"avg+max池化后: {out_both.shape}")
global_pool_demo()
4.4 Network in Network (NIN) 中的全局池化
python
class NINClassifier(nn.Module):
"""
Network in Network 风格的分类器
使用全局平均池化替代全连接层
每个通道直接对应一个类别
"""
def __init__(self, in_channels, num_classes):
super().__init__()
# 1×1卷积将通道数变为类别数
self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)
# 全局平均池化
self.gap = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
"""
Args:
x: [B, C, H, W] 特征图
Returns:
[B, num_classes] 分类logits
"""
# 1×1卷积:[B, C, H, W] → [B, num_classes, H, W]
x = self.conv(x)
# 全局平均池化:[B, num_classes, H, W] → [B, num_classes, 1, 1]
x = self.gap(x)
# 展平:[B, num_classes]
x = x.view(x.size(0), -1)
return x
# 可视化通道响应
def visualize_channel_response():
"""
可视化每个通道对应的类别响应
"""
import matplotlib.pyplot as plt
# 假设有10个类别
model = NINClassifier(in_channels=512, num_classes=10)
# 输入特征图
x = torch.randn(1, 512, 14, 14)
# 获取1×1卷积后的特征图(每个通道对应一个类别)
with torch.no_grad():
class_maps = model.conv(x) # [1, 10, 14, 14]
# 可视化每个类别的响应图
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flat):
ax.imshow(class_maps[0, i].numpy(), cmap='hot')
ax.set_title(f'Class {i}')
ax.axis('off')
plt.tight_layout()
plt.savefig('class_activation_maps.png')
print("Class activation maps saved!")
五、自适应池化(Adaptive Pooling)
5.1 原理详解
自适应池化指定输出尺寸而非池化核大小,自动计算需要的参数。
自适应池化示意:
"我不管输入多大,输出必须是 3×3"
输入 7×7: 输入 13×13:
┌─────────────┐ ┌─────────────────┐
│ │ │ │
│ 7×7 │ Adaptive │ 13×13 │ Adaptive
│ │ Pool(3×3) │ │ Pool(3×3)
└─────────────┘ ↓ └─────────────────┘ ↓
┌─────────┐ ┌─────────┐
│ 3×3 │ │ 3×3 │
└─────────┘ └─────────┘
自动计算池化核大小和步长
5.2 为什么需要自适应池化
问题:传统池化要求固定输入尺寸
训练时: 输入224×224 → 固定的池化参数 → 输出7×7
测试时: 输入320×320 → 同样的池化参数 → 输出?(尺寸不对!)
解决方案:自适应池化
训练时: 输入224×224 → AdaptivePool(7×7) → 输出7×7
测试时: 输入320×320 → AdaptivePool(7×7) → 输出7×7 ✓
可以处理任意输入尺寸!
5.3 PyTorch实现
python
class AdaptiveAvgPool2dManual:
"""
手动实现自适应平均池化
"""
def __init__(self, output_size):
"""
Args:
output_size: 输出尺寸 (H_out, W_out) 或单个数
"""
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
self.output_size = output_size
def forward(self, x):
"""
前向传播
自动计算每个输出位置对应的输入区域
"""
B, C, H, W = x.shape
H_out, W_out = self.output_size
output = torch.zeros(B, C, H_out, W_out, device=x.device)
for i in range(H_out):
for j in range(W_out):
# 计算输入区域的起始和结束位置
h_start = int(np.floor(i * H / H_out))
h_end = int(np.ceil((i + 1) * H / H_out))
w_start = int(np.floor(j * W / W_out))
w_end = int(np.ceil((j + 1) * W / W_out))
# 提取区域并取平均
region = x[:, :, h_start:h_end, w_start:w_end]
output[:, :, i, j] = region.mean(dim=(2, 3))
return output
def adaptive_pool_demo():
"""自适应池化演示"""
# 不同尺寸的输入
sizes = [(7, 7), (14, 14), (13, 13), (32, 32)]
# 自适应池化到3×3
adaptive_pool = nn.AdaptiveAvgPool2d((3, 3))
print("自适应池化演示(输出固定为3×3):")
for h, w in sizes:
x = torch.randn(1, 64, h, w)
out = adaptive_pool(x)
print(f" 输入 {h}×{w} → 输出 {out.shape[2]}×{out.shape[3]}")
adaptive_pool_demo()
5.4 SPP(空间金字塔池化)
python
class SpatialPyramidPooling(nn.Module):
"""
空间金字塔池化 (SPP)
多个尺度的自适应池化,拼接得到固定长度的特征
应用:
- SPPNet:可以处理任意尺寸输入
- 目标检测:提取不同尺度的特征
"""
def __init__(self, pool_sizes=[1, 2, 4]):
"""
Args:
pool_sizes: 不同尺度的池化输出尺寸
例如 [1, 2, 4] 表示输出 1×1, 2×2, 4×4
"""
super().__init__()
self.pool_sizes = pool_sizes
# 创建自适应池化层
self.pools = nn.ModuleList([
nn.AdaptiveMaxPool2d(size) for size in pool_sizes
])
def forward(self, x):
"""
Args:
x: [B, C, H, W]
Returns:
[B, C × (1×1 + 2×2 + 4×4)] = [B, C × 21]
"""
B, C, H, W = x.shape
features = []
for pool in self.pools:
# 池化
pooled = pool(x) # [B, C, size, size]
# 展平
pooled = pooled.view(B, -1) # [B, C × size × size]
features.append(pooled)
# 拼接所有尺度的特征
output = torch.cat(features, dim=1)
return output
def get_output_size(self, in_channels):
"""计算输出特征维度"""
total = sum(size * size for size in self.pool_sizes)
return in_channels * total
def spp_demo():
"""SPP演示"""
spp = SpatialPyramidPooling(pool_sizes=[1, 2, 4])
# 不同尺寸的输入,输出维度相同
print("SPP演示(输出维度固定):")
for h, w in [(7, 7), (14, 14), (28, 28)]:
x = torch.randn(2, 256, h, w)
out = spp(x)
print(f" 输入 {h}×{w}×256 → 输出 {out.shape}")
print(f"\n输出维度计算: 256 × (1×1 + 2×2 + 4×4) = {spp.get_output_size(256)}")
spp_demo()
六、其他池化变体
6.1 重叠池化(Overlapping Pooling)
python
class OverlappingPool(nn.Module):
"""
重叠池化
AlexNet中使用:kernel_size=3, stride=2
池化窗口有重叠,保留更多信息
"""
def __init__(self, kernel_size=3, stride=2):
super().__init__()
# stride < kernel_size 时产生重叠
self.pool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride)
def forward(self, x):
return self.pool(x)
def overlapping_pool_demo():
"""重叠池化演示"""
x = torch.randn(1, 64, 13, 13)
# 非重叠池化:kernel_size=2, stride=2
non_overlap = nn.MaxPool2d(kernel_size=2, stride=2)
out1 = non_overlap(x)
# 重叠池化:kernel_size=3, stride=2
overlap = nn.MaxPool2d(kernel_size=3, stride=2)
out2 = overlap(x)
print(f"输入: {x.shape}")
print(f"非重叠池化 (k=2, s=2): {out1.shape}")
print(f"重叠池化 (k=3, s=2): {out2.shape}")
overlapping_pool_demo()
6.2 随机池化(Stochastic Pooling)
python
class StochasticPool2d(nn.Module):
"""
随机池化
训练时:按概率(归一化后的激活值)随机选择一个位置
测试时:使用期望值(等价于加权平均)
优点:
- 提供正则化效果
- 比最大池化更平滑
- 比平均池化保留更多显著特征
"""
def __init__(self, kernel_size=2, stride=2):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
def forward(self, x):
B, C, H, W = x.shape
k = self.kernel_size
s = self.stride
H_out = (H - k) // s + 1
W_out = (W - k) // s + 1
output = torch.zeros(B, C, H_out, W_out, device=x.device)
for i in range(H_out):
for j in range(W_out):
h_start = i * s
w_start = j * s
window = x[:, :, h_start:h_start+k, w_start:w_start+k]
if self.training:
# 训练时:随机采样
# 概率 = 归一化后的激活值
window_flat = window.reshape(B, C, -1)
# 确保非负
window_positive = F.relu(window_flat) + 1e-8
# 归一化为概率
probs = window_positive / window_positive.sum(dim=-1, keepdim=True)
# 按概率采样
indices = torch.multinomial(probs.view(-1, k*k), 1).view(B, C)
# 收集采样值
output[:, :, i, j] = window_flat.gather(
dim=-1,
index=indices.unsqueeze(-1)
).squeeze(-1)
else:
# 测试时:使用期望值(加权平均)
window_flat = window.reshape(B, C, -1)
window_positive = F.relu(window_flat) + 1e-8
probs = window_positive / window_positive.sum(dim=-1, keepdim=True)
# 期望值 = Σ p_i × x_i
output[:, :, i, j] = (probs * window_flat).sum(dim=-1)
return output
6.3 混合池化(Mixed Pooling)
python
class MixedPool2d(nn.Module):
"""
混合池化
结合最大池化和平均池化的优点
方式1:拼接
方式2:加权组合
方式3:可学习的组合权重
"""
def __init__(self, kernel_size=2, stride=2, mode='concat', alpha=0.5):
"""
Args:
mode: 'concat', 'weighted', 'learnable'
alpha: 加权模式的权重
"""
super().__init__()
self.mode = mode
self.max_pool = nn.MaxPool2d(kernel_size, stride)
self.avg_pool = nn.AvgPool2d(kernel_size, stride)
if mode == 'weighted':
self.alpha = alpha
elif mode == 'learnable':
# 可学习的权重
self.alpha = nn.Parameter(torch.tensor(0.5))
def forward(self, x):
max_out = self.max_pool(x)
avg_out = self.avg_pool(x)
if self.mode == 'concat':
# 拼接(通道数翻倍)
return torch.cat([max_out, avg_out], dim=1)
elif self.mode == 'weighted' or self.mode == 'learnable':
# 加权组合
alpha = torch.sigmoid(self.alpha) if self.mode == 'learnable' else self.alpha
return alpha * max_out + (1 - alpha) * avg_out
else:
raise ValueError(f"Unknown mode: {self.mode}")
def mixed_pool_demo():
"""混合池化演示"""
x = torch.randn(2, 64, 8, 8)
# 拼接模式
pool_concat = MixedPool2d(kernel_size=2, stride=2, mode='concat')
out_concat = pool_concat(x)
print(f"拼接模式: {x.shape} → {out_concat.shape}")
# 加权模式
pool_weighted = MixedPool2d(kernel_size=2, stride=2, mode='weighted', alpha=0.7)
out_weighted = pool_weighted(x)
print(f"加权模式: {x.shape} → {out_weighted.shape}")
# 可学习模式
pool_learnable = MixedPool2d(kernel_size=2, stride=2, mode='learnable')
out_learnable = pool_learnable(x)
print(f"可学习模式: {x.shape} → {out_learnable.shape}")
print(f" 学习到的alpha: {torch.sigmoid(pool_learnable.alpha).item():.4f}")
mixed_pool_demo()
6.4 RoI池化(Region of Interest Pooling)
python
class RoIPool(nn.Module):
"""
RoI池化
目标检测中使用,将不同大小的候选区域池化到固定大小
用于:Fast R-CNN, Faster R-CNN
"""
def __init__(self, output_size):
"""
Args:
output_size: 输出尺寸 (H, W)
"""
super().__init__()
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
self.output_size = output_size
def forward(self, features, rois):
"""
Args:
features: [B, C, H, W] 特征图
rois: [N, 5] 候选区域 [batch_idx, x1, y1, x2, y2]
Returns:
[N, C, H_out, W_out] 池化后的RoI特征
"""
N = rois.size(0)
C = features.size(1)
H_out, W_out = self.output_size
output = torch.zeros(N, C, H_out, W_out, device=features.device)
for n in range(N):
batch_idx = int(rois[n, 0])
x1, y1, x2, y2 = rois[n, 1:5]
# 提取RoI区域
roi_feature = features[batch_idx, :, int(y1):int(y2), int(x1):int(x2)]
# 自适应池化到固定大小
output[n] = F.adaptive_max_pool2d(roi_feature.unsqueeze(0), self.output_size).squeeze(0)
return output
# PyTorch内置的RoI操作
from torchvision.ops import roi_pool, roi_align
def roi_pool_demo():
"""RoI池化演示"""
# 特征图
features = torch.randn(1, 256, 50, 50)
# RoI候选区域 [batch_idx, x1, y1, x2, y2]
rois = torch.tensor([
[0, 10, 10, 30, 30], # 第一个RoI
[0, 20, 20, 45, 45], # 第二个RoI
], dtype=torch.float32)
# RoI池化
output = roi_pool(features, rois, output_size=(7, 7), spatial_scale=1.0)
print(f"特征图: {features.shape}")
print(f"RoI数量: {rois.shape[0]}")
print(f"RoI池化输出: {output.shape}")
roi_pool_demo()
七、池化的反向传播
7.1 最大池化的梯度
python
"""
最大池化的反向传播:
梯度只传给最大值位置,其他位置梯度为0
前向:
┌────┬────┐
│ 1 │ 3 │ → max → 4
├────┼────┤
│ 4 │ 2 │
└────┴────┘
反向(假设输出梯度为δ):
┌────┬────┐
│ 0 │ 0 │ ← δ只传给4的位置
├────┼────┤
│ δ │ 0 │
└────┴────┘
"""
class MaxPool2dWithGrad(torch.autograd.Function):
"""
带自定义梯度的最大池化
"""
@staticmethod
def forward(ctx, x, kernel_size, stride):
B, C, H, W = x.shape
k = kernel_size
s = stride
H_out = (H - k) // s + 1
W_out = (W - k) // s + 1
output = torch.zeros(B, C, H_out, W_out, device=x.device)
indices = torch.zeros(B, C, H_out, W_out, dtype=torch.long, device=x.device)
for i in range(H_out):
for j in range(W_out):
h_start = i * s
w_start = j * s
window = x[:, :, h_start:h_start+k, w_start:w_start+k]
window_flat = window.reshape(B, C, -1)
max_vals, max_indices = window_flat.max(dim=-1)
output[:, :, i, j] = max_vals
indices[:, :, i, j] = max_indices
# 保存用于反向传播
ctx.save_for_backward(indices)
ctx.input_shape = x.shape
ctx.kernel_size = k
ctx.stride = s
return output
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
B, C, H, W = ctx.input_shape
k = ctx.kernel_size
s = ctx.stride
grad_input = torch.zeros(ctx.input_shape, device=grad_output.device)
H_out, W_out = grad_output.shape[2], grad_output.shape[3]
for i in range(H_out):
for j in range(W_out):
h_start = i * s
w_start = j * s
for b in range(B):
for c in range(C):
idx = indices[b, c, i, j].item()
h_idx = idx // k
w_idx = idx % k
grad_input[b, c, h_start+h_idx, w_start+w_idx] += grad_output[b, c, i, j]
return grad_input, None, None
7.2 平均池化的梯度
python
"""
平均池化的反向传播:
梯度均匀分配给窗口内所有位置
前向:
┌────┬────┐
│ 1 │ 3 │ → avg → 2.5
├────┼────┤
│ 4 │ 2 │
└────┴────┘
反向(假设输出梯度为δ):
┌──────┬──────┐
│ δ/4 │ δ/4 │ ← δ平均分配给4个位置
├──────┼──────┤
│ δ/4 │ δ/4 │
└──────┴──────┘
"""
class AvgPool2dWithGrad(torch.autograd.Function):
"""
带自定义梯度的平均池化
"""
@staticmethod
def forward(ctx, x, kernel_size, stride):
B, C, H, W = x.shape
k = kernel_size
s = stride
H_out = (H - k) // s + 1
W_out = (W - k) // s + 1
output = torch.zeros(B, C, H_out, W_out, device=x.device)
for i in range(H_out):
for j in range(W_out):
h_start = i * s
w_start = j * s
window = x[:, :, h_start:h_start+k, w_start:w_start+k]
output[:, :, i, j] = window.mean(dim=(2, 3))
ctx.input_shape = x.shape
ctx.kernel_size = k
ctx.stride = s
return output
@staticmethod
def backward(ctx, grad_output):
B, C, H, W = ctx.input_shape
k = ctx.kernel_size
s = ctx.stride
grad_input = torch.zeros(ctx.input_shape, device=grad_output.device)
H_out, W_out = grad_output.shape[2], grad_output.shape[3]
for i in range(H_out):
for j in range(W_out):
h_start = i * s
w_start = j * s
# 梯度均匀分配
grad_input[:, :, h_start:h_start+k, w_start:w_start+k] += \
grad_output[:, :, i:i+1, j:j+1] / (k * k)
return grad_input, None, None
八、池化在现代网络中的应用
8.1 经典网络中的池化
python
"""
不同网络中池化的使用方式:
LeNet-5 (1998):
Conv → AvgPool → Conv → AvgPool → FC
使用平均池化
AlexNet (2012):
Conv → MaxPool(3×3, s=2) → ...
使用重叠最大池化
VGGNet (2014):
Conv → Conv → MaxPool(2×2, s=2) → ...
标准最大池化
GoogLeNet/Inception (2014):
在Inception模块中使用1×1, 3×3, 5×5卷积 + 3×3 MaxPool
最后使用GlobalAvgPool
ResNet (2015):
开始用7×7 Conv + 3×3 MaxPool
最后用GlobalAvgPool
MobileNet (2017):
最后用GlobalAvgPool
中间用stride=2的深度可分离卷积代替池化
EfficientNet (2019):
最后用GlobalAvgPool
中间使用stride代替显式池化
"""
class ClassicPoolingPatterns(nn.Module):
"""展示经典网络中的池化模式"""
def __init__(self):
super().__init__()
# VGG风格:Conv-Conv-Pool
self.vgg_block = nn.Sequential(
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2) # 标准2×2池化
)
# AlexNet风格:重叠池化
self.alexnet_pool = nn.MaxPool2d(kernel_size=3, stride=2)
# ResNet风格:开头的大步长池化
self.resnet_stem = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# 现代风格:全局池化
self.global_pool = nn.AdaptiveAvgPool2d(1)
8.2 无池化设计(Strided Convolution)
python
"""
现代趋势:用带步长的卷积代替池化
传统方式:
Conv(s=1) → Pool(s=2)
现代方式:
Conv(s=2) # 直接下采样
优点:
- 可学习的下采样
- 减少信息丢失
- 梯度流更顺畅
例如:ResNet的下采样、所有stride-2的卷积
"""
class StridedDownsample(nn.Module):
"""用步长卷积代替池化"""
def __init__(self, in_channels, out_channels):
super().__init__()
# 方式1:传统池化
self.pool_way = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
# 方式2:步长卷积
self.stride_way = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x, use_pool=True):
if use_pool:
return self.pool_way(x)
else:
return self.stride_way(x)
def compare_pool_vs_stride():
"""比较池化和步长卷积"""
x = torch.randn(1, 64, 32, 32)
block = StridedDownsample(64, 128)
out_pool = block(x, use_pool=True)
out_stride = block(x, use_pool=False)
print(f"输入: {x.shape}")
print(f"池化方式输出: {out_pool.shape}")
print(f"步长卷积输出: {out_stride.shape}")
# 参数量对比
pool_params = sum(p.numel() for p in block.pool_way.parameters())
stride_params = sum(p.numel() for p in block.stride_way.parameters())
print(f"\n参数量对比:")
print(f" 池化方式: {pool_params:,}")
print(f" 步长卷积: {stride_params:,}")
compare_pool_vs_stride()
8.3 注意力池化
python
class AttentionPool2d(nn.Module):
"""
注意力池化
使用注意力机制进行加权池化
用于CLIP等模型
"""
def __init__(self, in_channels, embed_dim, num_heads=8):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
# 位置编码
self.positional_embedding = nn.Parameter(
torch.randn(1, in_channels, 1, 1) / in_channels ** 0.5
)
# QKV投影
self.q_proj = nn.Linear(in_channels, embed_dim)
self.k_proj = nn.Linear(in_channels, embed_dim)
self.v_proj = nn.Linear(in_channels, embed_dim)
# 输出投影
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
"""
Args:
x: [B, C, H, W]
Returns:
[B, embed_dim]
"""
B, C, H, W = x.shape
# 添加位置编码
x = x + self.positional_embedding
# 展平空间维度:[B, C, H, W] → [B, H*W, C]
x = x.flatten(2).permute(0, 2, 1)
# 添加全局token(用于聚合信息)
global_token = x.mean(dim=1, keepdim=True) # [B, 1, C]
x = torch.cat([global_token, x], dim=1) # [B, 1+H*W, C]
# 计算Q, K, V
q = self.q_proj(x[:, :1]) # [B, 1, embed_dim] 只对全局token计算Q
k = self.k_proj(x) # [B, 1+H*W, embed_dim]
v = self.v_proj(x) # [B, 1+H*W, embed_dim]
# 注意力
attn_weights = torch.bmm(q, k.transpose(1, 2)) / (self.embed_dim ** 0.5)
attn_weights = F.softmax(attn_weights, dim=-1)
# 加权求和
out = torch.bmm(attn_weights, v) # [B, 1, embed_dim]
out = self.out_proj(out.squeeze(1))
return out
def attention_pool_demo():
"""注意力池化演示"""
x = torch.randn(2, 256, 7, 7)
attn_pool = AttentionPool2d(in_channels=256, embed_dim=512)
out = attn_pool(x)
print(f"输入: {x.shape}")
print(f"注意力池化输出: {out.shape}")
attention_pool_demo()
九、总结
9.1 池化层核心要点
┌─────────────────────────────────────────────────────────────────┐
│ 池化层总结 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 最大池化 (Max Pooling): │
│ • 取窗口内最大值 │
│ • 保留最显著特征 │
│ • 常用于卷积层之后 │
│ │
│ 平均池化 (Average Pooling): │
│ • 取窗口内平均值 │
│ • 保留整体信息 │
│ • 常用于全局池化 │
│ │
│ 全局池化 (Global Pooling): │
│ • 将特征图压缩为单个值 │
│ • 替代全连接层 │
│ • 支持任意输入尺寸 │
│ │
│ 自适应池化 (Adaptive Pooling): │
│ • 指定输出尺寸,自动计算参数 │
│ • 处理可变尺寸输入 │
│ │
└─────────────────────────────────────────────────────────────────┘
9.2 选择指南
| 场景 | 推荐池化 | 原因 |
|---|---|---|
| 卷积层后下采样 | MaxPool | 保留显著特征 |
| 网络末端 | GlobalAvgPool | 减少参数,支持可变尺寸 |
| 目标检测RoI | RoIPool/RoIAlign | 处理不同大小候选框 |
| 需要可学习下采样 | Strided Conv | 更灵活 |
| 多尺度特征 | SPP | 固定输出长度 |
9.3 一句话总结
池化是CNN的"压缩器":减少计算量、增大感受野、提供平移不变性,是特征提取的关键环节。
希望这篇文章帮助你深入理解了CNN中的池化层!如有问题,欢迎评论区交流。
参考文献:
- LeCun Y, et al. "Gradient-based learning applied to document recognition." 1998.
- Krizhevsky A, et al. "ImageNet Classification with Deep Convolutional Neural Networks." NeurIPS 2012.
- Lin M, et al. "Network In Network." ICLR 2014.
- He K, et al. "Spatial Pyramid Pooling in Deep Convolutional Networks." ECCV 2014.
作者:Jia
更多技术文章,欢迎关注我的CSDN博客!