python
复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
# 设置matplotlib中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
# 1. 定义CBAM模块
class ChannelAttention(nn.Module):
def __init__(self, in_channels, reduction_ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avg_out, max_out], dim=1)
out = self.conv(out)
return self.sigmoid(out)
class CBAM(nn.Module):
def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
super(CBAM, self).__init__()
self.channel_attention = ChannelAttention(in_channels, reduction_ratio)
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x):
# 通道注意力
x = x * self.channel_attention(x)
# 空间注意力
x = x * self.spatial_attention(x)
return x
# 2. 带有CBAM的残差块
class BasicBlockWithCBAM(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(BasicBlockWithCBAM, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.cbam = CBAM(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = self.shortcut(x)
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# 应用CBAM注意力
out = self.cbam(out)
out += identity
out = F.relu(out)
return out
# 3. 完整的网络模型
class CBAMNet(nn.Module):
def __init__(self, num_classes=10):
super(CBAMNet, self).__init__()
# 初始卷积层
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
# 残差块
self.layer1 = self._make_layer(64, 64, 2, stride=1)
self.layer2 = self._make_layer(64, 128, 2, stride=2)
self.layer3 = self._make_layer(128, 256, 2, stride=2)
self.layer4 = self._make_layer(256, 512, 2, stride=2)
# 全局平均池化和全连接层
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
# 初始化权重
self._initialize_weights()
def _make_layer(self, in_channels, out_channels, blocks, stride):
layers = []
layers.append(BasicBlockWithCBAM(in_channels, out_channels, stride))
for _ in range(1, blocks):
layers.append(BasicBlockWithCBAM(out_channels, out_channels, stride=1))
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# 4. 训练监控类
class TrainingMonitor:
def __init__(self, log_dir="logs"):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(exist_ok=True)
self.train_losses = []
self.val_losses = []
self.train_accs = []
self.val_accs = []
self.lrs = []
# 创建子目录
(self.log_dir / "plots").mkdir(exist_ok=True)
(self.log_dir / "models").mkdir(exist_ok=True)
def update(self, epoch, train_loss, val_loss, train_acc, val_acc, lr):
"""更新训练指标"""
self.train_losses.append(train_loss)
self.val_losses.append(val_loss)
self.train_accs.append(train_acc)
self.val_accs.append(val_acc)
self.lrs.append(lr)
# 保存到文本文件
with open(self.log_dir / "training_log.txt", "a") as f:
f.write(f"Epoch {epoch}: "
f"Train Loss: {train_loss:.4f}, "
f"Val Loss: {val_loss:.4f}, "
f"Train Acc: {train_acc:.2f}%, "
f"Val Acc: {val_acc:.2f}%, "
f"LR: {lr:.6f}\n")
def plot_metrics(self, show=True):
"""绘制训练指标图"""
epochs = range(1, len(self.train_losses) + 1)
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# 损失曲线
axes[0, 0].plot(epochs, self.train_losses, 'b-', label='训练集')
axes[0, 0].plot(epochs, self.val_losses, 'r-', label='验证集')
axes[0, 0].set_xlabel('训练轮次')
axes[0, 0].set_ylabel('损失值')
axes[0, 0].set_title('训练和验证损失曲线')
axes[0, 0].legend()
axes[0, 0].grid(True)
# 准确率曲线
axes[0, 1].plot(epochs, self.train_accs, 'b-', label='训练集')
axes[0, 1].plot(epochs, self.val_accs, 'r-', label='验证集')
axes[0, 1].set_xlabel('训练轮次')
axes[0, 1].set_ylabel('准确率 (%)')
axes[0, 1].set_title('训练和验证准确率曲线')
axes[0, 1].legend()
axes[0, 1].grid(True)
# 学习率曲线
axes[1, 0].plot(epochs, self.lrs, 'g-')
axes[1, 0].set_xlabel('训练轮次')
axes[1, 0].set_ylabel('学习率')
axes[1, 0].set_title('学习率变化曲线')
axes[1, 0].grid(True)
# 损失-准确率散点图
axes[1, 1].scatter(self.train_losses, self.train_accs, alpha=0.5, label='训练集')
axes[1, 1].scatter(self.val_losses, self.val_accs, alpha=0.5, label='验证集')
axes[1, 1].set_xlabel('损失值')
axes[1, 1].set_ylabel('准确率 (%)')
axes[1, 1].set_title('损失与准确率关系')
axes[1, 1].legend()
axes[1, 1].grid(True)
plt.tight_layout()
plt.savefig(self.log_dir / "plots" / "training_metrics.png", dpi=300, bbox_inches='tight')
if show:
plt.show()
else:
plt.close()
def save_model(self, model, epoch, val_acc, filename=None):
"""保存模型"""
if filename is None:
filename = f"model_epoch_{epoch}_acc_{val_acc:.2f}.pth"
model_path = self.log_dir / "models" / filename
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'val_acc': val_acc,
}, model_path)
return model_path
def print_summary(self):
"""打印训练摘要"""
if len(self.train_losses) > 0:
print("\n" + "="*50)
print("训练摘要:")
print("="*50)
print(f"最佳验证准确率: {max(self.val_accs):.2f}%")
print(f"最佳训练准确率: {max(self.train_accs):.2f}%")
print(f"最终验证损失: {self.val_losses[-1]:.4f}")
print(f"最终训练损失: {self.train_losses[-1]:.4f}")
print("="*50)
# 5. 训练函数
def train_cbam_model(num_epochs=10, batch_size=64, lr=0.001, device='cuda'):
"""训练CBAM模型的函数(适用于Jupyter Notebook)"""
device = torch.device(device if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
if torch.cuda.is_available():
print(f"GPU型号: {torch.cuda.get_device_name(0)}")
# 数据增强和加载
print("\n加载CIFAR-10数据集...")
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
val_dataset = datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
shuffle=False, num_workers=2, pin_memory=True)
# 创建模型
print("\n创建模型...")
model = CBAMNet(num_classes=10).to(device)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
# 创建训练监控器
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
monitor = TrainingMonitor(f"logs/cbam_experiment_{timestamp}")
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
best_val_acc = 0.0
print(f"\n开始训练 ({num_epochs} 轮次)...")
print("="*80)
for epoch in range(num_epochs):
# 训练阶段
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 统计训练信息
train_loss += loss.item()
_, predicted = output.max(1)
train_total += target.size(0)
train_correct += predicted.eq(target).sum().item()
# 显示进度
if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == len(train_loader):
progress = (batch_idx + 1) / len(train_loader) * 100
bar_length = 30
filled_length = int(bar_length * (batch_idx + 1) // len(train_loader))
bar = '█' * filled_length + '░' * (bar_length - filled_length)
print(f'\r轮次 [{epoch+1}/{num_epochs}] | {bar} | {progress:.1f}%', end='')
# 验证阶段
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
val_loss += loss.item()
_, predicted = output.max(1)
val_total += target.size(0)
val_correct += predicted.eq(target).sum().item()
# 计算指标
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
train_acc = 100. * train_correct / train_total
val_acc = 100. * val_correct / val_total
current_lr = scheduler.get_last_lr()[0]
# 更新监控器
monitor.update(epoch + 1, avg_train_loss, avg_val_loss, train_acc, val_acc, current_lr)
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
monitor.save_model(model, epoch + 1, val_acc, "best_model.pth")
# 调整学习率
scheduler.step()
# 打印epoch结果
print(f'\r轮次 [{epoch+1:3d}/{num_epochs}] | '
f'训练损失: {avg_train_loss:.4f} | '
f'训练准确率: {train_acc:6.2f}% | '
f'验证损失: {avg_val_loss:.4f} | '
f'验证准确率: {val_acc:6.2f}% | '
f'学习率: {current_lr:.6f}')
# 保存最终模型
monitor.save_model(model, num_epochs, val_acc, "final_model.pth")
# 绘制训练曲线
print("\n绘制训练曲线...")
monitor.plot_metrics(show=True)
# 打印训练摘要
monitor.print_summary()
print(f"\n训练完成!")
print(f"日志保存到: {monitor.log_dir}")
return model, monitor, train_loader, val_loader, device
# 6. 可视化注意力特征图
def visualize_cbam_attention(model, data_loader, device, num_images=4):
"""可视化CBAM注意力特征图"""
save_path = Path("attention_visualizations")
save_path.mkdir(exist_ok=True)
model.eval()
data_iter = iter(data_loader)
images, labels = data_iter.next()
# 获取类别名称
class_names = ['飞机', '汽车', '鸟', '猫', '鹿',
'狗', '青蛙', '马', '船', '卡车']
# 获取特征图
with torch.no_grad():
images = images[:num_images].to(device)
# 注册钩子捕获CBAM输出
features = {}
def get_features(name):
def hook(model, input, output):
features[name] = output.detach()
return hook
hooks = []
for name, module in model.named_modules():
if isinstance(module, CBAM):
hook = module.register_forward_hook(get_features(name))
hooks.append(hook)
_ = model(images)
# 可视化
for i in range(min(num_images, 4)):
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# 原始图像
img_np = images[i].cpu().numpy().transpose(1, 2, 0)
# 反归一化
mean = np.array([0.4914, 0.4822, 0.4465])
std = np.array([0.2023, 0.1994, 0.2010])
img_np = std * img_np + mean
img_np = np.clip(img_np, 0, 1)
axes[0, 0].imshow(img_np)
axes[0, 0].set_title(f"输入图像 {i+1}\n类别: {class_names[labels[i]]}")
axes[0, 0].axis('off')
# 可视化每个CBAM层的输出
for idx, (name, feature) in enumerate(features.items()):
if idx >= 5: # 最多显示5个CBAM层
break
row = (idx + 1) // 3
col = (idx + 1) % 3
if col >= 3: # 确保不超过3列
continue
# 计算平均注意力图
attn_map = feature[i].mean(dim=0).cpu().numpy()
im = axes[row, col].imshow(attn_map, cmap='hot')
axes[row, col].set_title(f"{name}\n通道: {feature.shape[1]}")
axes[row, col].axis('off')
plt.colorbar(im, ax=axes[row, col], fraction=0.046, pad=0.04)
# 如果有空白的子图,隐藏它们
for idx in range(len(features) + 1, 6):
row = idx // 3
col = idx % 3
if row < 2 and col < 3:
axes[row, col].axis('off')
plt.suptitle(f"CBAM注意力特征图 (图像 {i+1})", fontsize=16)
plt.tight_layout()
plt.savefig(save_path / f"cbam_attention_{i}.png", dpi=150, bbox_inches='tight')
plt.show()
# 移除钩子
for hook in hooks:
hook.remove()
print(f"\n注意力特征图已保存到: {save_path}")
# 7. 测试简单模型
def test_simple_cbam_model():
"""测试简单的CBAM模型"""
print("="*50)
print("测试简单CBAM模型")
print("="*50)
# 简单CBAM模型
class SimpleCBAMNet(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCBAMNet, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.cbam1 = CBAM(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.cbam2 = CBAM(64)
self.pool = nn.MaxPool2d(2, 2)
self.dropout = nn.Dropout(0.5)
self.fc1 = nn.Linear(64 * 8 * 8, 256)
self.fc2 = nn.Linear(256, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.bn1(self.conv1(x))))
x = self.cbam1(x)
x = self.pool(F.relu(self.bn2(self.conv2(x))))
x = self.cbam2(x)
x = x.view(x.size(0), -1)
x = self.dropout(F.relu(self.fc1(x)))
x = self.fc2(x)
return x
# 创建设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 创建模型
model = SimpleCBAMNet(num_classes=10).to(device)
# 测试前向传播
dummy_input = torch.randn(4, 3, 32, 32).to(device)
output = model(dummy_input)
print(f"模型结构:")
print(model)
print(f"\n输入形状: {dummy_input.shape}")
print(f"输出形状: {output.shape}")
print(f"模型总参数: {sum(p.numel() for p in model.parameters()):,}")
print(f"可训练参数: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
return model
# 8. 在Jupyter中直接运行的示例
if __name__ == "__main__":
# 测试简单模型
print("正在测试简单CBAM模型...")
simple_model = test_simple_cbam_model()
# 询问是否开始训练
print("\n" + "="*50)
train_choice = input("是否开始训练完整CBAM模型? (y/n): ")
if train_choice.lower() == 'y':
# 设置训练参数
num_epochs = 20
batch_size = 64
learning_rate = 0.001
# 开始训练
model, monitor, train_loader, val_loader, device = train_cbam_model(
num_epochs=num_epochs,
batch_size=batch_size,
lr=learning_rate
)
# 询问是否可视化注意力
print("\n" + "="*50)
viz_choice = input("是否可视化注意力特征图? (y/n): ")
if viz_choice.lower() == 'y':
visualize_cbam_attention(model, val_loader, device, num_images=4)
else:
print("跳过训练,仅测试了简单模型。")