Day43 复习日

训练主模型

python 复制代码
# 训练模型主函数(优化版)
def train_model(
    model: nn.Module, 
    train_loader: DataLoader, 
    val_loader: DataLoader, 
    criterion: nn.Module, 
    optimizer: optim.Optimizer, 
    scheduler: optim.lr_scheduler._LRScheduler, 
    epochs: int
) -> tuple[list[float], list[float], list[float], list[float]]:
    # 初始化训练和验证过程中的监控指标
    train_losses: list[float] = []  # 存储每个epoch的训练损失
    val_losses: list[float] = []    # 存储每个epoch的验证损失
    train_accuracies: list[float] = []  # 存储每个epoch的训练准确率
    val_accuracies: list[float] = []    # 存储每个epoch的验证准确率
    
    # 新增:早停相关变量(可选)
    best_val_loss: float = float('inf')
    early_stop_counter: int = 0
    early_stop_patience: int = 5  # 连续5个epoch无提升则停止
    
    # 主训练循环 - 遍历指定轮数
    for epoch in range(epochs):
        # 设置模型为训练模式(启用Dropout和BatchNorm等训练特定层)
        model.train()
        train_loss: float = 0.0  # 累积训练损失
        correct: int = 0         # 正确预测的样本数
        total: int = 0           # 总样本数
        
        # 批次训练循环 - 遍历训练数据加载器中的所有批次
        for inputs, targets in train_loader:
            # 将数据移至计算设备(GPU或CPU)
            inputs, targets = inputs.to(device), targets.to(device)
            
            # 梯度清零 - 防止梯度累积(每个批次独立计算梯度)
            optimizer.zero_grad()
            
            # 前向传播 - 通过模型获取预测结果
            outputs = model(inputs)
            
            # 计算损失 - 使用预定义的损失函数(如交叉熵)
            loss = criterion(outputs, targets)
            
            # 反向传播 - 计算梯度
            loss.backward()
            
            # 参数更新 - 根据优化器(如Adam)更新模型权重
            optimizer.step()
            
            # 统计训练指标
            train_loss += loss.item()  # 累积批次损失
            _, predicted = outputs.max(1)  # 获取预测类别
            total += targets.size(0)  # 累积总样本数
            correct += predicted.eq(targets).sum().item()  # 累积正确预测数
        
        # 计算当前epoch的平均训练损失和准确率
        train_loss /= len(train_loader)  # 平均批次损失
        train_accuracy = 100.0 * correct / total  # 计算准确率百分比
        train_losses.append(train_loss)  # 记录损失
        train_accuracies.append(train_accuracy)  # 记录准确率
        
        # 模型验证部分
        model.eval()  # 设置模型为评估模式(禁用Dropout等)
        val_loss: float = 0.0  # 累积验证损失
        correct = 0   # 正确预测的样本数
        total = 0     # 总样本数
        
        # 禁用梯度计算 - 验证过程不需要计算梯度,节省内存和计算资源
        with torch.no_grad():
            # 遍历验证数据加载器中的所有批次
            for inputs, targets in val_loader:
                # 将数据移至计算设备
                inputs, targets = inputs.to(device), targets.to(device)
                
                # 前向传播 - 获取验证预测结果
                outputs = model(inputs)
                
                # 计算验证损失
                loss = criterion(outputs, targets)
                
                # 统计验证指标
                val_loss += loss.item()  # 累积验证损失
                _, predicted = outputs.max(1)  # 获取预测类别
                total += targets.size(0)  # 累积总样本数
                correct += predicted.eq(targets).sum().item()  # 累积正确预测数
        
        # 计算当前epoch的平均验证损失和准确率
        val_loss /= len(val_loader)  # 平均验证损失
        val_accuracy = 100.0 * correct / total  # 计算验证准确率
        val_losses.append(val_loss)  # 记录验证损失
        val_accuracies.append(val_accuracy)  # 记录验证准确率
        
        # 打印当前epoch的训练和验证指标
        print(f'Epoch {epoch+1}/{epochs}')
        print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.2f}%')
        print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.2f}%')
        print('-' * 50)
        
        # 更新学习率调度器(修正mode为min,匹配验证损失)
        scheduler.step(val_loss)  # 传入验证损失,mode='min'
        
        # 新增:早停逻辑(可选)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_counter = 0
            # 可选:保存最佳模型权重
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            early_stop_counter += 1
            if early_stop_counter >= early_stop_patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    # 返回训练和验证过程中的所有指标,用于后续分析和可视化
    return train_losses, val_losses, train_accuracies, val_accuracies
 
# 训练模型(保持调用方式不变)
epochs = 20  
train_losses, val_losses, train_accuracies, val_accuracies = train_model(
    model, train_loader, val_loader, criterion, optimizer, scheduler, epochs
)
 
# 可视化训练过程(保持原函数不变)
def plot_training(train_losses, val_losses, train_accuracies, val_accuracies):
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.title('Training and Validation Accuracy')
    
    plt.tight_layout()
    plt.show()
 
plot_training(train_losses, val_losses, train_accuracies, val_accuracies)

@浙大疏锦行