DAY49 CBAM注意力

@浙大疏锦行

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# 设置matplotlib中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

# 1. 定义CBAM模块
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        return self.sigmoid(out)

class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        # 通道注意力
        x = x * self.channel_attention(x)
        # 空间注意力
        x = x * self.spatial_attention(x)
        return x

# 2. 带有CBAM的残差块
class BasicBlockWithCBAM(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlockWithCBAM, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.cbam = CBAM(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                         stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = self.shortcut(x)
        
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        # 应用CBAM注意力
        out = self.cbam(out)
        
        out += identity
        out = F.relu(out)
        return out

# 3. 完整的网络模型
class CBAMNet(nn.Module):
    def __init__(self, num_classes=10):
        super(CBAMNet, self).__init__()
        
        # 初始卷积层
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        # 残差块
        self.layer1 = self._make_layer(64, 64, 2, stride=1)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)
        
        # 全局平均池化和全连接层
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        
        # 初始化权重
        self._initialize_weights()

    def _make_layer(self, in_channels, out_channels, blocks, stride):
        layers = []
        layers.append(BasicBlockWithCBAM(in_channels, out_channels, stride))
        for _ in range(1, blocks):
            layers.append(BasicBlockWithCBAM(out_channels, out_channels, stride=1))
        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# 4. 训练监控类
class TrainingMonitor:
    def __init__(self, log_dir="logs"):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(exist_ok=True)
        
        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []
        self.lrs = []
        
        # 创建子目录
        (self.log_dir / "plots").mkdir(exist_ok=True)
        (self.log_dir / "models").mkdir(exist_ok=True)
        
    def update(self, epoch, train_loss, val_loss, train_acc, val_acc, lr):
        """更新训练指标"""
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
        self.train_accs.append(train_acc)
        self.val_accs.append(val_acc)
        self.lrs.append(lr)
        
        # 保存到文本文件
        with open(self.log_dir / "training_log.txt", "a") as f:
            f.write(f"Epoch {epoch}: "
                   f"Train Loss: {train_loss:.4f}, "
                   f"Val Loss: {val_loss:.4f}, "
                   f"Train Acc: {train_acc:.2f}%, "
                   f"Val Acc: {val_acc:.2f}%, "
                   f"LR: {lr:.6f}\n")
    
    def plot_metrics(self, show=True):
        """绘制训练指标图"""
        epochs = range(1, len(self.train_losses) + 1)
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        
        # 损失曲线
        axes[0, 0].plot(epochs, self.train_losses, 'b-', label='训练集')
        axes[0, 0].plot(epochs, self.val_losses, 'r-', label='验证集')
        axes[0, 0].set_xlabel('训练轮次')
        axes[0, 0].set_ylabel('损失值')
        axes[0, 0].set_title('训练和验证损失曲线')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        # 准确率曲线
        axes[0, 1].plot(epochs, self.train_accs, 'b-', label='训练集')
        axes[0, 1].plot(epochs, self.val_accs, 'r-', label='验证集')
        axes[0, 1].set_xlabel('训练轮次')
        axes[0, 1].set_ylabel('准确率 (%)')
        axes[0, 1].set_title('训练和验证准确率曲线')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # 学习率曲线
        axes[1, 0].plot(epochs, self.lrs, 'g-')
        axes[1, 0].set_xlabel('训练轮次')
        axes[1, 0].set_ylabel('学习率')
        axes[1, 0].set_title('学习率变化曲线')
        axes[1, 0].grid(True)
        
        # 损失-准确率散点图
        axes[1, 1].scatter(self.train_losses, self.train_accs, alpha=0.5, label='训练集')
        axes[1, 1].scatter(self.val_losses, self.val_accs, alpha=0.5, label='验证集')
        axes[1, 1].set_xlabel('损失值')
        axes[1, 1].set_ylabel('准确率 (%)')
        axes[1, 1].set_title('损失与准确率关系')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.savefig(self.log_dir / "plots" / "training_metrics.png", dpi=300, bbox_inches='tight')
        
        if show:
            plt.show()
        else:
            plt.close()
    
    def save_model(self, model, epoch, val_acc, filename=None):
        """保存模型"""
        if filename is None:
            filename = f"model_epoch_{epoch}_acc_{val_acc:.2f}.pth"
        
        model_path = self.log_dir / "models" / filename
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_acc': val_acc,
        }, model_path)
        return model_path
    
    def print_summary(self):
        """打印训练摘要"""
        if len(self.train_losses) > 0:
            print("\n" + "="*50)
            print("训练摘要:")
            print("="*50)
            print(f"最佳验证准确率: {max(self.val_accs):.2f}%")
            print(f"最佳训练准确率: {max(self.train_accs):.2f}%")
            print(f"最终验证损失: {self.val_losses[-1]:.4f}")
            print(f"最终训练损失: {self.train_losses[-1]:.4f}")
            print("="*50)

# 5. 训练函数
def train_cbam_model(num_epochs=10, batch_size=64, lr=0.001, device='cuda'):
    """训练CBAM模型的函数(适用于Jupyter Notebook)"""
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    if torch.cuda.is_available():
        print(f"GPU型号: {torch.cuda.get_device_name(0)}")
    
    # 数据增强和加载
    print("\n加载CIFAR-10数据集...")
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    # 加载CIFAR-10数据集
    train_dataset = datasets.CIFAR10(root='./data', train=True, 
                                     download=True, transform=transform_train)
    val_dataset = datasets.CIFAR10(root='./data', train=False,
                                   download=True, transform=transform_test)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                              shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, 
                            shuffle=False, num_workers=2, pin_memory=True)
    
    # 创建模型
    print("\n创建模型...")
    model = CBAMNet(num_classes=10).to(device)
    print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
    
    # 创建训练监控器
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    monitor = TrainingMonitor(f"logs/cbam_experiment_{timestamp}")
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    
    best_val_acc = 0.0
    
    print(f"\n开始训练 ({num_epochs} 轮次)...")
    print("="*80)
    
    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            # 统计训练信息
            train_loss += loss.item()
            _, predicted = output.max(1)
            train_total += target.size(0)
            train_correct += predicted.eq(target).sum().item()
            
            # 显示进度
            if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == len(train_loader):
                progress = (batch_idx + 1) / len(train_loader) * 100
                bar_length = 30
                filled_length = int(bar_length * (batch_idx + 1) // len(train_loader))
                bar = '█' * filled_length + '░' * (bar_length - filled_length)
                print(f'\r轮次 [{epoch+1}/{num_epochs}] | {bar} | {progress:.1f}%', end='')
        
        # 验证阶段
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                
                val_loss += loss.item()
                _, predicted = output.max(1)
                val_total += target.size(0)
                val_correct += predicted.eq(target).sum().item()
        
        # 计算指标
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        train_acc = 100. * train_correct / train_total
        val_acc = 100. * val_correct / val_total
        current_lr = scheduler.get_last_lr()[0]
        
        # 更新监控器
        monitor.update(epoch + 1, avg_train_loss, avg_val_loss, train_acc, val_acc, current_lr)
        
        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            monitor.save_model(model, epoch + 1, val_acc, "best_model.pth")
        
        # 调整学习率
        scheduler.step()
        
        # 打印epoch结果
        print(f'\r轮次 [{epoch+1:3d}/{num_epochs}] | '
              f'训练损失: {avg_train_loss:.4f} | '
              f'训练准确率: {train_acc:6.2f}% | '
              f'验证损失: {avg_val_loss:.4f} | '
              f'验证准确率: {val_acc:6.2f}% | '
              f'学习率: {current_lr:.6f}')
    
    # 保存最终模型
    monitor.save_model(model, num_epochs, val_acc, "final_model.pth")
    
    # 绘制训练曲线
    print("\n绘制训练曲线...")
    monitor.plot_metrics(show=True)
    
    # 打印训练摘要
    monitor.print_summary()
    
    print(f"\n训练完成!")
    print(f"日志保存到: {monitor.log_dir}")
    
    return model, monitor, train_loader, val_loader, device

# 6. 可视化注意力特征图
def visualize_cbam_attention(model, data_loader, device, num_images=4):
    """可视化CBAM注意力特征图"""
    save_path = Path("attention_visualizations")
    save_path.mkdir(exist_ok=True)
    
    model.eval()
    data_iter = iter(data_loader)
    images, labels = data_iter.next()
    
    # 获取类别名称
    class_names = ['飞机', '汽车', '鸟', '猫', '鹿', 
                   '狗', '青蛙', '马', '船', '卡车']
    
    # 获取特征图
    with torch.no_grad():
        images = images[:num_images].to(device)
        
        # 注册钩子捕获CBAM输出
        features = {}
        def get_features(name):
            def hook(model, input, output):
                features[name] = output.detach()
            return hook
        
        hooks = []
        for name, module in model.named_modules():
            if isinstance(module, CBAM):
                hook = module.register_forward_hook(get_features(name))
                hooks.append(hook)
        
        _ = model(images)
        
        # 可视化
        for i in range(min(num_images, 4)):
            fig, axes = plt.subplots(2, 3, figsize=(15, 10))
            
            # 原始图像
            img_np = images[i].cpu().numpy().transpose(1, 2, 0)
            # 反归一化
            mean = np.array([0.4914, 0.4822, 0.4465])
            std = np.array([0.2023, 0.1994, 0.2010])
            img_np = std * img_np + mean
            img_np = np.clip(img_np, 0, 1)
            
            axes[0, 0].imshow(img_np)
            axes[0, 0].set_title(f"输入图像 {i+1}\n类别: {class_names[labels[i]]}")
            axes[0, 0].axis('off')
            
            # 可视化每个CBAM层的输出
            for idx, (name, feature) in enumerate(features.items()):
                if idx >= 5:  # 最多显示5个CBAM层
                    break
                
                row = (idx + 1) // 3
                col = (idx + 1) % 3
                
                if col >= 3:  # 确保不超过3列
                    continue
                    
                # 计算平均注意力图
                attn_map = feature[i].mean(dim=0).cpu().numpy()
                im = axes[row, col].imshow(attn_map, cmap='hot')
                axes[row, col].set_title(f"{name}\n通道: {feature.shape[1]}")
                axes[row, col].axis('off')
                plt.colorbar(im, ax=axes[row, col], fraction=0.046, pad=0.04)
            
            # 如果有空白的子图,隐藏它们
            for idx in range(len(features) + 1, 6):
                row = idx // 3
                col = idx % 3
                if row < 2 and col < 3:
                    axes[row, col].axis('off')
            
            plt.suptitle(f"CBAM注意力特征图 (图像 {i+1})", fontsize=16)
            plt.tight_layout()
            plt.savefig(save_path / f"cbam_attention_{i}.png", dpi=150, bbox_inches='tight')
            plt.show()
    
    # 移除钩子
    for hook in hooks:
        hook.remove()
    
    print(f"\n注意力特征图已保存到: {save_path}")

# 7. 测试简单模型
def test_simple_cbam_model():
    """测试简单的CBAM模型"""
    print("="*50)
    print("测试简单CBAM模型")
    print("="*50)
    
    # 简单CBAM模型
    class SimpleCBAMNet(nn.Module):
        def __init__(self, num_classes=10):
            super(SimpleCBAMNet, self).__init__()
            
            self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
            self.bn1 = nn.BatchNorm2d(32)
            self.cbam1 = CBAM(32)
            
            self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
            self.bn2 = nn.BatchNorm2d(64)
            self.cbam2 = CBAM(64)
            
            self.pool = nn.MaxPool2d(2, 2)
            self.dropout = nn.Dropout(0.5)
            
            self.fc1 = nn.Linear(64 * 8 * 8, 256)
            self.fc2 = nn.Linear(256, num_classes)
            
        def forward(self, x):
            x = self.pool(F.relu(self.bn1(self.conv1(x))))
            x = self.cbam1(x)
            
            x = self.pool(F.relu(self.bn2(self.conv2(x))))
            x = self.cbam2(x)
            
            x = x.view(x.size(0), -1)
            x = self.dropout(F.relu(self.fc1(x)))
            x = self.fc2(x)
            return x
    
    # 创建设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 创建模型
    model = SimpleCBAMNet(num_classes=10).to(device)
    
    # 测试前向传播
    dummy_input = torch.randn(4, 3, 32, 32).to(device)
    output = model(dummy_input)
    
    print(f"模型结构:")
    print(model)
    print(f"\n输入形状: {dummy_input.shape}")
    print(f"输出形状: {output.shape}")
    print(f"模型总参数: {sum(p.numel() for p in model.parameters()):,}")
    print(f"可训练参数: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    return model

# 8. 在Jupyter中直接运行的示例
if __name__ == "__main__":
    # 测试简单模型
    print("正在测试简单CBAM模型...")
    simple_model = test_simple_cbam_model()
    
    # 询问是否开始训练
    print("\n" + "="*50)
    train_choice = input("是否开始训练完整CBAM模型? (y/n): ")
    
    if train_choice.lower() == 'y':
        # 设置训练参数
        num_epochs = 20
        batch_size = 64
        learning_rate = 0.001
        
        # 开始训练
        model, monitor, train_loader, val_loader, device = train_cbam_model(
            num_epochs=num_epochs,
            batch_size=batch_size,
            lr=learning_rate
        )
        
        # 询问是否可视化注意力
        print("\n" + "="*50)
        viz_choice = input("是否可视化注意力特征图? (y/n): ")
        
        if viz_choice.lower() == 'y':
            visualize_cbam_attention(model, val_loader, device, num_images=4)
    else:
        print("跳过训练,仅测试了简单模型。")
相关推荐
阿龙AI日记2 小时前
YOLO26:全新的视觉模型来了
深度学习·神经网络·yolo·目标检测
jay神2 小时前
手势识别数据集 - 专业级目标检测训练数据
人工智能·深度学习·yolo·目标检测·计算机视觉
海绵宝宝de派小星2 小时前
AI发展简史与里程碑事件
人工智能·搜索引擎
海绵宝宝de派小星2 小时前
什么是人工智能?AI、机器学习、深度学习的关系
人工智能·深度学习·机器学习·ai
HaiLang_IT2 小时前
基于图像处理与注意力机制的输电线路绝缘子缺陷智能识别方法
图像处理·人工智能
棒棒的皮皮2 小时前
【深度学习】YOLO 进阶提升之算法改进(新型骨干网络 / 特征融合方法 / 损失函数设计)
深度学习·算法·yolo·计算机视觉
大山同学2 小时前
深度学习任务分类之图像超分辨率
人工智能·深度学习·分类
一招定胜负2 小时前
机器学习项目:矿物分类系统重制版
人工智能·机器学习·分类
koo3642 小时前
pytorch深度学习笔记17
pytorch·笔记·深度学习