python打卡day37

@疏锦行

知识点回顾:

  1. 过拟合的判断:测试集和训练集同步打印指标

  2. 模型的保存和加载

a. 仅保存权重

b. 保存权重和模型

c. 保存全部信息checkpoint,还包含训练状态

  1. 早停策略

**作业:**对信贷数据集训练后保存权重,加载权重后继续训练50轮,并采取早停策略

复制代码
# 保存模型权重
torch.save(model.state_dict(), 'credit_model_weights.pth')

# 加载模型权重
model.load_state_dict(torch.load('credit_model_weights.pth'))

# 设置继续训练的轮数
additional_epochs = 50

for epoch in range(additional_epochs):
    # 前向传播
    outputs = model(X_train_tensor)
    loss = criterion(outputs, y_train_tensor)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{additional_epochs}], Loss: {loss.item():.4f}')

# 保存继续训练后的模型权重
torch.save(model.state_dict(), 'credit_model_weights_continued.pth')
# 早停策略参数
patience = 10  # 容忍验证集损失不下降的最大轮数
best_val_loss = float('inf')
counter = 0

for epoch in range(num_epochs):
    # 训练代码
    model.train()
    outputs = model(X_train_tensor)
    train_loss = criterion(outputs, y_train_tensor)
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()

    # 验证代码
    model.eval()
    with torch.no_grad():
        val_outputs = model(X_val_tensor)
        val_loss = criterion(val_outputs, y_val_tensor)

    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss.item():.4f}, Val Loss: {val_loss.item():.4f}')

    # 早停策略逻辑
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        # 保存最佳模型权重
        torch.save(model.state_dict(), 'best_credit_model_weights.pth')
    else:
        counter += 1
        if counter >= patience:
            print('Early stopping!')
            break
相关推荐
happyness442 分钟前
2026 主流 AI 编码全景对比表
人工智能·ai编程
智慧医养结合软件开源2 分钟前
数智协同,赋能康养服务高效升级
大数据·人工智能·云计算·生活
And_Ii3 分钟前
leetCode 146. LRU 缓存
python·链表
SEO_juper3 分钟前
行业白皮书 GEO 化转 HTML + 结构化,AI 引用率提升 50%
人工智能·chatgpt·seo·白皮书·独立站·外贸电商·谷歌geo
萤萤七悬7 分钟前
【AI精彩BUG汇总】一、yolo图像训练截图蓝色变橙色
人工智能·yolo·bug
思维新观察8 分钟前
从 AI 工具到音乐生态:可酷加速布局,构建数字音乐全新基础设施
人工智能
Miss roro9 分钟前
法律文书信息自动提取:OCR识别与AI技术在案件管理中的应用
人工智能·ocr·法律科技·律所管理系统·案件管理系统
乐迪信息13 分钟前
乐迪信息:港口夜间船舶巡查难,AI摄像机法全天候监测
人工智能·物联网·算法·计算机视觉·目标跟踪
sali-tec14 分钟前
C# 基于OpenCv的视觉工作流-章74-线-线距离
图像处理·人工智能·opencv·算法·计算机视觉
byte轻骑兵16 分钟前
【HID】规范精讲[17]: 蓝牙HID设备功耗优化秘籍——从Sniff模式到断连重连的省电之道
人工智能·人机交互·蓝牙键盘·蓝牙鼠标·蓝牙hid