pytorch 实战【以图像处理为例】

pytorch 实战【以图像处理为例】


训练过程中保存模型

在PyTorch中,模型训练过程中保存模型通常涉及以下几个步骤:

  1. 保存整个模型 :

    使用 torch.save 函数,你可以保存整个模型,包括模型的结构和参数。

    python 复制代码
    torch.save(model, 'model.pth')

    加载模型时,使用 torch.load 函数。

    python 复制代码
    model = torch.load('model.pth')
  2. 保存模型的参数 :

    这种方法通常更受欢迎,因为它只保存模型的参数,不保存模型的结构。这样,模型文件会比较小,并且在加载模型时可以更加灵活。

    python 复制代码
    torch.save(model.state_dict(), 'model_params.pth')

    加载模型时,首先创建模型的实例,然后加载参数。

    python 复制代码
    model = ModelClass()  # replace ModelClass with your model's class name
    model.load_state_dict(torch.load('model_params.pth'))
  3. 保存训练的检查点 :

    在训练过程中,除了保存模型或模型的参数,通常还会保存其他关键信息,例如优化器的状态、当前的epoch、最佳准确率等。这样,如果训练被中断,可以从检查点继续训练,而不是从头开始。

    python 复制代码
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        # ... any other relevant information
    }
    torch.save(checkpoint, 'checkpoint.pth')

    加载检查点时:

    python 复制代码
    checkpoint = torch.load('checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
  4. 在训练时定期保存模型 :

    通常,我们会在每个epoch结束时或在验证准确率提高时保存模型。这样,如果训练过程中出现任何问题,我们可以从最近的检查点恢复。

  • 保存检查点:

在训练循环中,你可能会在每个 epoch 结束时或在模型在验证集上达到新的最佳性能时保存检查点:

python 复制代码
# 假设以下变量已经定义:
# model: 你的模型
# optimizer: 你使用的优化器
# epoch: 当前的epoch
# loss: 最近的loss值
# best_accuracy: 迄今为止在验证集上的最佳准确率

# 在每个 epoch 结束时或在验证准确率提高时:
if current_accuracy > best_accuracy:  # current_accuracy是这个epoch在验证集上的准确率
    best_accuracy = current_accuracy
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'best_accuracy': best_accuracy
    }
    torch.save(checkpoint, 'best_checkpoint.pth')
  • 加载检查点:

当你希望从检查点继续训练或评估模型时,可以使用以下代码来加载检查点:

python 复制代码
# 假设以下变量已经定义:
# model: 你的模型 (需要先实例化)
# optimizer: 你使用的优化器 (需要先实例化)

checkpoint = torch.load('best_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
best_accuracy = checkpoint['best_accuracy']

# 如果继续训练,可以从上一个 epoch 开始
model.train()

这样,即使训练过程中断,你也可以从上次停止的地方继续,而不是重新开始。

  1. 保存在不同设备上的模型 :
    如果你在GPU上训练模型,但希望在CPU上加载模型,可以使用以下方式:

    python 复制代码
    torch.save(model.state_dict(), 'model_params.pth')
    # Loading on CPU
    model.load_state_dict(torch.load('model_params.pth', map_location=torch.device('cpu')))

总之,保存模型是训练深度学习模型的关键部分,它允许我们在训练中断时恢复,或在训练完成后部署模型。

具体在训练中断如何继续

如果训练过程中断并且你已经定期保存了检查点,那么你可以从最近的检查点恢复。以下是一个基本流程,描述如何在训练中断后从上次停止的地方继续:

  1. 加载检查点:

    在开始训练之前,首先加载保存的检查点。

    python 复制代码
    checkpoint = torch.load('best_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_accuracy = checkpoint.get('best_accuracy', -1)  # 默认为-1,假设你保存了这个值
  2. 恢复训练:

    使用从检查点中加载的 start_epoch 作为起始点,并从那里开始你的训练循环。

    python 复制代码
    for epoch in range(start_epoch, total_epochs):
        # 训练代码...
        train_one_epoch()
    
        # 验证代码...
        current_accuracy = validate()
    
        # 保存新的检查点,如果模型在验证集上有更好的性能
        if current_accuracy > best_accuracy:
            best_accuracy = current_accuracy
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_accuracy': best_accuracy
                # ... 你可以添加其他信息,如loss等
            }
            torch.save(checkpoint, 'best_checkpoint.pth')
  3. 注意点:

    • 学习率调整 :如果你使用了学习率调度器,例如 ReduceLROnPlateauStepLR,那么你也应该保存和加载它的状态。这样可以确保学习率调整策略在中断后正确地继续。
    • 随机种子:为了确保训练的可复现性,如果你设置了随机种子,那么在恢复训练之前,你可能需要重新设置相同的随机种子。

通过这种方式,你可以在训练中断后恢复并从上次停止的地方继续,而不会丢失任何进度。

相关推荐
埃菲尔铁塔_CV算法24 分钟前
人工智能图像算法:开启视觉新时代的钥匙
人工智能·算法
EasyCVR24 分钟前
EHOME视频平台EasyCVR视频融合平台使用OBS进行RTMP推流,WebRTC播放出现抖动、卡顿如何解决?
人工智能·算法·ffmpeg·音视频·webrtc·监控视频接入
打羽毛球吗️31 分钟前
机器学习中的两种主要思路:数据驱动与模型驱动
人工智能·机器学习
好喜欢吃红柚子1 小时前
万字长文解读空间、通道注意力机制机制和超详细代码逐行分析(SE,CBAM,SGE,CA,ECA,TA)
人工智能·pytorch·python·计算机视觉·cnn
小馒头学python1 小时前
机器学习是什么?AIGC又是什么?机器学习与AIGC未来科技的双引擎
人工智能·python·机器学习
神奇夜光杯1 小时前
Python酷库之旅-第三方库Pandas(202)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
正义的彬彬侠1 小时前
《XGBoost算法的原理推导》12-14决策树复杂度的正则化项 公式解析
人工智能·决策树·机器学习·集成学习·boosting·xgboost
Debroon1 小时前
RuleAlign 规则对齐框架:将医生的诊断规则形式化并注入模型,无需额外人工标注的自动对齐方法
人工智能
羊小猪~~1 小时前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
AI小杨1 小时前
【车道线检测】一、传统车道线检测:基于霍夫变换的车道线检测史诗级详细教程
人工智能·opencv·计算机视觉·霍夫变换·车道线检测