训练主模型
python
# 训练模型主函数(优化版)
def train_model(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
criterion: nn.Module,
optimizer: optim.Optimizer,
scheduler: optim.lr_scheduler._LRScheduler,
epochs: int
) -> tuple[list[float], list[float], list[float], list[float]]:
# 初始化训练和验证过程中的监控指标
train_losses: list[float] = [] # 存储每个epoch的训练损失
val_losses: list[float] = [] # 存储每个epoch的验证损失
train_accuracies: list[float] = [] # 存储每个epoch的训练准确率
val_accuracies: list[float] = [] # 存储每个epoch的验证准确率
# 新增:早停相关变量(可选)
best_val_loss: float = float('inf')
early_stop_counter: int = 0
early_stop_patience: int = 5 # 连续5个epoch无提升则停止
# 主训练循环 - 遍历指定轮数
for epoch in range(epochs):
# 设置模型为训练模式(启用Dropout和BatchNorm等训练特定层)
model.train()
train_loss: float = 0.0 # 累积训练损失
correct: int = 0 # 正确预测的样本数
total: int = 0 # 总样本数
# 批次训练循环 - 遍历训练数据加载器中的所有批次
for inputs, targets in train_loader:
# 将数据移至计算设备(GPU或CPU)
inputs, targets = inputs.to(device), targets.to(device)
# 梯度清零 - 防止梯度累积(每个批次独立计算梯度)
optimizer.zero_grad()
# 前向传播 - 通过模型获取预测结果
outputs = model(inputs)
# 计算损失 - 使用预定义的损失函数(如交叉熵)
loss = criterion(outputs, targets)
# 反向传播 - 计算梯度
loss.backward()
# 参数更新 - 根据优化器(如Adam)更新模型权重
optimizer.step()
# 统计训练指标
train_loss += loss.item() # 累积批次损失
_, predicted = outputs.max(1) # 获取预测类别
total += targets.size(0) # 累积总样本数
correct += predicted.eq(targets).sum().item() # 累积正确预测数
# 计算当前epoch的平均训练损失和准确率
train_loss /= len(train_loader) # 平均批次损失
train_accuracy = 100.0 * correct / total # 计算准确率百分比
train_losses.append(train_loss) # 记录损失
train_accuracies.append(train_accuracy) # 记录准确率
# 模型验证部分
model.eval() # 设置模型为评估模式(禁用Dropout等)
val_loss: float = 0.0 # 累积验证损失
correct = 0 # 正确预测的样本数
total = 0 # 总样本数
# 禁用梯度计算 - 验证过程不需要计算梯度,节省内存和计算资源
with torch.no_grad():
# 遍历验证数据加载器中的所有批次
for inputs, targets in val_loader:
# 将数据移至计算设备
inputs, targets = inputs.to(device), targets.to(device)
# 前向传播 - 获取验证预测结果
outputs = model(inputs)
# 计算验证损失
loss = criterion(outputs, targets)
# 统计验证指标
val_loss += loss.item() # 累积验证损失
_, predicted = outputs.max(1) # 获取预测类别
total += targets.size(0) # 累积总样本数
correct += predicted.eq(targets).sum().item() # 累积正确预测数
# 计算当前epoch的平均验证损失和准确率
val_loss /= len(val_loader) # 平均验证损失
val_accuracy = 100.0 * correct / total # 计算验证准确率
val_losses.append(val_loss) # 记录验证损失
val_accuracies.append(val_accuracy) # 记录验证准确率
# 打印当前epoch的训练和验证指标
print(f'Epoch {epoch+1}/{epochs}')
print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.2f}%')
print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.2f}%')
print('-' * 50)
# 更新学习率调度器(修正mode为min,匹配验证损失)
scheduler.step(val_loss) # 传入验证损失,mode='min'
# 新增:早停逻辑(可选)
if val_loss < best_val_loss:
best_val_loss = val_loss
early_stop_counter = 0
# 可选:保存最佳模型权重
torch.save(model.state_dict(), 'best_model.pth')
else:
early_stop_counter += 1
if early_stop_counter >= early_stop_patience:
print(f"Early stopping at epoch {epoch+1}")
break
# 返回训练和验证过程中的所有指标,用于后续分析和可视化
return train_losses, val_losses, train_accuracies, val_accuracies
# 训练模型(保持调用方式不变)
epochs = 20
train_losses, val_losses, train_accuracies, val_accuracies = train_model(
model, train_loader, val_loader, criterion, optimizer, scheduler, epochs
)
# 可视化训练过程(保持原函数不变)
def plot_training(train_losses, val_losses, train_accuracies, val_accuracies):
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Training and Validation Accuracy')
plt.tight_layout()
plt.show()
plot_training(train_losses, val_losses, train_accuracies, val_accuracies)