CNN池化层深度解析:从原理到PyTorch实现

本文深入剖析卷积神经网络中池化层的核心原理,涵盖最大池化、平均池化、全局池化、自适应池化等多种变体,从数学原理到完整代码实现,帮你彻底理解这个CNN的关键组件。


一、池化层概述

1.1 什么是池化

池化(Pooling) 是一种下采样操作,用于减小特征图的空间尺寸,同时保留重要特征。

复制代码
池化操作示意:

输入特征图 4×4                    输出特征图 2×2
┌────┬────┬────┬────┐            ┌────┬────┐
│ 1  │ 3  │ 2  │ 1  │            │    │    │
├────┼────┼────┼────┤   2×2池化  │ 4  │ 6  │
│ 4  │ 2  │ 1  │ 6  │  ───────→  ├────┼────┤
├────┼────┼────┼────┤            │    │    │
│ 1  │ 5  │ 3  │ 2  │            │ 5  │ 8  │
├────┼────┼────┼────┤            └────┴────┘
│ 2  │ 1  │ 8  │ 4  │
└────┴────┴────┴────┘

每个2×2区域取最大值(最大池化)

1.2 为什么需要池化

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                     池化层的核心作用                             │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  1. 降维(Dimensionality Reduction)                            │
│     ┌─────────────────────────────────────────────────────┐    │
│     │ 输入: 224×224 → 池化后: 112×112                      │    │
│     │ 特征图尺寸减半,计算量减少到1/4                       │    │
│     │ 让网络可以更深,感受野更大                            │    │
│     └─────────────────────────────────────────────────────┘    │
│                                                                 │
│  2. 平移不变性(Translation Invariance)                        │
│     ┌─────────────────────────────────────────────────────┐    │
│     │ 目标轻微移动时,池化后的特征保持不变                   │    │
│     │ 提高模型对位置变化的鲁棒性                            │    │
│     └─────────────────────────────────────────────────────┘    │
│                                                                 │
│  3. 防止过拟合(Regularization)                                │
│     ┌─────────────────────────────────────────────────────┐    │
│     │ 减少参数数量                                          │    │
│     │ 丢弃部分空间信息,保留最重要的特征                     │    │
│     └─────────────────────────────────────────────────────┘    │
│                                                                 │
│  4. 扩大感受野(Receptive Field)                               │
│     ┌─────────────────────────────────────────────────────┐    │
│     │ 池化后每个像素对应原图更大的区域                       │    │
│     │ 让后续卷积能"看到"更大范围的信息                      │    │
│     └─────────────────────────────────────────────────────┘    │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

1.3 池化在CNN中的位置

复制代码
典型CNN架构中池化的位置:

Input (224×224×3)
       │
       ▼
┌─────────────┐
│   Conv1     │  → 224×224×64
│   ReLU      │
│   Pool      │  → 112×112×64    ← 池化1:尺寸减半
└─────────────┘
       │
       ▼
┌─────────────┐
│   Conv2     │  → 112×112×128
│   ReLU      │
│   Pool      │  → 56×56×128     ← 池化2:尺寸再减半
└─────────────┘
       │
       ▼
┌─────────────┐
│   Conv3     │  → 56×56×256
│   ReLU      │
│   Pool      │  → 28×28×256     ← 池化3
└─────────────┘
       │
       ▼
    ...更多层...
       │
       ▼
┌─────────────┐
│ Global Pool │  → 1×1×512       ← 全局池化:变成向量
└─────────────┘
       │
       ▼
┌─────────────┐
│    FC       │  → 1000          ← 分类输出
└─────────────┘

二、最大池化(Max Pooling)

2.1 原理详解

最大池化取池化窗口内的最大值,保留最显著的特征。

复制代码
最大池化示意(2×2池化,步长2):

输入 4×4:
┌────┬────┬────┬────┐
│ 1  │ 3  │ 2  │ 1  │
├────┼────┼────┼────┤       窗口1: max(1,3,4,2) = 4
│ 4  │ 2  │ 1  │ 6  │       窗口2: max(2,1,1,6) = 6
├────┼────┼────┼────┤       窗口3: max(1,5,2,1) = 5
│ 1  │ 5  │ 3  │ 2  │       窗口4: max(3,2,8,4) = 8
├────┼────┼────┼────┤
│ 2  │ 1  │ 8  │ 4  │
└────┴────┴────┴────┘

输出 2×2:
┌────┬────┐
│ 4  │ 6  │
├────┼────┤
│ 5  │ 8  │
└────┴────┘

2.2 数学定义

复制代码
最大池化公式:

对于输入特征图 X,池化窗口大小 k×k,步长 s

输出位置 (i,j) 的值:

Y[i,j] = max { X[i×s+m, j×s+n] | m,n ∈ [0, k-1] }

例如 2×2池化,步长2:
Y[0,0] = max(X[0,0], X[0,1], X[1,0], X[1,1])
Y[0,1] = max(X[0,2], X[0,3], X[1,2], X[1,3])
...

2.3 最大池化的特性

复制代码
优点:
┌─────────────────────────────────────────────────────────────┐
│ • 保留最显著的特征(边缘、纹理等激活最强的位置)            │
│ • 提供一定的平移不变性                                      │
│ • 对噪声有一定的抑制作用                                    │
│ • 计算简单高效                                              │
└─────────────────────────────────────────────────────────────┘

直觉理解:
"这个区域有没有检测到边缘?" → 取最大值
只要区域内有一个位置检测到边缘,整个区域就"有边缘"

2.4 PyTorch实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class MaxPool2dManual:
    """
    手动实现最大池化
    """
    
    def __init__(self, kernel_size=2, stride=None, padding=0):
        """
        Args:
            kernel_size: 池化窗口大小
            stride: 步长,默认等于kernel_size
            padding: 填充大小
        """
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size
        self.padding = padding
    
    def forward(self, x):
        """
        前向传播
        
        Args:
            x: 输入张量 [B, C, H, W]
        Returns:
            output: 池化后的张量
            indices: 最大值位置索引(用于反池化)
        """
        B, C, H, W = x.shape
        k = self.kernel_size
        s = self.stride
        p = self.padding
        
        # 计算输出尺寸
        H_out = (H + 2*p - k) // s + 1
        W_out = (W + 2*p - k) // s + 1
        
        # 填充
        if p > 0:
            x = F.pad(x, (p, p, p, p), mode='constant', value=float('-inf'))
        
        # 初始化输出
        output = torch.zeros(B, C, H_out, W_out, device=x.device)
        indices = torch.zeros(B, C, H_out, W_out, dtype=torch.long, device=x.device)
        
        # 滑动窗口池化
        for i in range(H_out):
            for j in range(W_out):
                # 提取窗口
                h_start = i * s
                w_start = j * s
                window = x[:, :, h_start:h_start+k, w_start:w_start+k]
                
                # 取最大值
                window_flat = window.reshape(B, C, -1)
                max_vals, max_indices = window_flat.max(dim=2)
                
                output[:, :, i, j] = max_vals
                indices[:, :, i, j] = max_indices
        
        return output, indices
    
    def backward(self, grad_output, indices, input_shape):
        """
        反向传播
        
        最大池化的梯度只传给最大值位置,其他位置梯度为0
        """
        B, C, H, W = input_shape
        k = self.kernel_size
        s = self.stride
        
        grad_input = torch.zeros(input_shape, device=grad_output.device)
        
        H_out, W_out = grad_output.shape[2], grad_output.shape[3]
        
        for i in range(H_out):
            for j in range(W_out):
                h_start = i * s
                w_start = j * s
                
                for b in range(B):
                    for c in range(C):
                        # 找到最大值在窗口内的位置
                        idx = indices[b, c, i, j].item()
                        h_idx = idx // k
                        w_idx = idx % k
                        
                        # 梯度只传给最大值位置
                        grad_input[b, c, h_start+h_idx, w_start+w_idx] += grad_output[b, c, i, j]
        
        return grad_input


# 使用PyTorch内置实现
def max_pool_demo():
    """最大池化演示"""
    
    # 创建输入
    x = torch.tensor([
        [1, 3, 2, 1],
        [4, 2, 1, 6],
        [1, 5, 3, 2],
        [2, 1, 8, 4]
    ], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    
    print("输入:")
    print(x.squeeze())
    
    # 最大池化
    pool = nn.MaxPool2d(kernel_size=2, stride=2)
    output = pool(x)
    
    print("\n2×2最大池化输出:")
    print(output.squeeze())
    
    # 带索引的最大池化(用于反池化)
    pool_with_indices = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
    output, indices = pool_with_indices(x)
    
    print("\n最大值索引:")
    print(indices.squeeze())


max_pool_demo()

三、平均池化(Average Pooling)

3.1 原理详解

平均池化取池化窗口内的平均值,保留区域的整体特征。

复制代码
平均池化示意(2×2池化,步长2):

输入 4×4:
┌────┬────┬────┬────┐
│ 1  │ 3  │ 2  │ 1  │
├────┼────┼────┼────┤       窗口1: avg(1,3,4,2) = 2.5
│ 4  │ 2  │ 1  │ 6  │       窗口2: avg(2,1,1,6) = 2.5
├────┼────┼────┼────┤       窗口3: avg(1,5,2,1) = 2.25
│ 1  │ 5  │ 3  │ 2  │       窗口4: avg(3,2,8,4) = 4.25
├────┼────┼────┼────┤
│ 2  │ 1  │ 8  │ 4  │
└────┴────┴────┴────┘

输出 2×2:
┌──────┬──────┐
│ 2.5  │ 2.5  │
├──────┼──────┤
│ 2.25 │ 4.25 │
└──────┴──────┘

3.2 数学定义

复制代码
平均池化公式:

Y[i,j] = (1/k²) × Σ Σ X[i×s+m, j×s+n]
                  m  n

其中 m,n ∈ [0, k-1]

例如 2×2池化:
Y[0,0] = (X[0,0] + X[0,1] + X[1,0] + X[1,1]) / 4

3.3 最大池化 vs 平均池化

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                最大池化 vs 平均池化                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  最大池化 (Max Pooling):                                        │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ • 保留最显著的特征                                       │   │
│  │ • 对噪声更鲁棒                                           │   │
│  │ • 适合检测"是否存在某特征"                              │   │
│  │ • 常用于卷积层之后                                       │   │
│  │ • 例:边缘检测、纹理识别                                 │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  平均池化 (Average Pooling):                                    │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ • 保留区域的整体信息                                     │   │
│  │ • 更平滑的下采样                                         │   │
│  │ • 适合需要保留背景信息的任务                             │   │
│  │ • 常用于网络末端(全局平均池化)                         │   │
│  │ • 例:语义分割、全局特征提取                             │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  可视化对比:                                                   │
│                                                                 │
│  输入图像:  ■ □ □ □     最大池化: ■ □    平均池化: ▧ □        │
│            □ □ □ □               □ □              □ □        │
│            □ □ ■ □                                            │
│            □ □ □ □     保留突出特征      保留整体亮度          │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

3.4 PyTorch实现

python 复制代码
class AvgPool2dManual:
    """
    手动实现平均池化
    """
    
    def __init__(self, kernel_size=2, stride=None, padding=0):
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size
        self.padding = padding
    
    def forward(self, x):
        """
        前向传播
        """
        B, C, H, W = x.shape
        k = self.kernel_size
        s = self.stride
        p = self.padding
        
        # 计算输出尺寸
        H_out = (H + 2*p - k) // s + 1
        W_out = (W + 2*p - k) // s + 1
        
        # 填充
        if p > 0:
            x = F.pad(x, (p, p, p, p), mode='constant', value=0)
        
        # 初始化输出
        output = torch.zeros(B, C, H_out, W_out, device=x.device)
        
        # 滑动窗口池化
        for i in range(H_out):
            for j in range(W_out):
                h_start = i * s
                w_start = j * s
                window = x[:, :, h_start:h_start+k, w_start:w_start+k]
                
                # 取平均值
                output[:, :, i, j] = window.mean(dim=(2, 3))
        
        return output
    
    def backward(self, grad_output, input_shape):
        """
        反向传播
        
        平均池化的梯度均匀分配给窗口内所有位置
        """
        B, C, H, W = input_shape
        k = self.kernel_size
        s = self.stride
        
        grad_input = torch.zeros(input_shape, device=grad_output.device)
        
        H_out, W_out = grad_output.shape[2], grad_output.shape[3]
        
        for i in range(H_out):
            for j in range(W_out):
                h_start = i * s
                w_start = j * s
                
                # 梯度均匀分配
                grad_input[:, :, h_start:h_start+k, w_start:w_start+k] += \
                    grad_output[:, :, i:i+1, j:j+1] / (k * k)
        
        return grad_input


def avg_pool_demo():
    """平均池化演示"""
    
    x = torch.tensor([
        [1, 3, 2, 1],
        [4, 2, 1, 6],
        [1, 5, 3, 2],
        [2, 1, 8, 4]
    ], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    
    print("输入:")
    print(x.squeeze())
    
    # 平均池化
    pool = nn.AvgPool2d(kernel_size=2, stride=2)
    output = pool(x)
    
    print("\n2×2平均池化输出:")
    print(output.squeeze())


avg_pool_demo()

四、全局池化(Global Pooling)

4.1 原理详解

全局池化将整个特征图压缩成单个值,常用于CNN末端替代全连接层。

复制代码
全局池化示意:

输入特征图 7×7×512:
┌─────────────────────┐
│                     │
│    7×7 特征图       │  × 512个通道
│                     │
└─────────────────────┘
          │
          ▼ 全局平均池化
┌───┐
│ 1 │ × 512 = 512维向量
└───┘

每个通道的7×7特征图 → 取平均 → 1个数
512个通道 → 512维向量

4.2 为什么使用全局池化

复制代码
传统方式(全连接层):
7×7×512 → Flatten → 25088维 → FC(25088, 4096) → FC(4096, 1000)

问题:
• 参数量巨大:25088×4096 ≈ 1亿参数
• 容易过拟合
• 要求固定输入尺寸

全局池化方式:
7×7×512 → GlobalAvgPool → 512维 → FC(512, 1000)

优点:
• 参数量小:只需512×1000 ≈ 50万参数
• 不容易过拟合
• 可以处理任意输入尺寸
• 每个通道对应一个类别的响应(可解释性强)

4.3 全局平均池化 vs 全局最大池化

python 复制代码
class GlobalPooling(nn.Module):
    """
    全局池化层
    
    将 [B, C, H, W] → [B, C, 1, 1] 或 [B, C]
    """
    
    def __init__(self, pool_type='avg', flatten=True):
        """
        Args:
            pool_type: 'avg' 或 'max'
            flatten: 是否展平为 [B, C]
        """
        super().__init__()
        self.pool_type = pool_type
        self.flatten = flatten
    
    def forward(self, x):
        """
        Args:
            x: [B, C, H, W]
        Returns:
            [B, C] 或 [B, C, 1, 1]
        """
        if self.pool_type == 'avg':
            # 全局平均池化
            out = x.mean(dim=(2, 3))
        elif self.pool_type == 'max':
            # 全局最大池化
            out = x.amax(dim=(2, 3))
        elif self.pool_type == 'avg+max':
            # 同时使用两者(拼接)
            avg_out = x.mean(dim=(2, 3))
            max_out = x.amax(dim=(2, 3))
            out = torch.cat([avg_out, max_out], dim=1)
        else:
            raise ValueError(f"Unknown pool_type: {self.pool_type}")
        
        if not self.flatten:
            out = out.unsqueeze(-1).unsqueeze(-1)
        
        return out


def global_pool_demo():
    """全局池化演示"""
    
    # 模拟最后一层特征图
    x = torch.randn(2, 512, 7, 7)  # [B, C, H, W]
    
    print(f"输入形状: {x.shape}")
    
    # 全局平均池化
    gap = GlobalPooling(pool_type='avg')
    out_avg = gap(x)
    print(f"全局平均池化后: {out_avg.shape}")
    
    # 全局最大池化
    gmp = GlobalPooling(pool_type='max')
    out_max = gmp(x)
    print(f"全局最大池化后: {out_max.shape}")
    
    # 两者结合
    gp_both = GlobalPooling(pool_type='avg+max')
    out_both = gp_both(x)
    print(f"avg+max池化后: {out_both.shape}")


global_pool_demo()

4.4 Network in Network (NIN) 中的全局池化

python 复制代码
class NINClassifier(nn.Module):
    """
    Network in Network 风格的分类器
    
    使用全局平均池化替代全连接层
    每个通道直接对应一个类别
    """
    
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        # 1×1卷积将通道数变为类别数
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)
        
        # 全局平均池化
        self.gap = nn.AdaptiveAvgPool2d(1)
    
    def forward(self, x):
        """
        Args:
            x: [B, C, H, W] 特征图
        Returns:
            [B, num_classes] 分类logits
        """
        # 1×1卷积:[B, C, H, W] → [B, num_classes, H, W]
        x = self.conv(x)
        
        # 全局平均池化:[B, num_classes, H, W] → [B, num_classes, 1, 1]
        x = self.gap(x)
        
        # 展平:[B, num_classes]
        x = x.view(x.size(0), -1)
        
        return x


# 可视化通道响应
def visualize_channel_response():
    """
    可视化每个通道对应的类别响应
    """
    import matplotlib.pyplot as plt
    
    # 假设有10个类别
    model = NINClassifier(in_channels=512, num_classes=10)
    
    # 输入特征图
    x = torch.randn(1, 512, 14, 14)
    
    # 获取1×1卷积后的特征图(每个通道对应一个类别)
    with torch.no_grad():
        class_maps = model.conv(x)  # [1, 10, 14, 14]
    
    # 可视化每个类别的响应图
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    for i, ax in enumerate(axes.flat):
        ax.imshow(class_maps[0, i].numpy(), cmap='hot')
        ax.set_title(f'Class {i}')
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig('class_activation_maps.png')
    print("Class activation maps saved!")

五、自适应池化(Adaptive Pooling)

5.1 原理详解

自适应池化指定输出尺寸而非池化核大小,自动计算需要的参数。

复制代码
自适应池化示意:

"我不管输入多大,输出必须是 3×3"

输入 7×7:                    输入 13×13:
┌─────────────┐              ┌─────────────────┐
│             │              │                 │
│    7×7      │ Adaptive     │     13×13       │ Adaptive
│             │ Pool(3×3)    │                 │ Pool(3×3)
└─────────────┘     ↓        └─────────────────┘     ↓
              ┌─────────┐                      ┌─────────┐
              │   3×3   │                      │   3×3   │
              └─────────┘                      └─────────┘

自动计算池化核大小和步长

5.2 为什么需要自适应池化

复制代码
问题:传统池化要求固定输入尺寸

训练时: 输入224×224 → 固定的池化参数 → 输出7×7
测试时: 输入320×320 → 同样的池化参数 → 输出?(尺寸不对!)

解决方案:自适应池化

训练时: 输入224×224 → AdaptivePool(7×7) → 输出7×7
测试时: 输入320×320 → AdaptivePool(7×7) → 输出7×7 ✓

可以处理任意输入尺寸!

5.3 PyTorch实现

python 复制代码
class AdaptiveAvgPool2dManual:
    """
    手动实现自适应平均池化
    """
    
    def __init__(self, output_size):
        """
        Args:
            output_size: 输出尺寸 (H_out, W_out) 或单个数
        """
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            self.output_size = output_size
    
    def forward(self, x):
        """
        前向传播
        
        自动计算每个输出位置对应的输入区域
        """
        B, C, H, W = x.shape
        H_out, W_out = self.output_size
        
        output = torch.zeros(B, C, H_out, W_out, device=x.device)
        
        for i in range(H_out):
            for j in range(W_out):
                # 计算输入区域的起始和结束位置
                h_start = int(np.floor(i * H / H_out))
                h_end = int(np.ceil((i + 1) * H / H_out))
                w_start = int(np.floor(j * W / W_out))
                w_end = int(np.ceil((j + 1) * W / W_out))
                
                # 提取区域并取平均
                region = x[:, :, h_start:h_end, w_start:w_end]
                output[:, :, i, j] = region.mean(dim=(2, 3))
        
        return output


def adaptive_pool_demo():
    """自适应池化演示"""
    
    # 不同尺寸的输入
    sizes = [(7, 7), (14, 14), (13, 13), (32, 32)]
    
    # 自适应池化到3×3
    adaptive_pool = nn.AdaptiveAvgPool2d((3, 3))
    
    print("自适应池化演示(输出固定为3×3):")
    for h, w in sizes:
        x = torch.randn(1, 64, h, w)
        out = adaptive_pool(x)
        print(f"  输入 {h}×{w} → 输出 {out.shape[2]}×{out.shape[3]}")


adaptive_pool_demo()

5.4 SPP(空间金字塔池化)

python 复制代码
class SpatialPyramidPooling(nn.Module):
    """
    空间金字塔池化 (SPP)
    
    多个尺度的自适应池化,拼接得到固定长度的特征
    
    应用:
    - SPPNet:可以处理任意尺寸输入
    - 目标检测:提取不同尺度的特征
    """
    
    def __init__(self, pool_sizes=[1, 2, 4]):
        """
        Args:
            pool_sizes: 不同尺度的池化输出尺寸
                       例如 [1, 2, 4] 表示输出 1×1, 2×2, 4×4
        """
        super().__init__()
        self.pool_sizes = pool_sizes
        
        # 创建自适应池化层
        self.pools = nn.ModuleList([
            nn.AdaptiveMaxPool2d(size) for size in pool_sizes
        ])
    
    def forward(self, x):
        """
        Args:
            x: [B, C, H, W]
        Returns:
            [B, C × (1×1 + 2×2 + 4×4)] = [B, C × 21]
        """
        B, C, H, W = x.shape
        
        features = []
        for pool in self.pools:
            # 池化
            pooled = pool(x)  # [B, C, size, size]
            # 展平
            pooled = pooled.view(B, -1)  # [B, C × size × size]
            features.append(pooled)
        
        # 拼接所有尺度的特征
        output = torch.cat(features, dim=1)
        
        return output
    
    def get_output_size(self, in_channels):
        """计算输出特征维度"""
        total = sum(size * size for size in self.pool_sizes)
        return in_channels * total


def spp_demo():
    """SPP演示"""
    
    spp = SpatialPyramidPooling(pool_sizes=[1, 2, 4])
    
    # 不同尺寸的输入,输出维度相同
    print("SPP演示(输出维度固定):")
    for h, w in [(7, 7), (14, 14), (28, 28)]:
        x = torch.randn(2, 256, h, w)
        out = spp(x)
        print(f"  输入 {h}×{w}×256 → 输出 {out.shape}")
    
    print(f"\n输出维度计算: 256 × (1×1 + 2×2 + 4×4) = {spp.get_output_size(256)}")


spp_demo()

六、其他池化变体

6.1 重叠池化(Overlapping Pooling)

python 复制代码
class OverlappingPool(nn.Module):
    """
    重叠池化
    
    AlexNet中使用:kernel_size=3, stride=2
    池化窗口有重叠,保留更多信息
    """
    
    def __init__(self, kernel_size=3, stride=2):
        super().__init__()
        # stride < kernel_size 时产生重叠
        self.pool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride)
    
    def forward(self, x):
        return self.pool(x)


def overlapping_pool_demo():
    """重叠池化演示"""
    
    x = torch.randn(1, 64, 13, 13)
    
    # 非重叠池化:kernel_size=2, stride=2
    non_overlap = nn.MaxPool2d(kernel_size=2, stride=2)
    out1 = non_overlap(x)
    
    # 重叠池化:kernel_size=3, stride=2
    overlap = nn.MaxPool2d(kernel_size=3, stride=2)
    out2 = overlap(x)
    
    print(f"输入: {x.shape}")
    print(f"非重叠池化 (k=2, s=2): {out1.shape}")
    print(f"重叠池化 (k=3, s=2): {out2.shape}")


overlapping_pool_demo()

6.2 随机池化(Stochastic Pooling)

python 复制代码
class StochasticPool2d(nn.Module):
    """
    随机池化
    
    训练时:按概率(归一化后的激活值)随机选择一个位置
    测试时:使用期望值(等价于加权平均)
    
    优点:
    - 提供正则化效果
    - 比最大池化更平滑
    - 比平均池化保留更多显著特征
    """
    
    def __init__(self, kernel_size=2, stride=2):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
    
    def forward(self, x):
        B, C, H, W = x.shape
        k = self.kernel_size
        s = self.stride
        
        H_out = (H - k) // s + 1
        W_out = (W - k) // s + 1
        
        output = torch.zeros(B, C, H_out, W_out, device=x.device)
        
        for i in range(H_out):
            for j in range(W_out):
                h_start = i * s
                w_start = j * s
                window = x[:, :, h_start:h_start+k, w_start:w_start+k]
                
                if self.training:
                    # 训练时:随机采样
                    # 概率 = 归一化后的激活值
                    window_flat = window.reshape(B, C, -1)
                    
                    # 确保非负
                    window_positive = F.relu(window_flat) + 1e-8
                    
                    # 归一化为概率
                    probs = window_positive / window_positive.sum(dim=-1, keepdim=True)
                    
                    # 按概率采样
                    indices = torch.multinomial(probs.view(-1, k*k), 1).view(B, C)
                    
                    # 收集采样值
                    output[:, :, i, j] = window_flat.gather(
                        dim=-1, 
                        index=indices.unsqueeze(-1)
                    ).squeeze(-1)
                else:
                    # 测试时:使用期望值(加权平均)
                    window_flat = window.reshape(B, C, -1)
                    window_positive = F.relu(window_flat) + 1e-8
                    probs = window_positive / window_positive.sum(dim=-1, keepdim=True)
                    
                    # 期望值 = Σ p_i × x_i
                    output[:, :, i, j] = (probs * window_flat).sum(dim=-1)
        
        return output

6.3 混合池化(Mixed Pooling)

python 复制代码
class MixedPool2d(nn.Module):
    """
    混合池化
    
    结合最大池化和平均池化的优点
    
    方式1:拼接
    方式2:加权组合
    方式3:可学习的组合权重
    """
    
    def __init__(self, kernel_size=2, stride=2, mode='concat', alpha=0.5):
        """
        Args:
            mode: 'concat', 'weighted', 'learnable'
            alpha: 加权模式的权重
        """
        super().__init__()
        self.mode = mode
        
        self.max_pool = nn.MaxPool2d(kernel_size, stride)
        self.avg_pool = nn.AvgPool2d(kernel_size, stride)
        
        if mode == 'weighted':
            self.alpha = alpha
        elif mode == 'learnable':
            # 可学习的权重
            self.alpha = nn.Parameter(torch.tensor(0.5))
    
    def forward(self, x):
        max_out = self.max_pool(x)
        avg_out = self.avg_pool(x)
        
        if self.mode == 'concat':
            # 拼接(通道数翻倍)
            return torch.cat([max_out, avg_out], dim=1)
        
        elif self.mode == 'weighted' or self.mode == 'learnable':
            # 加权组合
            alpha = torch.sigmoid(self.alpha) if self.mode == 'learnable' else self.alpha
            return alpha * max_out + (1 - alpha) * avg_out
        
        else:
            raise ValueError(f"Unknown mode: {self.mode}")


def mixed_pool_demo():
    """混合池化演示"""
    
    x = torch.randn(2, 64, 8, 8)
    
    # 拼接模式
    pool_concat = MixedPool2d(kernel_size=2, stride=2, mode='concat')
    out_concat = pool_concat(x)
    print(f"拼接模式: {x.shape} → {out_concat.shape}")
    
    # 加权模式
    pool_weighted = MixedPool2d(kernel_size=2, stride=2, mode='weighted', alpha=0.7)
    out_weighted = pool_weighted(x)
    print(f"加权模式: {x.shape} → {out_weighted.shape}")
    
    # 可学习模式
    pool_learnable = MixedPool2d(kernel_size=2, stride=2, mode='learnable')
    out_learnable = pool_learnable(x)
    print(f"可学习模式: {x.shape} → {out_learnable.shape}")
    print(f"  学习到的alpha: {torch.sigmoid(pool_learnable.alpha).item():.4f}")


mixed_pool_demo()

6.4 RoI池化(Region of Interest Pooling)

python 复制代码
class RoIPool(nn.Module):
    """
    RoI池化
    
    目标检测中使用,将不同大小的候选区域池化到固定大小
    
    用于:Fast R-CNN, Faster R-CNN
    """
    
    def __init__(self, output_size):
        """
        Args:
            output_size: 输出尺寸 (H, W)
        """
        super().__init__()
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            self.output_size = output_size
    
    def forward(self, features, rois):
        """
        Args:
            features: [B, C, H, W] 特征图
            rois: [N, 5] 候选区域 [batch_idx, x1, y1, x2, y2]
        Returns:
            [N, C, H_out, W_out] 池化后的RoI特征
        """
        N = rois.size(0)
        C = features.size(1)
        H_out, W_out = self.output_size
        
        output = torch.zeros(N, C, H_out, W_out, device=features.device)
        
        for n in range(N):
            batch_idx = int(rois[n, 0])
            x1, y1, x2, y2 = rois[n, 1:5]
            
            # 提取RoI区域
            roi_feature = features[batch_idx, :, int(y1):int(y2), int(x1):int(x2)]
            
            # 自适应池化到固定大小
            output[n] = F.adaptive_max_pool2d(roi_feature.unsqueeze(0), self.output_size).squeeze(0)
        
        return output


# PyTorch内置的RoI操作
from torchvision.ops import roi_pool, roi_align

def roi_pool_demo():
    """RoI池化演示"""
    
    # 特征图
    features = torch.randn(1, 256, 50, 50)
    
    # RoI候选区域 [batch_idx, x1, y1, x2, y2]
    rois = torch.tensor([
        [0, 10, 10, 30, 30],  # 第一个RoI
        [0, 20, 20, 45, 45],  # 第二个RoI
    ], dtype=torch.float32)
    
    # RoI池化
    output = roi_pool(features, rois, output_size=(7, 7), spatial_scale=1.0)
    
    print(f"特征图: {features.shape}")
    print(f"RoI数量: {rois.shape[0]}")
    print(f"RoI池化输出: {output.shape}")


roi_pool_demo()

七、池化的反向传播

7.1 最大池化的梯度

python 复制代码
"""
最大池化的反向传播:

梯度只传给最大值位置,其他位置梯度为0

前向:
┌────┬────┐
│ 1  │ 3  │  → max → 4
├────┼────┤
│ 4  │ 2  │
└────┴────┘

反向(假设输出梯度为δ):
┌────┬────┐
│ 0  │ 0  │  ← δ只传给4的位置
├────┼────┤
│ δ  │ 0  │
└────┴────┘
"""

class MaxPool2dWithGrad(torch.autograd.Function):
    """
    带自定义梯度的最大池化
    """
    
    @staticmethod
    def forward(ctx, x, kernel_size, stride):
        B, C, H, W = x.shape
        k = kernel_size
        s = stride
        
        H_out = (H - k) // s + 1
        W_out = (W - k) // s + 1
        
        output = torch.zeros(B, C, H_out, W_out, device=x.device)
        indices = torch.zeros(B, C, H_out, W_out, dtype=torch.long, device=x.device)
        
        for i in range(H_out):
            for j in range(W_out):
                h_start = i * s
                w_start = j * s
                window = x[:, :, h_start:h_start+k, w_start:w_start+k]
                window_flat = window.reshape(B, C, -1)
                
                max_vals, max_indices = window_flat.max(dim=-1)
                output[:, :, i, j] = max_vals
                indices[:, :, i, j] = max_indices
        
        # 保存用于反向传播
        ctx.save_for_backward(indices)
        ctx.input_shape = x.shape
        ctx.kernel_size = k
        ctx.stride = s
        
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        indices, = ctx.saved_tensors
        B, C, H, W = ctx.input_shape
        k = ctx.kernel_size
        s = ctx.stride
        
        grad_input = torch.zeros(ctx.input_shape, device=grad_output.device)
        
        H_out, W_out = grad_output.shape[2], grad_output.shape[3]
        
        for i in range(H_out):
            for j in range(W_out):
                h_start = i * s
                w_start = j * s
                
                for b in range(B):
                    for c in range(C):
                        idx = indices[b, c, i, j].item()
                        h_idx = idx // k
                        w_idx = idx % k
                        
                        grad_input[b, c, h_start+h_idx, w_start+w_idx] += grad_output[b, c, i, j]
        
        return grad_input, None, None

7.2 平均池化的梯度

python 复制代码
"""
平均池化的反向传播:

梯度均匀分配给窗口内所有位置

前向:
┌────┬────┐
│ 1  │ 3  │  → avg → 2.5
├────┼────┤
│ 4  │ 2  │
└────┴────┘

反向(假设输出梯度为δ):
┌──────┬──────┐
│ δ/4  │ δ/4  │  ← δ平均分配给4个位置
├──────┼──────┤
│ δ/4  │ δ/4  │
└──────┴──────┘
"""

class AvgPool2dWithGrad(torch.autograd.Function):
    """
    带自定义梯度的平均池化
    """
    
    @staticmethod
    def forward(ctx, x, kernel_size, stride):
        B, C, H, W = x.shape
        k = kernel_size
        s = stride
        
        H_out = (H - k) // s + 1
        W_out = (W - k) // s + 1
        
        output = torch.zeros(B, C, H_out, W_out, device=x.device)
        
        for i in range(H_out):
            for j in range(W_out):
                h_start = i * s
                w_start = j * s
                window = x[:, :, h_start:h_start+k, w_start:w_start+k]
                output[:, :, i, j] = window.mean(dim=(2, 3))
        
        ctx.input_shape = x.shape
        ctx.kernel_size = k
        ctx.stride = s
        
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        B, C, H, W = ctx.input_shape
        k = ctx.kernel_size
        s = ctx.stride
        
        grad_input = torch.zeros(ctx.input_shape, device=grad_output.device)
        
        H_out, W_out = grad_output.shape[2], grad_output.shape[3]
        
        for i in range(H_out):
            for j in range(W_out):
                h_start = i * s
                w_start = j * s
                
                # 梯度均匀分配
                grad_input[:, :, h_start:h_start+k, w_start:w_start+k] += \
                    grad_output[:, :, i:i+1, j:j+1] / (k * k)
        
        return grad_input, None, None

八、池化在现代网络中的应用

8.1 经典网络中的池化

python 复制代码
"""
不同网络中池化的使用方式:

LeNet-5 (1998):
    Conv → AvgPool → Conv → AvgPool → FC
    使用平均池化

AlexNet (2012):
    Conv → MaxPool(3×3, s=2) → ...
    使用重叠最大池化

VGGNet (2014):
    Conv → Conv → MaxPool(2×2, s=2) → ...
    标准最大池化

GoogLeNet/Inception (2014):
    在Inception模块中使用1×1, 3×3, 5×5卷积 + 3×3 MaxPool
    最后使用GlobalAvgPool

ResNet (2015):
    开始用7×7 Conv + 3×3 MaxPool
    最后用GlobalAvgPool
    
MobileNet (2017):
    最后用GlobalAvgPool
    中间用stride=2的深度可分离卷积代替池化

EfficientNet (2019):
    最后用GlobalAvgPool
    中间使用stride代替显式池化
"""

class ClassicPoolingPatterns(nn.Module):
    """展示经典网络中的池化模式"""
    
    def __init__(self):
        super().__init__()
        
        # VGG风格:Conv-Conv-Pool
        self.vgg_block = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # 标准2×2池化
        )
        
        # AlexNet风格:重叠池化
        self.alexnet_pool = nn.MaxPool2d(kernel_size=3, stride=2)
        
        # ResNet风格:开头的大步长池化
        self.resnet_stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # 现代风格:全局池化
        self.global_pool = nn.AdaptiveAvgPool2d(1)

8.2 无池化设计(Strided Convolution)

python 复制代码
"""
现代趋势:用带步长的卷积代替池化

传统方式:
    Conv(s=1) → Pool(s=2)
    
现代方式:
    Conv(s=2)  # 直接下采样

优点:
- 可学习的下采样
- 减少信息丢失
- 梯度流更顺畅

例如:ResNet的下采样、所有stride-2的卷积
"""

class StridedDownsample(nn.Module):
    """用步长卷积代替池化"""
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        # 方式1:传统池化
        self.pool_way = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        # 方式2:步长卷积
        self.stride_way = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x, use_pool=True):
        if use_pool:
            return self.pool_way(x)
        else:
            return self.stride_way(x)


def compare_pool_vs_stride():
    """比较池化和步长卷积"""
    
    x = torch.randn(1, 64, 32, 32)
    
    block = StridedDownsample(64, 128)
    
    out_pool = block(x, use_pool=True)
    out_stride = block(x, use_pool=False)
    
    print(f"输入: {x.shape}")
    print(f"池化方式输出: {out_pool.shape}")
    print(f"步长卷积输出: {out_stride.shape}")
    
    # 参数量对比
    pool_params = sum(p.numel() for p in block.pool_way.parameters())
    stride_params = sum(p.numel() for p in block.stride_way.parameters())
    
    print(f"\n参数量对比:")
    print(f"  池化方式: {pool_params:,}")
    print(f"  步长卷积: {stride_params:,}")


compare_pool_vs_stride()

8.3 注意力池化

python 复制代码
class AttentionPool2d(nn.Module):
    """
    注意力池化
    
    使用注意力机制进行加权池化
    用于CLIP等模型
    """
    
    def __init__(self, in_channels, embed_dim, num_heads=8):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        
        # 位置编码
        self.positional_embedding = nn.Parameter(
            torch.randn(1, in_channels, 1, 1) / in_channels ** 0.5
        )
        
        # QKV投影
        self.q_proj = nn.Linear(in_channels, embed_dim)
        self.k_proj = nn.Linear(in_channels, embed_dim)
        self.v_proj = nn.Linear(in_channels, embed_dim)
        
        # 输出投影
        self.out_proj = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        """
        Args:
            x: [B, C, H, W]
        Returns:
            [B, embed_dim]
        """
        B, C, H, W = x.shape
        
        # 添加位置编码
        x = x + self.positional_embedding
        
        # 展平空间维度:[B, C, H, W] → [B, H*W, C]
        x = x.flatten(2).permute(0, 2, 1)
        
        # 添加全局token(用于聚合信息)
        global_token = x.mean(dim=1, keepdim=True)  # [B, 1, C]
        x = torch.cat([global_token, x], dim=1)     # [B, 1+H*W, C]
        
        # 计算Q, K, V
        q = self.q_proj(x[:, :1])    # [B, 1, embed_dim] 只对全局token计算Q
        k = self.k_proj(x)           # [B, 1+H*W, embed_dim]
        v = self.v_proj(x)           # [B, 1+H*W, embed_dim]
        
        # 注意力
        attn_weights = torch.bmm(q, k.transpose(1, 2)) / (self.embed_dim ** 0.5)
        attn_weights = F.softmax(attn_weights, dim=-1)
        
        # 加权求和
        out = torch.bmm(attn_weights, v)  # [B, 1, embed_dim]
        out = self.out_proj(out.squeeze(1))
        
        return out


def attention_pool_demo():
    """注意力池化演示"""
    
    x = torch.randn(2, 256, 7, 7)
    
    attn_pool = AttentionPool2d(in_channels=256, embed_dim=512)
    out = attn_pool(x)
    
    print(f"输入: {x.shape}")
    print(f"注意力池化输出: {out.shape}")


attention_pool_demo()

九、总结

9.1 池化层核心要点

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                        池化层总结                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  最大池化 (Max Pooling):                                        │
│  • 取窗口内最大值                                               │
│  • 保留最显著特征                                               │
│  • 常用于卷积层之后                                             │
│                                                                 │
│  平均池化 (Average Pooling):                                    │
│  • 取窗口内平均值                                               │
│  • 保留整体信息                                                 │
│  • 常用于全局池化                                               │
│                                                                 │
│  全局池化 (Global Pooling):                                     │
│  • 将特征图压缩为单个值                                         │
│  • 替代全连接层                                                 │
│  • 支持任意输入尺寸                                             │
│                                                                 │
│  自适应池化 (Adaptive Pooling):                                 │
│  • 指定输出尺寸,自动计算参数                                   │
│  • 处理可变尺寸输入                                             │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

9.2 选择指南

场景 推荐池化 原因
卷积层后下采样 MaxPool 保留显著特征
网络末端 GlobalAvgPool 减少参数,支持可变尺寸
目标检测RoI RoIPool/RoIAlign 处理不同大小候选框
需要可学习下采样 Strided Conv 更灵活
多尺度特征 SPP 固定输出长度

9.3 一句话总结

池化是CNN的"压缩器":减少计算量、增大感受野、提供平移不变性,是特征提取的关键环节。

希望这篇文章帮助你深入理解了CNN中的池化层!如有问题,欢迎评论区交流。


参考文献

  1. LeCun Y, et al. "Gradient-based learning applied to document recognition." 1998.
  2. Krizhevsky A, et al. "ImageNet Classification with Deep Convolutional Neural Networks." NeurIPS 2012.
  3. Lin M, et al. "Network In Network." ICLR 2014.
  4. He K, et al. "Spatial Pyramid Pooling in Deep Convolutional Networks." ECCV 2014.

作者:Jia

更多技术文章,欢迎关注我的CSDN博客!

相关推荐
2501_948120152 小时前
语音识别在儿科医疗语音交互中的应用
人工智能·交互·语音识别
星爷AG I2 小时前
9-4 大小知觉(AGI基础理论)
人工智能·agi
User_芊芊君子2 小时前
听歌不再只存于耳机!MusicCard 解锁音乐分享新方式,cpolar局域网外访问更自由
人工智能·ai·测评
小柔说科技2 小时前
AI销售机器人助理是做什么的?AI销售客服源码系统怎么收费?销冠留不住?
人工智能·ai·软件开发
小北方城市网2 小时前
微服务接口熔断降级与限流实战:保障系统高可用
java·spring boot·python·rabbitmq·java-rabbitmq·数据库架构
love530love2 小时前
【避坑指南】提示词“闹鬼”?Stable Diffusion 自动注入神秘词汇 xiao yi xian 排查全记录
人工智能·windows·stable diffusion·model keyword
2401_841495642 小时前
【强化学习】DQN 改进算法
人工智能·python·深度学习·强化学习·dqn·double dqn·dueling dqn
幸福清风2 小时前
【Python】实战记录:从零搭建 Django + Vue 全栈应用 —— 用户认证篇
vue.js·python·django
故乡de云2 小时前
Gemini API的数据隔离:企业级AI应用的安全感从哪来?
大数据·人工智能