Day 50 预训练模型+CBAM模块

@浙大疏锦行

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time

# CBAM 定义(从笔记复用)
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // ratio, bias=False),
            nn.ReLU(),
            nn.Linear(in_channels // ratio, in_channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.shape
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))
        attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)
        return x * attention

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        pool_out = torch.cat([avg_out, max_out], dim=1)
        attention = self.conv(pool_out)
        return x * self.sigmoid(attention)

class CBAM(nn.Module):
    def __init__(self, in_channels, ratio=16, kernel_size=7):
        super().__init__()
        self.channel_attn = ChannelAttention(in_channels, ratio)
        self.spatial_attn = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.channel_attn(x)
        x = self.spatial_attn(x)
        return x

# 数据预处理和加载(从笔记复用)
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 自定义 VGG16_CBAM 模型
class VGG16_CBAM(nn.Module):
    def __init__(self, num_classes=10, pretrained=True, cbam_ratio=16, cbam_kernel=7):
        super().__init__()
        # 加载预训练 VGG16
        self.backbone = models.vgg16(pretrained=pretrained)
        
        # 为 CIFAR-10 调整:移除 avgpool(用 Global Avg Pool 代替),修改 classifier
        self.backbone.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # 输出 1x1 以匹配小输入
        
        # 修改 classifier 为简单 Linear(避免过大参数)
        self.backbone.classifier = nn.Sequential(
            nn.Linear(512, num_classes)
        )
        
        # 在每个卷积块后插入 CBAM(对应通道: 64, 128, 256, 512, 512)
        self.cbam1 = CBAM(64, cbam_ratio, cbam_kernel)   # 块1 后
        self.cbam2 = CBAM(128, cbam_ratio, cbam_kernel)  # 块2 后
        self.cbam3 = CBAM(256, cbam_ratio, cbam_kernel)  # 块3 后
        self.cbam4 = CBAM(512, cbam_ratio, cbam_kernel)  # 块4 后
        self.cbam5 = CBAM(512, cbam_ratio, cbam_kernel)  # 块5 后

    def forward(self, x):
        features = self.backbone.features
        
        # 块1: features[0:4] (2 Conv64 + MaxPool) + CBAM
        x = features[0:4](x)  # 到 MaxPool 前
        x = self.cbam1(x)
        x = features[4](x)    # MaxPool
        
        # 块2: features[5:9] + CBAM + MaxPool
        x = features[5:9](x)
        x = self.cbam2(x)
        x = features[9](x)
        
        # 块3: features[10:16] + CBAM + MaxPool
        x = features[10:16](x)
        x = self.cbam3(x)
        x = features[16](x)
        
        # 块4: features[17:23] + CBAM + MaxPool
        x = features[17:23](x)
        x = self.cbam4(x)
        x = features[23](x)
        
        # 块5: features[24:30] + CBAM + MaxPool
        x = features[24:30](x)
        x = self.cbam5(x)
        x = features[30](x)
        
        # Head: AvgPool + Flatten + Classifier
        x = self.backbone.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.backbone.classifier(x)
        return x

# set_trainable_layers 函数调整(适应 VGG:用 "cbam" / "backbone.classifier" / "backbone.features.XX")
def set_trainable_layers(model, trainable_parts):
    print(f"\n---> 解冻以下部分并设为可训练: {trainable_parts}")
    for name, param in model.named_parameters():
        param.requires_grad = False
        for part in trainable_parts:
            if part in name:
                param.requires_grad = True
                break

# train_staged_finetuning 函数(从笔记复用,微调阶段定义)
def train_staged_finetuning(model, criterion, train_loader, test_loader, device, epochs):
    optimizer = None
    all_iter_losses, iter_indices = [], []
    train_acc_history, test_acc_history = [], []
    train_loss_history, test_loss_history = [], []
    
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        
        # 动态调整(适应 VGG:高层为 features.10: (块3-5))
        if epoch == 1:
            print("\n" + "="*50 + "\n🚀 **阶段 1:训练注意力模块和分类头**\n" + "="*50)
            set_trainable_layers(model, ["cbam", "backbone.classifier"])
            optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
        elif epoch == 6:
            print("\n" + "="*50 + "\n✈️ **阶段 2:解冻高层卷积层 (features[10:] - 块3-5)**\n" + "="*50)
            set_trainable_layers(model, ["cbam", "backbone.classifier", "backbone.features.1", "backbone.features.2"])  # features.10-30 的前缀
            optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
        elif epoch == 21:
            print("\n" + "="*50 + "\n🛰️ **阶段 3:解冻所有层,进行全局微调**\n" + "="*50)
            for param in model.parameters(): param.requires_grad = True
            optimizer = optim.Adam(model.parameters(), lr=1e-5)
        
        # 训练循环(从笔记复用)
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            iter_loss = loss.item()
            all_iter_losses.append(iter_loss)
            iter_indices.append((epoch - 1) * len(train_loader) + batch_idx + 1)
            
            running_loss += iter_loss
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
            if (batch_idx + 1) % 100 == 0:
                print(f'Epoch: {epoch}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} '
                      f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')
        
        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_acc = 100. * correct / total
        train_loss_history.append(epoch_train_loss)
        train_acc_history.append(epoch_train_acc)
        
        # 测试循环(从笔记复用)
        model.eval()
        test_loss, correct_test, total_test = 0, 0, 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                _, predicted = output.max(1)
                total_test += target.size(0)
                correct_test += predicted.eq(target).sum().item()
        
        epoch_test_loss = test_loss / len(test_loader)
        epoch_test_acc = 100. * correct_test / total_test
        test_loss_history.append(epoch_test_loss)
        test_acc_history.append(epoch_test_acc)
        
        print(f'Epoch {epoch}/{epochs} 完成 | 耗时: {time.time() - epoch_start_time:.2f}s | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')
    
    print("\n训练完成! 开始绘制结果图表...")
    plot_iter_losses(all_iter_losses, iter_indices)
    plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)
    
    return epoch_test_acc

# 绘图函数(从笔记复用)
def plot_iter_losses(losses, indices):
    plt.figure(figsize=(10, 4))
    plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')
    plt.xlabel('Iteration(Batch序号)')
    plt.ylabel('损失值')
    plt.title('每个 Iteration 的训练损失')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):
    epochs = range(1, len(train_acc) + 1)
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_acc, 'b-', label='训练准确率')
    plt.plot(epochs, test_acc, 'r-', label='测试准确率')
    plt.xlabel('Epoch')
    plt.ylabel('准确率 (%)')
    plt.title('训练和测试准确率')
    plt.legend(); plt.grid(True)
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_loss, 'b-', label='训练损失')
    plt.plot(epochs, test_loss, 'r-', label='测试损失')
    plt.xlabel('Epoch')
    plt.ylabel('损失值')
    plt.title('训练和测试损失')
    plt.legend(); plt.grid(True)
    plt.tight_layout()
    plt.show()

# 执行训练
model = VGG16_CBAM().to(device)
criterion = nn.CrossEntropyLoss()
epochs = 50
print("开始使用带分阶段微调策略的 VGG16 + CBAM 模型进行训练...")
final_accuracy = train_staged_finetuning(model, criterion, train_loader, test_loader, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")
# torch.save(model.state_dict(), 'vgg16_cbam_finetuned.pth')
# print("模型已保存为: vgg16_cbam_finetuned.pth")
相关推荐
好家伙VCC3 小时前
### WebRTC技术:实时通信的革新与实现####webRTC(Web Real-TimeComm
java·前端·python·webrtc
前端玖耀里4 小时前
如何使用python的boto库和SES发送电子邮件?
python
serve the people4 小时前
python环境搭建 (十二) pydantic和pydantic-settings类型验证与解析
java·网络·python
小天源4 小时前
Error 1053 Error 1067 服务“启动后立即停止” Java / Python 程序无法后台运行 windows nssm注册器下载与报错处理
开发语言·windows·python·nssm·error 1053·error 1067
喵手5 小时前
Python爬虫实战:HTTP缓存系统深度实战 — ETag、Last-Modified与requests-cache完全指南(附SQLite持久化存储)!
爬虫·python·爬虫实战·http缓存·etag·零基础python爬虫教学·requests-cache
喵手5 小时前
Python爬虫实战:容器化与定时调度实战 - Docker + Cron + 日志轮转 + 失败重试完整方案(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·容器化·零基础python爬虫教学·csv导出·定时调度
2601_949146535 小时前
Python语音通知接口接入教程:开发者快速集成AI语音API的脚本实现
人工智能·python·语音识别
寻梦csdn6 小时前
pycharm+miniconda兼容问题
ide·python·pycharm·conda
Java面试题总结7 小时前
基于 Java 的 PDF 文本水印实现方案(iText7 示例)
java·python·pdf
不懒不懒7 小时前
【决策树算法实战指南:从原理到Python实现】
python·决策树·id3·c4.5·catr