Python day37

@浙大疏锦行 python day37.

内容:

  • 保存模型只需要保存模型的参数即可,使用的时候直接构建模型再导入参数即可
python 复制代码
# 保存模型参数
torch.save(model.state_dict(), "model_weights.pth")

# 加载参数(需先定义模型结构)
model = MLP()  # 初始化与训练时相同的模型结构
model.load_state_dict(torch.load("model_weights.pth"))
# model.eval()  # 切换至推理模式(可选)
  • 也可以同时保存模型 + 参数
python 复制代码
# 保存整个模型
torch.save(model, "full_model.pth")

# 加载模型(无需提前定义类,但需确保环境一致)
model = torch.load("full_model.pth")
model.eval()  # 切换至推理模式(可选)
  • 保存训练状态,用于在训练过程中保存中间状态
python 复制代码
# # 保存训练状态
# checkpoint = {
#     "model_state_dict": model.state_dict(),
#     "optimizer_state_dict": optimizer.state_dict(),
#     "epoch": epoch,
#     "loss": best_loss,
# }
# torch.save(checkpoint, "checkpoint.pth")

# # 加载并续训
# model = MLP()
# optimizer = torch.optim.Adam(model.parameters())
# checkpoint = torch.load("checkpoint.pth")

# model.load_state_dict(checkpoint["model_state_dict"])
# optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# start_epoch = checkpoint["epoch"] + 1  # 从下一轮开始训练
# best_loss = checkpoint["loss"]

# # 继续训练循环
# for epoch in range(start_epoch, num_epochs):
#     train(model, optimizer, ...)
  • 对于跨框架保存时,需要保存为onnx文件
  • 早停策略:在训练过程中,训练集的损失不断下降,但是验证集或者测试集的损失反而上升,此时这种情况称为过拟合;针对这种情况,我们可以引入早停策略,用于提前终止训练;
  • 可以在训练达到某一个epoch次数时进行一次验证,将此次结果和历史最佳结果进行比较,如果结果更好则保留,如果不好则会使临时记录变量自增,一旦连续自增到预设值,直接停止。这种设计是为了防止出现震荡波动情况,避免偶然性
  • 可以结合上面的保存断点 breakpoint
python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt
from tqdm import tqdm  # 导入tqdm库用于进度条显示
import warnings
warnings.filterwarnings("ignore")  # 忽略警告信息

# 设置GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 归一化数据
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 将数据转换为PyTorch张量并移至GPU
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.LongTensor(y_test).to(device)

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(4, 10)  # 输入层到隐藏层
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)  # 隐藏层到输出层

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# 实例化模型并移至GPU
model = MLP().to(device)

# 分类问题使用交叉熵损失函数
criterion = nn.CrossEntropyLoss()

# 使用随机梯度下降优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 20000  # 训练的轮数

# 用于存储每200个epoch的损失值和对应的epoch数
train_losses = []  # 存储训练集损失
test_losses = []   # 存储测试集损失
epochs = []

# ===== 新增早停相关参数 =====
best_test_loss = float('inf')  # 记录最佳测试集损失
best_epoch = 0                 # 记录最佳epoch
patience = 50                # 早停耐心值(连续多少轮测试集损失未改善时停止训练)
counter = 0                    # 早停计数器
early_stopped = False          # 是否早停标志
# ==========================

start_time = time.time()  # 记录开始时间

# 创建tqdm进度条
with tqdm(total=num_epochs, desc="训练进度", unit="epoch") as pbar:
    # 训练模型
    for epoch in range(num_epochs):
        # 前向传播
        outputs = model(X_train)  # 隐式调用forward函数
        train_loss = criterion(outputs, y_train)

        # 反向传播和优化
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # 记录损失值并更新进度条
        if (epoch + 1) % 200 == 0:
            # 计算测试集损失
            model.eval()
            with torch.no_grad():
                test_outputs = model(X_test)
                test_loss = criterion(test_outputs, y_test)
            model.train()
            
            train_losses.append(train_loss.item())
            test_losses.append(test_loss.item())
            epochs.append(epoch + 1)
            
            # 更新进度条的描述信息
            pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})
            
            # ===== 新增早停逻辑 =====
            if test_loss.item() < best_test_loss: # 如果当前测试集损失小于最佳损失
                best_test_loss = test_loss.item() # 更新最佳损失
                best_epoch = epoch + 1 # 更新最佳epoch
                counter = 0 # 重置计数器
                # 保存最佳模型
                torch.save(model.state_dict(), 'best_model.pth')
            else:
                counter += 1
                if counter >= patience:
                    print(f"早停触发!在第{epoch+1}轮,测试集损失已有{patience}轮未改善。")
                    print(f"最佳测试集损失出现在第{best_epoch}轮,损失值为{best_test_loss:.4f}")
                    early_stopped = True
                    break  # 终止训练循环
            # ======================

        # 每1000个epoch更新一次进度条
        if (epoch + 1) % 1000 == 0:
            pbar.update(1000)  # 更新进度条

    # 确保进度条达到100%
    if pbar.n < num_epochs:
        pbar.update(num_epochs - pbar.n)  # 计算剩余的进度并更新

time_all = time.time() - start_time  # 计算训练时间
print(f'Training time: {time_all:.2f} seconds')

# ===== 新增:加载最佳模型用于最终评估 =====
if early_stopped:
    print(f"加载第{best_epoch}轮的最佳模型进行最终评估...")
    model.load_state_dict(torch.load('best_model.pth'))
# ================================

# 可视化损失曲线
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()

# 在测试集上评估模型
model.eval()
with torch.no_grad():
    outputs = model(X_test)
    _, predicted = torch.max(outputs, 1)
    correct = (predicted == y_test).sum().item()
    accuracy = correct / y_test.size(0)
    print(f'测试集准确率: {accuracy * 100:.2f}%')    
相关推荐
xwill*几秒前
wandb的使用方法,以navrl为例
开发语言·python·深度学习
紧固件研究社几秒前
从标准件到复杂异形件,紧固件设备如何赋能制造升级
人工智能·制造·紧固件
木头左几秒前
贝叶斯深度学习在指数期权风险价值VaR估计中的实现与应用
人工智能·深度学习
反向跟单策略1 分钟前
期货反向跟单—高频换人能够提高跟单效率?
大数据·人工智能·学习·数据分析·区块链
哎吆我呸1 分钟前
Android studio 安装Claude Code GUI 插件报错无法找到Node.js解决方案
人工智能
咕噜企业分发小米2 分钟前
独立IP服务器有哪些常见的应用场景?
人工智能·阿里云·云计算
rgeshfgreh6 分钟前
解决Windows系统Python命令无效问题
python
测试者家园7 分钟前
AI 智能体如何构建模拟真实用户行为的复杂负载场景?
人工智能·压力测试·性能测试·智能体·用户行为·智能化测试·软件开发和测试
MF_AI8 分钟前
苹果病害检测识别数据集:1w+图像,5类,yolo标注
图像处理·人工智能·深度学习·yolo·计算机视觉
jinglong.zha11 分钟前
AScript游戏进阶课程 - 实战课表(0基础小白从入门到精通系列课程)
python·自动化·懒人精灵·ascript·游戏脚本