Day 42 简单CNN

观察不同结构的调度器和CNN的训练结构

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F

# =========================
# 数据加载
# =========================
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# =========================
# CNN 模型定义
# =========================
class CNN_A(nn.Module):
    def __init__(self):
        super(CNN_A, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, 1)
        self.conv2 = nn.Conv2d(16, 32, 3, 1)
        self.fc1 = nn.Linear(4608, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class CNN_B(nn.Module):
    def __init__(self):
        super(CNN_B, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.conv3 = nn.Conv2d(64, 64, 3, 1)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(7744, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout(torch.flatten(x, 1))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return F.relu(out)

class CNN_C(nn.Module):
    def __init__(self):
        super(CNN_C, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.resblock1 = ResidualBlock(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.resblock2 = ResidualBlock(64)
        self.fc1 = nn.Linear(64 * 7 * 7, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(self.resblock1(x), 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(self.resblock2(x), 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

# =========================
# 训练与评估函数
# =========================
def train_model(model, optimizer, scheduler, scheduler_name, epochs=5):
    train_losses, test_losses, lrs = [], [], []
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        train_losses.append(total_loss / len(train_loader))

        # Evaluate
        model.eval()
        correct, test_loss = 0, 0
        with torch.no_grad():
            for data, target in test_loader:
                output = model(data)
                test_loss += F.cross_entropy(output, target, reduction='sum').item()
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
        test_loss /= len(test_loader.dataset)
        acc = 100. * correct / len(test_loader.dataset)
        test_losses.append(test_loss)

        # Scheduler step
        if scheduler_name == "ReduceLROnPlateau":
            scheduler.step(test_loss)
        else:
            scheduler.step()

        lrs.append(optimizer.param_groups[0]['lr'])

        print(f"Epoch {epoch+1}: Train Loss={train_losses[-1]:.4f}, Test Loss={test_loss:.4f}, Acc={acc:.2f}%, LR={lrs[-1]:.6f}")
    
    return train_losses, test_losses, lrs

# =========================
# 调度器与模型组合实验
# =========================
models = {
    'CNN_A': CNN_A(),
    'CNN_B': CNN_B(),
    'CNN_C': CNN_C()
}

scheduler_factories = {
    'StepLR': lambda opt: StepLR(opt, step_size=3, gamma=0.5),
    'ExponentialLR': lambda opt: ExponentialLR(opt, gamma=0.9),
    'CosineAnnealingLR': lambda opt: CosineAnnealingLR(opt, T_max=5),
    'ReduceLROnPlateau': lambda opt: ReduceLROnPlateau(opt, patience=2)
}

results = {}

for model_name, model in models.items():
    for sched_name, sched_factory in scheduler_factories.items():
        print(f"\n===== Training {model_name} with {sched_name} =====")
        model_copy = model.__class__()
        optimizer = optim.Adam(model_copy.parameters(), lr=0.001)
        scheduler = sched_factory(optimizer)
        train_losses, test_losses, lrs = train_model(model_copy, optimizer, scheduler, sched_name, epochs=5)
        results[(model_name, sched_name)] = (train_losses, test_losses, lrs)

# =========================
# 可视化结果
# =========================
plt.figure(figsize=(12, 6))
for (model_name, sched_name), (train_losses, test_losses, lrs) in results.items():
    plt.plot(test_losses, label=f'{model_name}-{sched_name}')
plt.title('Test Loss Comparison')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

plt.figure(figsize=(12, 6))
for (model_name, sched_name), (train_losses, test_losses, lrs) in results.items():
    plt.plot(lrs, label=f'{model_name}-{sched_name}')
plt.title('Learning Rate Schedules')
plt.xlabel('Epoch')
plt.ylabel('LR')
plt.legend()
plt.show()

@浙大疏锦行

相关推荐
VCR__14 小时前
python第三次作业
开发语言·python
韩立学长14 小时前
【开题答辩实录分享】以《助农信息发布系统设计与实现》为例进行选题答辩实录分享
python·web
小白狮ww14 小时前
Ovis-Image:卓越的图像生成模型
人工智能·深度学习·目标检测·机器学习·cpu·gpu·视觉分割模型
滴啦嘟啦哒14 小时前
【机械臂】【LLM】一、接入千问LLM实现自然语言指令解析
深度学习·ros·vla
工程师老罗14 小时前
Pytorch完整的模型训练流程
人工智能·pytorch·深度学习
2401_8384725115 小时前
使用Scikit-learn构建你的第一个机器学习模型
jvm·数据库·python
u01092727115 小时前
使用Python进行网络设备自动配置
jvm·数据库·python
工程师老罗15 小时前
优化器、反向传播、损失函数之间是什么关系,Pytorch中如何使用和设置?
人工智能·pytorch·python
Fleshy数模15 小时前
我的第一只Python爬虫:从Requests库到爬取整站新书
开发语言·爬虫·python
CoLiuRs15 小时前
Image-to-3D — 让 2D 图片跃然立体*
python·3d·flask