知识点回顾:
-
过拟合的判断:测试集和训练集同步打印指标
-
模型的保存和加载
a. 仅保存权重
b. 保存权重和模型
c. 保存全部信息checkpoint,还包含训练状态
- 早停策略
**作业:**对信贷数据集训练后保存权重,加载权重后继续训练50轮,并采取早停策略
# 保存模型权重
torch.save(model.state_dict(), 'credit_model_weights.pth')
# 加载模型权重
model.load_state_dict(torch.load('credit_model_weights.pth'))
# 设置继续训练的轮数
additional_epochs = 50
for epoch in range(additional_epochs):
# 前向传播
outputs = model(X_train_tensor)
loss = criterion(outputs, y_train_tensor)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{additional_epochs}], Loss: {loss.item():.4f}')
# 保存继续训练后的模型权重
torch.save(model.state_dict(), 'credit_model_weights_continued.pth')
# 早停策略参数
patience = 10 # 容忍验证集损失不下降的最大轮数
best_val_loss = float('inf')
counter = 0
for epoch in range(num_epochs):
# 训练代码
model.train()
outputs = model(X_train_tensor)
train_loss = criterion(outputs, y_train_tensor)
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
# 验证代码
model.eval()
with torch.no_grad():
val_outputs = model(X_val_tensor)
val_loss = criterion(val_outputs, y_val_tensor)
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss.item():.4f}, Val Loss: {val_loss.item():.4f}')
# 早停策略逻辑
if val_loss < best_val_loss:
best_val_loss = val_loss
counter = 0
# 保存最佳模型权重
torch.save(model.state_dict(), 'best_credit_model_weights.pth')
else:
counter += 1
if counter >= patience:
print('Early stopping!')
break