深度学习篇---SENet网络结构

在 PyTorch 中实现 SENet(Squeeze-and-Excitation Networks),核心是实现它的 "通道注意力机制"------ 通过 SE 模块给每个特征通道分配重要性权重,增强有用特征、抑制无用特征。我们从最基础的 SE 模块开始,一步步嵌入到 ResNet 中(形成 SE-ResNet),确保你能理解每个环节的作用。

一、先明确 SENet 的核心结构

SENet 的本质是 "在现有 CNN 中嵌入 SE 模块",以经典的 SE-ResNet 为例,结构可以概括为:

复制代码
输入图像 → 
初始卷积层 → 池化层 → 
多个残差块(每个残差块后嵌入SE模块) → 
全局平均池化 → 全连接层(输出类别)

其中,SE 模块(Squeeze-and-Excitation Module)是核心组件,由 "压缩→激励→重标定" 三步组成。

二、PyTorch 实现 SENet 的步骤

步骤 1:导入必要的库

复制代码
import torch  # 核心库
import torch.nn as nn  # 神经网络层
import torch.optim as optim  # 优化器
from torch.utils.data import DataLoader  # 数据加载器
from torchvision import datasets, transforms  # 图像数据处理

步骤 2:实现核心组件 ------SE 模块

SE 模块是 SENet 的灵魂,包含三个关键步骤:压缩(Squeeze)、激励(Excitation)、重标定(Scale):

python 复制代码
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        # 1. 压缩(Squeeze):全局平均池化,将H×W×C压缩为1×1×C
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        # 2. 激励(Excitation):两个全连接层,学习通道权重
        self.fc = nn.Sequential(
            # 降维:减少计算量(通道数从C→C/reduction)
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            # 升维:恢复通道数(从C/reduction→C)
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()  # 将权重压缩到0~1之间
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()  # 获取批次大小b和通道数c
        
        # 压缩:全局平均池化 → 输出形状:(b, c, 1, 1)
        y = self.avg_pool(x)
        # 拉平成向量:(b, c)
        y = y.view(b, c)
        
        # 激励:学习权重 → 输出形状:(b, c),每个元素是对应通道的权重
        y = self.fc(y)
        # 调整形状:(b, c, 1, 1),方便后续广播乘法
        y = y.view(b, c, 1, 1)
        
        # 3. 重标定:将权重与原特征图相乘(通道级加权)
        return x * y  # 广播机制:每个通道的所有像素都乘以该通道的权重

参数解释

  • channel:输入特征图的通道数;
  • reduction:降维系数(论文中推荐 16),用于减少全连接层的计算量(如通道数 256→256/16=16)。

通俗理解

  • 压缩步骤:给每个通道打一个 "全局平均分"(比如 "猫眼睛" 通道得分高,"背景噪音" 通道得分低);
  • 激励步骤:根据平均分学习每个通道的 "重要性权重"(0~1 之间);
  • 重标定步骤:用权重调整原特征(重要通道放大,不重要通道缩小)。

步骤 3:实现 SE-ResNet 的残差块(嵌入 SE 模块)

SENet 通常基于 ResNet 改进,我们以 ResNet 的 "瓶颈残差块"(Bottleneck)为例,在其中嵌入 SE 模块:

python 复制代码
class SEBottleneck(nn.Module):
    expansion = 4  # 残差块输出通道数是输入的4倍(ResNet瓶颈块的特性)
    
    def __init__(self, in_channels, out_channels, stride=1, downsample=None, reduction=16):
        super(SEBottleneck, self).__init__()
        # 1×1卷积:降维(减少计算量)
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=1, stride=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        # 3×3卷积:提取特征(主卷积层)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=stride, 
            padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # 1×1卷积:升维(恢复通道数)
        self.conv3 = nn.Conv2d(
            out_channels, out_channels * self.expansion, kernel_size=1, 
            stride=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        
        self.relu = nn.ReLU(inplace=True)
        
        # 嵌入SE模块(在特征提取后、残差连接前)
        self.se = SEBlock(out_channels * self.expansion, reduction)
        
        # 下采样(当输入输出尺寸/通道不同时使用,确保残差连接维度匹配)
        self.downsample = downsample
        self.stride = stride
    
    def forward(self, x):
        identity = x  # 残差连接的捷径分支
        
        # 主分支:特征提取
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        out = self.conv3(out)
        out = self.bn3(out)
        
        # 关键:通过SE模块对特征加权
        out = self.se(out)
        
        # 残差连接:主分支输出 + 捷径分支
        if self.downsample is not None:
            identity = self.downsample(x)  # 下采样调整捷径分支维度
        
        out += identity
        out = self.relu(out)
        
        return out

结构解释

  • 这是 ResNet 的 "瓶颈残差块"(1×1 Conv→3×3 Conv→1×1 Conv),在最后添加了 SE 模块;
  • SE 模块位于 "特征提取完成后、残差连接前",确保加权后的特征与捷径分支融合;
  • expansion=4表示输出通道数是中间 3×3 卷积通道数的 4 倍(如中间 3×3 用 64 通道,输出则为 256 通道)。

步骤 4:搭建 SE-ResNet 完整网络

以 SE-ResNet50 为例(50 层),由 4 个残差块组组成,每组分别包含 3、4、6、3 个 SEBottleneck:

python 复制代码
class SEResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, reduction=16):
        super(SEResNet, self).__init__()
        self.in_channels = 64  # 初始卷积后的通道数
        
        # 初始卷积层
        self.conv1 = nn.Conv2d(
            3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # 尺寸减半
        
        # 4个残差块组(每个组的输出通道数翻倍)
        self.layer1 = self._make_layer(
            block, 64, layers[0], reduction=reduction
        )  # 输出通道:64×4=256
        self.layer2 = self._make_layer(
            block, 128, layers[1], stride=2, reduction=reduction
        )  # 输出通道:128×4=512
        self.layer3 = self._make_layer(
            block, 256, layers[2], stride=2, reduction=reduction
        )  # 输出通道:256×4=1024
        self.layer4 = self._make_layer(
            block, 512, layers[3], stride=2, reduction=reduction
        )  # 输出通道:512×4=2048
        
        # 分类部分
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # 全局平均池化
        self.fc = nn.Linear(512 * block.expansion, num_classes)  # 全连接层
    
    def _make_layer(self, block, out_channels, blocks, stride=1, reduction=16):
        """创建一个残差块组"""
        downsample = None
        # 当步长>1或输入输出通道不同时,需要下采样调整捷径分支
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.in_channels, out_channels * block.expansion,
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels * block.expansion),
            )
        
        layers = []
        # 添加第一个残差块(可能需要下采样)
        layers.append(
            block(self.in_channels, out_channels, stride, downsample, reduction)
        )
        self.in_channels = out_channels * block.expansion  # 更新输入通道数
        
        # 添加剩余的残差块(步长=1,不需要下采样)
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels, reduction=reduction))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # 初始卷积和池化
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # 4个残差块组(每个块都包含SE模块)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # 分类
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)  # 拉平成向量
        x = self.fc(x)
        
        return x

SE-ResNet50 配置

通过layers=[3,4,6,3]定义 4 个残差块组的数量,总层数计算为:
3+4+6+3=16个残差块 × 3 层卷积 / 块 + 初始卷积层 + 全连接层 ≈ 50 层。

步骤 5:初始化 SE-ResNet50 模型

用上面定义的模块组装 SE-ResNet50:

python 复制代码
def se_resnet50(num_classes=1000, reduction=16):
    """创建SE-ResNet50模型"""
    return SEResNet(SEBottleneck, [3, 4, 6, 3], num_classes, reduction)

步骤 6:准备数据(用 CIFAR-10 演示)

SE-ResNet 适合高精度分类任务,我们用 CIFAR-10(10 类)演示,输入尺寸调整为 224×224:

python 复制代码
# 数据预处理:缩放+裁剪+翻转+标准化
transform = transforms.Compose([
    transforms.Resize(256),  # 缩放为256×256
    transforms.RandomCrop(224),  # 随机裁剪成224×224
    transforms.RandomHorizontalFlip(),  # 数据增强
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
])

# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)

# 批量加载数据
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

步骤 7:初始化模型、损失函数和优化器

python 复制代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化SE-ResNet50,输出10类(CIFAR-10)
model = se_resnet50(num_classes=10, reduction=16).to(device)

criterion = nn.CrossEntropyLoss()  # 交叉熵损失
# 优化器:推荐用SGD+动量
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

步骤 8:训练和测试函数

SE-ResNet 的训练逻辑与普通 ResNet 类似,由于 SE 模块计算量极小,训练速度几乎不受影响:

python 复制代码
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()  # 清空梯度
        output = model(data)   # 模型预测
        loss = criterion(output, target)  # 计算损失
        loss.backward()        # 反向传播
        optimizer.step()       # 更新参数
        
        # 打印进度
        if batch_idx % 50 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')

def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    print(f'Test Accuracy: {100 * correct / total:.2f}%')

步骤 9:开始训练和测试

SE-ResNet50 收敛速度与普通 ResNet50 相近,建议训练 30-50 轮:

python 复制代码
for epoch in range(1, 31):
    train(model, train_loader, criterion, optimizer, epoch)
    test(model, test_loader)

在 CIFAR-10 上,SE-ResNet50 的准确率比普通 ResNet50 高 1-2%,体现了 SE 模块的特征增强效果。

三、完整代码总结

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 1. 实现SE模块(Squeeze-and-Excitation)
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        # 压缩:全局平均池化
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # 激励:两个全连接层学习通道权重
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        # 压缩:(b, c, h, w) → (b, c, 1, 1) → (b, c)
        y = self.avg_pool(x).view(b, c)
        # 激励:学习权重 → (b, c) → (b, c, 1, 1)
        y = self.fc(y).view(b, c, 1, 1)
        # 重标定:通道级加权
        return x * y

# 2. 实现SE-ResNet的残差块(嵌入SE模块)
class SEBottleneck(nn.Module):
    expansion = 4  # 输出通道是中间通道的4倍
    
    def __init__(self, in_channels, out_channels, stride=1, downsample=None, reduction=16):
        super(SEBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=stride,
            padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.conv3 = nn.Conv2d(
            out_channels, out_channels * self.expansion, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        
        self.relu = nn.ReLU(inplace=True)
        self.se = SEBlock(out_channels * self.expansion, reduction)  # 嵌入SE模块
        self.downsample = downsample
        self.stride = stride
    
    def forward(self, x):
        identity = x
        
        # 主分支特征提取
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        out = self.conv3(out)
        out = self.bn3(out)
        
        # SE模块加权
        out = self.se(out)
        
        # 残差连接
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        
        return out

# 3. 搭建SE-ResNet完整网络
class SEResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, reduction=16):
        super(SEResNet, self).__init__()
        self.in_channels = 64
        
        # 初始卷积层
        self.conv1 = nn.Conv2d(
            3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # 4个残差块组
        self.layer1 = self._make_layer(block, 64, layers[0], reduction=reduction)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, reduction=reduction)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, reduction=reduction)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, reduction=reduction)
        
        # 分类部分
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
    
    def _make_layer(self, block, out_channels, blocks, stride=1, reduction=16):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.in_channels, out_channels * block.expansion,
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels * block.expansion),
            )
        
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample, reduction))
        self.in_channels = out_channels * block.expansion
        
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels, reduction=reduction))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x

# 4. 初始化SE-ResNet50模型
def se_resnet50(num_classes=1000, reduction=16):
    return SEResNet(SEBottleneck, [3, 4, 6, 3], num_classes, reduction)

# 5. 准备CIFAR-10数据
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# 6. 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = se_resnet50(num_classes=10, reduction=16).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# 7. 训练函数
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 50 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')

# 8. 测试函数
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print(f'Test Accuracy: {100 * correct / total:.2f}%')

# 9. 开始训练和测试
for epoch in range(1, 31):
    train(model, train_loader, criterion, optimizer, epoch)
    test(model, test_loader)
    

四、关键知识点回顾

  1. SE 模块核心逻辑:通过 "压缩(全局平均池化)→激励(全连接层学习权重)→重标定(通道级加权)" 三步,让模型自动关注重要特征通道;
  2. 嵌入方式:SE 模块通常放在残差块的 "特征提取后、残差连接前",确保加权后的优质特征参与后续计算;
  3. 参数设置
    • reduction=16:降维系数,平衡计算量和性能(值越大,计算量越小,但可能损失精度);
    • SE-ResNet50 的layers=[3,4,6,3]:定义 4 个残差块组的数量,总层数约 50 层;
  4. 优势:计算量极小(额外参数仅 0.03%),兼容性强(可嵌入任何 CNN),精度提升明显(比普通 ResNet 高 1-2%)。

通过这段代码,你能亲手实现这个 "智能特征管家",感受注意力机制如何用极小代价提升模型性能!

相关推荐
n12352353 小时前
AI IDE+AI 辅助编程,真能让程序员 “告别 996” 吗?
ide·人工智能
漠缠3 小时前
Android AI客户端开发(语音与大模型部署)面试题大全
android·人工智能
连合机器人3 小时前
当有鹿机器人读懂城市呼吸的韵律——具身智能如何重构户外清洁生态
人工智能·ai·设备租赁·连合直租·智能清洁专家·有鹿巡扫机器人
良策金宝AI3 小时前
当电力设计遇上AI:良策金宝AI如何重构行业效率边界?
人工智能·光伏·电力工程
数科星球3 小时前
AI重构出海营销:HeadAI如何用“滴滴模式”破解红人营销效率困局?
大数据·人工智能
THMAIL3 小时前
机器学习从入门到精通 - 机器学习调参终极手册:网格搜索、贝叶斯优化实战
人工智能·python·算法·机器学习·支持向量机·数据挖掘·逻辑回归
摆烂工程师4 小时前
Anthropic 停止 Claude 提供给多数股权由中国资本持有的集团或其子公司使用,会给国内的AI生态带来什么影响?
人工智能·程序员·claude
ai绘画-安安妮4 小时前
Agentic AI 架构全解析:到底什么是Agentic AI?它是如何工作的
人工智能·ai·语言模型·自然语言处理·程序员·大模型·转行
洞见AI新未来5 小时前
Stable Diffusion XL 1.0实战:AI绘画从“能看”到“好看”的全面升级指南
人工智能