Day 42 简单CNN

观察不同结构的调度器和CNN的训练结构

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F

# =========================
# 数据加载
# =========================
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# =========================
# CNN 模型定义
# =========================
class CNN_A(nn.Module):
    def __init__(self):
        super(CNN_A, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, 1)
        self.conv2 = nn.Conv2d(16, 32, 3, 1)
        self.fc1 = nn.Linear(4608, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class CNN_B(nn.Module):
    def __init__(self):
        super(CNN_B, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.conv3 = nn.Conv2d(64, 64, 3, 1)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(7744, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout(torch.flatten(x, 1))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return F.relu(out)

class CNN_C(nn.Module):
    def __init__(self):
        super(CNN_C, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.resblock1 = ResidualBlock(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.resblock2 = ResidualBlock(64)
        self.fc1 = nn.Linear(64 * 7 * 7, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(self.resblock1(x), 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(self.resblock2(x), 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

# =========================
# 训练与评估函数
# =========================
def train_model(model, optimizer, scheduler, scheduler_name, epochs=5):
    train_losses, test_losses, lrs = [], [], []
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        train_losses.append(total_loss / len(train_loader))

        # Evaluate
        model.eval()
        correct, test_loss = 0, 0
        with torch.no_grad():
            for data, target in test_loader:
                output = model(data)
                test_loss += F.cross_entropy(output, target, reduction='sum').item()
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
        test_loss /= len(test_loader.dataset)
        acc = 100. * correct / len(test_loader.dataset)
        test_losses.append(test_loss)

        # Scheduler step
        if scheduler_name == "ReduceLROnPlateau":
            scheduler.step(test_loss)
        else:
            scheduler.step()

        lrs.append(optimizer.param_groups[0]['lr'])

        print(f"Epoch {epoch+1}: Train Loss={train_losses[-1]:.4f}, Test Loss={test_loss:.4f}, Acc={acc:.2f}%, LR={lrs[-1]:.6f}")
    
    return train_losses, test_losses, lrs

# =========================
# 调度器与模型组合实验
# =========================
models = {
    'CNN_A': CNN_A(),
    'CNN_B': CNN_B(),
    'CNN_C': CNN_C()
}

scheduler_factories = {
    'StepLR': lambda opt: StepLR(opt, step_size=3, gamma=0.5),
    'ExponentialLR': lambda opt: ExponentialLR(opt, gamma=0.9),
    'CosineAnnealingLR': lambda opt: CosineAnnealingLR(opt, T_max=5),
    'ReduceLROnPlateau': lambda opt: ReduceLROnPlateau(opt, patience=2)
}

results = {}

for model_name, model in models.items():
    for sched_name, sched_factory in scheduler_factories.items():
        print(f"\n===== Training {model_name} with {sched_name} =====")
        model_copy = model.__class__()
        optimizer = optim.Adam(model_copy.parameters(), lr=0.001)
        scheduler = sched_factory(optimizer)
        train_losses, test_losses, lrs = train_model(model_copy, optimizer, scheduler, sched_name, epochs=5)
        results[(model_name, sched_name)] = (train_losses, test_losses, lrs)

# =========================
# 可视化结果
# =========================
plt.figure(figsize=(12, 6))
for (model_name, sched_name), (train_losses, test_losses, lrs) in results.items():
    plt.plot(test_losses, label=f'{model_name}-{sched_name}')
plt.title('Test Loss Comparison')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

plt.figure(figsize=(12, 6))
for (model_name, sched_name), (train_losses, test_losses, lrs) in results.items():
    plt.plot(lrs, label=f'{model_name}-{sched_name}')
plt.title('Learning Rate Schedules')
plt.xlabel('Epoch')
plt.ylabel('LR')
plt.legend()
plt.show()

@浙大疏锦行

相关推荐
The_Ticker22 分钟前
印度股票实时行情API(低成本方案)
python·websocket·算法·金融·区块链
ZC跨境爬虫29 分钟前
Scrapy工作空间搭建与目录结构解析:从初始化到基础配置全流程
前端·爬虫·python·scrapy·自动化
EAIReport32 分钟前
国外网站数据批量采集技术实现路径
开发语言·python
Ulyanov36 分钟前
基于ttk的现代化Python音视频播放器:UI设计与可视化技术深度解析
python·ui·音视频
Freak嵌入式44 分钟前
MicroPython LVGL基础知识和概念:时序与动态效果
开发语言·python·github·php·gui·lvgl·micropython
zhangzeyuaaa1 小时前
Python 中的 Map 和 Reduce 详解
开发语言·python
Dfreedom.1 小时前
【实战篇】图像分割-计算图中不同颜色区域的面积比
图像处理·人工智能·深度学习·计算机视觉·图像分割·otsu
七夜zippoe2 小时前
Java技术未来展望:GraalVM、Quarkus、Helidon等新趋势探讨
java·开发语言·python·quarkus·graaivm·helidon
m0_738120722 小时前
网络安全编程——Python编写基于UDP的主机发现工具(解码IP header)
python·网络协议·tcp/ip·安全·web安全·udp
北冥有羽Victoria2 小时前
OpenCLI 操作网页 从0到1完整实操指南
vscode·爬虫·python·github·api·ai编程·opencli