DAY 37 早停策略和模型权重的保存

  1. 早停策略

python 复制代码
import torch.nn as nn
import torch.optim as optim
import time
import matplotlib.pyplot as plt
from tqdm import tqdm

# Define the MLP model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(X_train.shape[1], 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 2)  # Binary classification

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# Instantiate the model
model = MLP().to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training settings
num_epochs = 20000
early_stop_patience = 50  # Epochs to wait for improvement
best_loss = float('inf')
patience_counter = 0
best_epoch = 0
early_stopped = False

# Track losses
train_losses = []
test_losses = []
epochs = []

# Start training
start_time = time.time()
with tqdm(total=num_epochs, desc="Training Progress", unit="epoch") as pbar:
    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        outputs = model(X_train)
        train_loss = criterion(outputs, y_train)
        train_loss.backward()
        optimizer.step()
        
        # Evaluate on the test set
        model.eval()
        with torch.no_grad():
            outputs_test = model(X_test)
            test_loss = criterion(outputs_test, y_test)
        
        if (epoch + 1) % 200 == 0:
            train_losses.append(train_loss.item())
            test_losses.append(test_loss.item())
            epochs.append(epoch + 1)
            
            # Early stopping check
            if test_loss.item() < best_loss:  # If current test loss is better than the best
                best_loss = test_loss.item()  # Update best loss
                best_epoch = epoch + 1  # Update best epoch
                patience_counter = 0  # Reset counter
                # Save the best model
                torch.save(model.state_dict(), 'best_model.pth')
            else:
                patience_counter += 1
                if patience_counter >= early_stop_patience:
                    print(f"Early stopping triggered! No improvement for {early_stop_patience} epochs.")
                    print(f"Best test loss was at epoch {best_epoch} with a loss of {best_loss:.4f}")
                    early_stopped = True
                    break  # Stop the training loop
            
            # Update the progress bar
            pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})
        
        # Update progress bar every 1000 epochs
        if (epoch + 1) % 1000 == 0:
            pbar.update(1000)

# Ensure progress bar reaches 100%
if pbar.n < num_epochs:
    pbar.update(num_epochs - pbar.n)

time_all = time.time() - start_time  # Calculate total training time
print(f'Training time: {time_all:.2f} seconds')

# If early stopping occurred, load the best model
if early_stopped:
    print(f"Loading best model from epoch {best_epoch} for final evaluation...")
    model.load_state_dict(torch.load('best_model.pth'))

# Continue training for 50 more epochs after loading the best model
num_extra_epochs = 50
for epoch in range(num_extra_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train)
    train_loss = criterion(outputs, y_train)
    train_loss.backward()
    optimizer.step()
    
    # Evaluate on the test set
    model.eval()
    with torch.no_grad():
        outputs_test = model(X_test)
        test_loss = criterion(outputs_test, y_test)
    
    train_losses.append(train_loss.item())
    test_losses.append(test_loss.item())
    epochs.append(num_epochs + epoch + 1)

    # Print progress for the extra epochs
    print(f"Epoch {num_epochs + epoch + 1}: Train Loss = {train_loss.item():.4f}, Test Loss = {test_loss.item():.4f}")

# Plot the loss curves
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()

# Evaluate final accuracy on the test set
model.eval()
with torch.no_grad():
    outputs = model(X_test)
    _, predicted = torch.max(outputs, 1)
    correct = (predicted == y_test).sum().item()
    accuracy = correct / y_test.size(0)
    print(f'Test Accuracy: {accuracy * 100:.2f}%')

@浙大疏锦行

相关推荐
西柚小萌新1 分钟前
【深入浅出PyTorch】--上采样+下采样
人工智能·pytorch·python
丁学文武28 分钟前
大语言模型(LLM)是“预制菜”? 从应用到底层原理,在到中央厨房的深度解析
人工智能·语言模型·自然语言处理·大语言模型·大模型应用·预制菜
fie888933 分钟前
基于MATLAB的声呐图像特征提取与显示
开发语言·人工智能
未来之窗软件服务1 小时前
自己写算法(九)网页数字动画函数——东方仙盟化神期
前端·javascript·算法·仙盟创梦ide·东方仙盟·东方仙盟算法
豐儀麟阁贵1 小时前
基本数据类型
java·算法
文火冰糖的硅基工坊2 小时前
[嵌入式系统-100]:常见的IoT(物联网)开发板
人工智能·物联网·架构
刘晓倩2 小时前
实战任务二:用扣子空间通过任务提示词制作精美PPT
人工智能
shut up2 小时前
LangChain - 如何使用阿里云百炼平台的Qwen-plus模型构建一个桌面文件查询AI助手 - 超详细
人工智能·python·langchain·智能体
Hy行者勇哥2 小时前
公司全场景运营中 PPT 的类型、功能与作用详解
大数据·人工智能
FIN66683 小时前
昂瑞微:实现精准突破,攻坚射频“卡脖子”难题
前端·人工智能·安全·前端框架·信息与通信