Day40 早停策略和模型权重的保存

@浙大疏锦行

作业:对信贷数据集进行训练后保持权重,后继续训练50次,采取早停策略

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")


# 检查GPU是否可用,优先使用GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 若有多个GPU,可指定具体GPU,例如cuda:1
# 验证GPU是否真的在使用(可选)
if torch.cuda.is_available():
    print(f"GPU名称: {torch.cuda.get_device_name(0)}")
    torch.cuda.empty_cache()  # 清空GPU缓存

# 加载信贷数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 归一化数据
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 转换为PyTorch张量并强制移至指定设备(GPU/CPU)
X_train = torch.FloatTensor(X_train).to(device, non_blocking=True)
y_train = torch.LongTensor(y_train).to(device, non_blocking=True)
X_test = torch.FloatTensor(X_test).to(device, non_blocking=True)
y_test = torch.LongTensor(y_test).to(device, non_blocking=True)


class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(4, 10)  # 输入层(信贷数据集需修改输入维度)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)  # 输出层(信贷数据集需修改输出维度)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# 实例化模型并移至GPU
model = MLP().to(device)


criterion = nn.CrossEntropyLoss()  # 分类损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 优化器

# 首次训练参数
first_train_epochs = 20000
train_losses = []  # 首次训练损失
test_losses = []
epochs = []

# 早停参数(首次训练和继续训练共用相同策略)
best_test_loss = float('inf')
best_epoch = 0
patience = 50
counter = 0
early_stopped = False

print("\n===== 开始首次训练 =====")
start_time = time.time()

with tqdm(total=first_train_epochs, desc="首次训练进度", unit="epoch") as pbar:
    for epoch in range(first_train_epochs):
        model.train()
        # 前向传播
        outputs = model(X_train)
        train_loss = criterion(outputs, y_train)

        # 反向传播和优化
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # 每200轮记录损失并检查早停
        if (epoch + 1) % 200 == 0:
            model.eval()
            with torch.no_grad():
                test_outputs = model(X_test)
                test_loss = criterion(test_outputs, y_test)
            
            train_losses.append(train_loss.item())
            test_losses.append(test_loss.item())
            epochs.append(epoch + 1)
            
            # 更新进度条
            pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})
            
            # 早停逻辑
            if test_loss.item() < best_test_loss:
                best_test_loss = test_loss.item()
                best_epoch = epoch + 1
                counter = 0
                # 保存最佳模型
                torch.save(model.state_dict(), 'best_model.pth')
            else:
                counter += 1
                if counter >= patience:
                    print(f"\n首次训练早停触发!在第{epoch+1}轮,测试集损失已有{patience}轮未改善。")
                    print(f"最佳测试集损失出现在第{best_epoch}轮,损失值为{best_test_loss:.4f}")
                    early_stopped = True
                    break

        # 更新进度条
        if (epoch + 1) % 1000 == 0:
            pbar.update(1000)
    
    # 补全进度条
    if pbar.n < first_train_epochs:
        pbar.update(first_train_epochs - pbar.n)

# 保存首次训练结束后的模型权重(核心修改点1)
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch + 1,
    'best_loss': best_test_loss
}, 'trained_model.pth')
print(f"\n首次训练完成,权重已保存至 trained_model.pth")
print(f"首次训练总耗时: {time.time() - start_time:.2f} 秒")


print("\n===== 加载权重并开始继续训练 =====")
# 加载保存的权重(核心修改点2)
checkpoint = torch.load('trained_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"加载了首次训练至第{checkpoint['epoch']}轮的权重,最佳损失: {checkpoint['best_loss']:.4f}")

# 重新初始化优化器(核心修改点3:继续训练必须重置优化器)
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 若需要延续优化器状态,可取消下面注释(视场景选择)
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# 继续训练的参数
continue_train_epochs = 50  # 目标继续训练50轮
continue_train_losses = []  # 继续训练损失
continue_test_losses = []
continue_epochs = []

# 重置早停参数(针对继续训练)
continue_best_loss = checkpoint['best_loss']
continue_counter = 0
continue_early_stop = False

start_continue_time = time.time()

with tqdm(total=continue_train_epochs, desc="继续训练进度", unit="epoch") as pbar:
    for epoch in range(continue_train_epochs):
        model.train()
        # 前向传播
        outputs = model(X_train)
        train_loss = criterion(outputs, y_train)

        # 反向传播和优化
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # 每1轮就检查损失和早停(继续训练轮数少,无需间隔)
        model.eval()
        with torch.no_grad():
            test_outputs = model(X_test)
            test_loss = criterion(test_outputs, y_test)
        
        continue_train_losses.append(train_loss.item())
        continue_test_losses.append(test_loss.item())
        continue_epochs.append(epoch + 1)
        
        # 更新进度条
        pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})
        pbar.update(1)

        # 继续训练的早停逻辑
        if test_loss.item() < continue_best_loss:
            continue_best_loss = test_loss.item()
            continue_counter = 0
            # 保存继续训练后的最佳模型
            torch.save(model.state_dict(), 'continue_best_model.pth')
        else:
            continue_counter += 1
            if continue_counter >= patience:
                print(f"\n继续训练早停触发!在第{epoch+1}轮,测试集损失已有{patience}轮未改善。")
                print(f"继续训练最佳损失: {continue_best_loss:.4f}")
                continue_early_stop = True
                break

print(f"继续训练完成,总耗时: {time.time() - start_continue_time:.2f} 秒")
print(f"继续训练实际轮数: {len(continue_epochs)} 轮(早停触发则少于50轮)")


print("\n===== 最终模型评估 =====")
model.load_state_dict(torch.load('continue_best_model.pth', map_location=device))
model.eval()
with torch.no_grad():
    outputs = model(X_test)
    _, predicted = torch.max(outputs, 1)
    correct = (predicted == y_test).sum().item()
    accuracy = correct / y_test.size(0)
    print(f'测试集最终准确率: {accuracy * 100:.2f}%')

# ====================== 8. 可视化 ======================
plt.figure(figsize=(12, 6))

# 绘制首次训练损失
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('首次训练损失曲线')
plt.legend()
plt.grid(True)

# 绘制继续训练损失
plt.subplot(1, 2, 2)
plt.plot(continue_epochs, continue_train_losses, label='Train Loss')
plt.plot(continue_epochs, continue_test_losses, label='Test Loss')
plt.xlabel('Continue Epoch')
plt.ylabel('Loss')
plt.title('继续训练50轮损失曲线')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()
相关推荐
好大哥呀2 小时前
如何在手机上运行Python程序
开发语言·python·智能手机
_codemonster2 小时前
手语识别及翻译项目实战系列(一)环境准备
人工智能·python·计算机视觉
毕设源码-钟学长2 小时前
【开题答辩全过程】以 基于Python的新闻热点舆情分析系统为例,包含答辩的问题和答案
开发语言·python
2401_841495642 小时前
【Python高级编程】单词统计与查找分析工具
数据结构·python·算法·gui·排序·单词统计·查找
XerCis2 小时前
Python代码检查与格式化工具Ruff
开发语言·python
西红市杰出青年2 小时前
asyncio.gather 内部原理与运行机制(详解)
网络·python·异步
70asunflower2 小时前
torch.manual_seed()介绍
人工智能·pytorch·python
西红市杰出青年2 小时前
Playwright 的 BrowserContext 与 Page:原理与实践指南
python
Tianwen_Burning2 小时前
pycharm下配置halcon
python