以下示例代码展示了如何在 PyTorch 中保存和加载模型状态和优化器状态,以便训练中断后可以继续训练。
1. 保存模型和优化器状态
假设模型训练了一段时间后,我们想要保存模型和优化器的状态,确保下次可以从同一位置继续训练。
2. 加载模型和优化器状态
加载保存的状态后,可以从保存的 epoch
继续训练。
示例代码
import torch
import torch.nn as nn
import torch.optim as optim
# 假设我们定义了一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
# 创建模型和优化器
model = SimpleModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 模拟的训练代码片段
num_epochs = 20
checkpoint_path = "model_checkpoint.pth"
# 保存模型和优化器状态
def save_checkpoint(epoch, model, optimizer, path):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, path)
print(f"Checkpoint saved at epoch {epoch}.")
# 加载模型和优化器状态
def load_checkpoint(model, optimizer, path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
print(f"Checkpoint loaded, starting at epoch {start_epoch}.")
return start_epoch
# 尝试加载已保存的检查点
try:
start_epoch = load_checkpoint(model, optimizer, checkpoint_path)
except FileNotFoundError:
start_epoch = 0
print("No checkpoint found, starting training from scratch.")
# 继续训练
for epoch in range(start_epoch, num_epochs):
# 模拟训练步骤
# output = model(input) ...
# loss = loss_fn(output, target) ...
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
print(f"Epoch {epoch+1}/{num_epochs} completed.")
# 每 5 个 epoch 保存一次模型状态
if (epoch + 1) % 5 == 0:
save_checkpoint(epoch + 1, model, optimizer, checkpoint_path)
解释
- 保存 :
save_checkpoint
函数会在指定的 epoch 保存模型和优化器状态。 - 加载 :
load_checkpoint
函数会加载模型和优化器状态,并返回上次的epoch
,以便继续训练。 - 训练控制 :
start_epoch
变量控制了是否继续从之前的检查点继续训练,确保模型在中断后可以接着训练。