Day40 早停策略和模型权重的保存

@浙大疏锦行

作业:对信贷数据集进行训练后保持权重,后继续训练50次,采取早停策略

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")


# 检查GPU是否可用,优先使用GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 若有多个GPU,可指定具体GPU,例如cuda:1
# 验证GPU是否真的在使用(可选)
if torch.cuda.is_available():
    print(f"GPU名称: {torch.cuda.get_device_name(0)}")
    torch.cuda.empty_cache()  # 清空GPU缓存

# 加载信贷数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 归一化数据
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 转换为PyTorch张量并强制移至指定设备(GPU/CPU)
X_train = torch.FloatTensor(X_train).to(device, non_blocking=True)
y_train = torch.LongTensor(y_train).to(device, non_blocking=True)
X_test = torch.FloatTensor(X_test).to(device, non_blocking=True)
y_test = torch.LongTensor(y_test).to(device, non_blocking=True)


class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(4, 10)  # 输入层(信贷数据集需修改输入维度)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)  # 输出层(信贷数据集需修改输出维度)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# 实例化模型并移至GPU
model = MLP().to(device)


criterion = nn.CrossEntropyLoss()  # 分类损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 优化器

# 首次训练参数
first_train_epochs = 20000
train_losses = []  # 首次训练损失
test_losses = []
epochs = []

# 早停参数(首次训练和继续训练共用相同策略)
best_test_loss = float('inf')
best_epoch = 0
patience = 50
counter = 0
early_stopped = False

print("\n===== 开始首次训练 =====")
start_time = time.time()

with tqdm(total=first_train_epochs, desc="首次训练进度", unit="epoch") as pbar:
    for epoch in range(first_train_epochs):
        model.train()
        # 前向传播
        outputs = model(X_train)
        train_loss = criterion(outputs, y_train)

        # 反向传播和优化
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # 每200轮记录损失并检查早停
        if (epoch + 1) % 200 == 0:
            model.eval()
            with torch.no_grad():
                test_outputs = model(X_test)
                test_loss = criterion(test_outputs, y_test)
            
            train_losses.append(train_loss.item())
            test_losses.append(test_loss.item())
            epochs.append(epoch + 1)
            
            # 更新进度条
            pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})
            
            # 早停逻辑
            if test_loss.item() < best_test_loss:
                best_test_loss = test_loss.item()
                best_epoch = epoch + 1
                counter = 0
                # 保存最佳模型
                torch.save(model.state_dict(), 'best_model.pth')
            else:
                counter += 1
                if counter >= patience:
                    print(f"\n首次训练早停触发!在第{epoch+1}轮,测试集损失已有{patience}轮未改善。")
                    print(f"最佳测试集损失出现在第{best_epoch}轮,损失值为{best_test_loss:.4f}")
                    early_stopped = True
                    break

        # 更新进度条
        if (epoch + 1) % 1000 == 0:
            pbar.update(1000)
    
    # 补全进度条
    if pbar.n < first_train_epochs:
        pbar.update(first_train_epochs - pbar.n)

# 保存首次训练结束后的模型权重(核心修改点1)
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch + 1,
    'best_loss': best_test_loss
}, 'trained_model.pth')
print(f"\n首次训练完成,权重已保存至 trained_model.pth")
print(f"首次训练总耗时: {time.time() - start_time:.2f} 秒")


print("\n===== 加载权重并开始继续训练 =====")
# 加载保存的权重(核心修改点2)
checkpoint = torch.load('trained_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"加载了首次训练至第{checkpoint['epoch']}轮的权重,最佳损失: {checkpoint['best_loss']:.4f}")

# 重新初始化优化器(核心修改点3:继续训练必须重置优化器)
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 若需要延续优化器状态,可取消下面注释(视场景选择)
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# 继续训练的参数
continue_train_epochs = 50  # 目标继续训练50轮
continue_train_losses = []  # 继续训练损失
continue_test_losses = []
continue_epochs = []

# 重置早停参数(针对继续训练)
continue_best_loss = checkpoint['best_loss']
continue_counter = 0
continue_early_stop = False

start_continue_time = time.time()

with tqdm(total=continue_train_epochs, desc="继续训练进度", unit="epoch") as pbar:
    for epoch in range(continue_train_epochs):
        model.train()
        # 前向传播
        outputs = model(X_train)
        train_loss = criterion(outputs, y_train)

        # 反向传播和优化
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # 每1轮就检查损失和早停(继续训练轮数少,无需间隔)
        model.eval()
        with torch.no_grad():
            test_outputs = model(X_test)
            test_loss = criterion(test_outputs, y_test)
        
        continue_train_losses.append(train_loss.item())
        continue_test_losses.append(test_loss.item())
        continue_epochs.append(epoch + 1)
        
        # 更新进度条
        pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})
        pbar.update(1)

        # 继续训练的早停逻辑
        if test_loss.item() < continue_best_loss:
            continue_best_loss = test_loss.item()
            continue_counter = 0
            # 保存继续训练后的最佳模型
            torch.save(model.state_dict(), 'continue_best_model.pth')
        else:
            continue_counter += 1
            if continue_counter >= patience:
                print(f"\n继续训练早停触发!在第{epoch+1}轮,测试集损失已有{patience}轮未改善。")
                print(f"继续训练最佳损失: {continue_best_loss:.4f}")
                continue_early_stop = True
                break

print(f"继续训练完成,总耗时: {time.time() - start_continue_time:.2f} 秒")
print(f"继续训练实际轮数: {len(continue_epochs)} 轮(早停触发则少于50轮)")


print("\n===== 最终模型评估 =====")
model.load_state_dict(torch.load('continue_best_model.pth', map_location=device))
model.eval()
with torch.no_grad():
    outputs = model(X_test)
    _, predicted = torch.max(outputs, 1)
    correct = (predicted == y_test).sum().item()
    accuracy = correct / y_test.size(0)
    print(f'测试集最终准确率: {accuracy * 100:.2f}%')

# ====================== 8. 可视化 ======================
plt.figure(figsize=(12, 6))

# 绘制首次训练损失
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('首次训练损失曲线')
plt.legend()
plt.grid(True)

# 绘制继续训练损失
plt.subplot(1, 2, 2)
plt.plot(continue_epochs, continue_train_losses, label='Train Loss')
plt.plot(continue_epochs, continue_test_losses, label='Test Loss')
plt.xlabel('Continue Epoch')
plt.ylabel('Loss')
plt.title('继续训练50轮损失曲线')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()
相关推荐
冷雨夜中漫步6 小时前
Python快速入门(6)——for/if/while语句
开发语言·经验分享·笔记·python
郝学胜-神的一滴7 小时前
深入解析Python字典的继承关系:从abc模块看设计之美
网络·数据结构·python·程序人生
百锦再7 小时前
Reactive编程入门:Project Reactor 深度指南
前端·javascript·python·react.js·django·前端框架·reactjs
喵手8 小时前
Python爬虫实战:旅游数据采集实战 - 携程&去哪儿酒店机票价格监控完整方案(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·采集结果csv导出·旅游数据采集·携程/去哪儿酒店机票价格监控
2501_944934738 小时前
高职大数据技术专业,CDA和Python认证优先考哪个?
大数据·开发语言·python
helloworldandy9 小时前
使用Pandas进行数据分析:从数据清洗到可视化
jvm·数据库·python
肖永威10 小时前
macOS环境安装/卸载python实践笔记
笔记·python·macos
TechWJ10 小时前
PyPTO编程范式深度解读:让NPU开发像写Python一样简单
开发语言·python·cann·pypto
枷锁—sha10 小时前
【SRC】SQL注入WAF 绕过应对策略(二)
网络·数据库·python·sql·安全·网络安全
abluckyboy11 小时前
Java 实现求 n 的 n^n 次方的最后一位数字
java·python·算法