python学习打卡:DAY 40 训练和测试的规范写法

浙大疏锦行-CSDN博客
知识点回顾:

  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。


在图像数据的格式以及模型定义的过程中,发现和之前结构化数据略有不同,主要差异体现在2处:

  1. 模型定义时需要展平图像

  2. 由于数据过大,需要将数据集进行分批次处理,往往涉及到dataset和dataloader来规范代码组织

python 复制代码
# 先继续之前的代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt
import warnings
# 忽略警告信息
warnings.filterwarnings("ignore")
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
device = torch.device("mps")
print(f"使用设备: {device}")
 
# 1. 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量并归一化到[0,1]
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差
])
 
# 2. 加载MNIST数据集
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)
 
test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)
 
# 3. 创建数据加载器
batch_size = 64  # 每批处理64个样本
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
 
# 4. 定义模型、损失函数和优化器
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()  # 将28x28的图像展平为784维向量
        self.layer1 = nn.Linear(784, 128)  # 第一层:784个输入,128个神经元
        self.relu = nn.ReLU()  # 激活函数
        self.layer2 = nn.Linear(128, 10)  # 第二层:128个输入,10个输出(对应10个数字类别)
        
    def forward(self, x):
        x = self.flatten(x)  # 展平图像
        x = self.layer1(x)   # 第一层线性变换
        x = self.relu(x)     # 应用ReLU激活函数
        x = self.layer2(x)   # 第二层线性变换,输出logits
        return x
 
# 初始化模型
model = MLP()
model = model.to(device)  # 将模型移至GPU(如果可用)
 
# from torchsummary import summary  # 导入torchsummary库
# print("\n模型结构信息:")
# summary(model, input_size=(1, 28, 28))  # 输入尺寸为MNIST图像尺寸
 
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,适用于多分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器
 
# 5. 训练模型(记录每个 iteration 的损失)
def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):
    model.train()  # 设置为训练模式
    
    # 新增:记录每个 iteration 的损失
    all_iter_losses = []  # 存储所有 batch 的损失
    iter_indices = []     # 存储 iteration 序号(从1开始)
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            # enumerate() 是 Python 内置函数,用于遍历可迭代对象(如列表、元组)并同时获取索引和值。
            # batch_idx:当前批次的索引(从 0 开始)
            # (data, target):当前批次的样本数据和对应的标签,是一个元组,这是因为dataloader内置的getitem方法返回的是一个元组,包含数据和标签。
            # 只需要记住这种固定写法即可
            data, target = data.to(device), target.to(device)  # 移至GPU(如果可用)
            
            optimizer.zero_grad()  # 梯度清零
            output = model(data)  # 前向传播
            loss = criterion(output, target)  # 计算损失
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数
            
            # 记录当前 iteration 的损失(注意:这里直接使用单 batch 损失,而非累加平均)
            iter_loss = loss.item()
            all_iter_losses.append(iter_loss)
            iter_indices.append(epoch * len(train_loader) + batch_idx + 1)  # iteration 序号从1开始
            
            # 统计准确率和损失
            running_loss += loss.item() #将loss转化为标量值并且累加到running_loss中,计算总损失
            _, predicted = output.max(1) # output:是模型的输出(logits),形状为 [batch_size, 10](MNIST 有 10 个类别)
            # 获取预测结果,max(1) 返回每行(即每个样本)的最大值和对应的索引,这里我们只需要索引
            total += target.size(0) # target.size(0) 返回当前批次的样本数量,即 batch_size,累加所有批次的样本数,最终等于训练集的总样本数
            correct += predicted.eq(target).sum().item() # 返回一个布尔张量,表示预测是否正确,sum() 计算正确预测的数量,item() 将结果转换为 Python 数字
            
            
            # 每100个批次打印一次训练信息(可选:同时打印单 batch 损失)
            if (batch_idx + 1) % 100 == 0:
                print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} '
                      f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')
        
        # 测试、打印 epoch 结果
        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_acc = 100. * correct / total
        epoch_test_loss, epoch_test_acc = test(model, test_loader, criterion, device)
        
        print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')
    
    # 绘制所有 iteration 的损失曲线
    plot_iter_losses(all_iter_losses, iter_indices)
    # 保留原 epoch 级曲线(可选)
    # plot_metrics(train_losses, test_losses, train_accuracies, test_accuracies, epochs)
    
    return epoch_test_acc  # 返回最终测试准确率

这里先不写早停策略,因为规范的早停策略需要用到验证集,一般还需要划分测试集

  1. 划分数据集:训练集、验证集、测试集。

  2. 训练过程中,使用验证集触发早停。

  3. 训练结束后,仅用测试集运行一次测试函数,得到最终准确率。

python 复制代码
# 6. 测试模型(不变)
def test(model, test_loader, criterion, device):
    model.eval()  # 设置为评估模式
    test_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():  # 不计算梯度,节省内存和计算资源
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    avg_loss = test_loss / len(test_loader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy  # 返回损失和准确率
 
# 7. 绘制每个 iteration 的损失曲线
def plot_iter_losses(losses, indices):
    plt.figure(figsize=(10, 4))
    plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')
    plt.xlabel('Iteration(Batch序号)')
    plt.ylabel('损失值')
    plt.title('每个 Iteration 的训练损失')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
 
# 8. 执行训练和测试(设置 epochs=2 验证效果)
epochs = 2  
print("开始训练模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")
相关推荐
西岸行者4 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习
悠哉悠哉愿意4 天前
【单片机学习笔记】串口、超声波、NE555的同时使用
笔记·单片机·学习
别催小唐敲代码4 天前
嵌入式学习路线
学习
毛小茛5 天前
计算机系统概论——校验码
学习
babe小鑫5 天前
大专经济信息管理专业学习数据分析的必要性
学习·数据挖掘·数据分析
winfreedoms5 天前
ROS2知识大白话
笔记·学习·ros2
在这habit之下5 天前
Linux Virtual Server(LVS)学习总结
linux·学习·lvs
我想我不够好。5 天前
2026.2.25监控学习
学习
im_AMBER5 天前
Leetcode 127 删除有序数组中的重复项 | 删除有序数组中的重复项 II
数据结构·学习·算法·leetcode
CodeJourney_J5 天前
从“Hello World“ 开始 C++
c语言·c++·学习