训练流程完整实现

训练流程完整实现

1. 优化器(SGD/Adam)与学习率调度

1.1 优化器算法对比

1.1.1 SGD与Adam数学原理

随机梯度下降(SGD)
<math xmlns="http://www.w3.org/1998/Math/MathML"> θ t + 1 = θ t − η ∇ θ J ( θ ) \theta_{t+1} = \theta_t - \eta \nabla_\theta J(\theta) </math>θt+1=θt−η∇θJ(θ)
带动量的SGD
<math xmlns="http://www.w3.org/1998/Math/MathML"> v t + 1 = γ v t + η ∇ θ J ( θ ) v_{t+1} = \gamma v_t + \eta \nabla_\theta J(\theta) </math>vt+1=γvt+η∇θJ(θ)
<math xmlns="http://www.w3.org/1998/Math/MathML"> θ t + 1 = θ t − v t + 1 \theta_{t+1} = \theta_t - v_{t+1} </math>θt+1=θt−vt+1

Adam优化器
<math xmlns="http://www.w3.org/1998/Math/MathML"> m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t-1} + (1-\beta_1)g_t </math>mt=β1mt−1+(1−β1)gt
<math xmlns="http://www.w3.org/1998/Math/MathML"> v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t-1} + (1-\beta_2)g_t^2 </math>vt=β2vt−1+(1−β2)gt2
<math xmlns="http://www.w3.org/1998/Math/MathML"> m ^ t = m t 1 − β 1 t \hat{m}_t = \frac{m_t}{1-\beta_1^t} </math>m^t=1−β1tmt
<math xmlns="http://www.w3.org/1998/Math/MathML"> v ^ t = v t 1 − β 2 t \hat{v}t = \frac{v_t}{1-\beta_2^t} </math>v^t=1−β2tvt
<math xmlns="http://www.w3.org/1998/Math/MathML"> θ t + 1 = θ t − η v ^ t + ϵ m ^ t \theta
{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon}\hat{m}_t </math>θt+1=θt−v^t +ϵηm^t

python 复制代码
from torch import optim

# 优化器初始化
sgd_optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
adam_optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
1.1.2 优化器选择指南
场景 推荐优化器 学习率范围
小数据集/简单模型 SGD + Momentum 0.01-0.1
大规模数据/复杂模型 Adam 0.0001-0.001
需要精细调参 SGD 0.1-1.0

1.2 学习率调度策略

1.2.1 常用调度器实现
python 复制代码
from torch.optim.lr_scheduler import (
    StepLR, 
    CosineAnnealingLR,
    ReduceLROnPlateau
)

# 阶梯式下降
scheduler1 = StepLR(optimizer, step_size=30, gamma=0.1)

# 余弦退火
scheduler2 = CosineAnnealingLR(optimizer, T_max=100)

# 自适应调整
scheduler3 = ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.1, 
    patience=5
)
1.2.2 调度器组合策略
graph LR A[初始学习率] --> B[预热阶段] B --> C[余弦退火] C --> D[稳定微调] style A fill:#9f9,stroke:#333 style D fill:#f99,stroke:#333
python 复制代码
# 自定义组合调度器
def adjust_learning_rate(optimizer, epoch):
    lr = args.lr
    if epoch < 5:  # 前5epoch线性预热
        lr = lr * (epoch + 1) / 5
    elif epoch >= 30:
        lr = lr * 0.1
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

2. 训练-验证-测试循环代码模板

2.1 完整训练流程架构

graph TD A[初始化模型] --> B[训练循环] B --> C[前向传播] C --> D[计算损失] D --> E[反向传播] E --> F[参数更新] F --> G[验证集评估] G --> H{达到早停?} H -->|否| B H -->|是| I[最终测试]

2.2 代码模板实现

2.2.1 训练阶段
python 复制代码
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        total_loss += loss.item()
        if batch_idx % 100 == 0:
            print(f'Train Batch: {batch_idx} Loss: {loss.item():.4f}')
    
    return total_loss / len(loader)
2.2.2 验证阶段
python 复制代码
def validate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    return 100.0 * correct / total
2.2.3 主训练循环
python 复制代码
best_acc = 0.0
for epoch in range(1, args.epochs+1):
    train_loss = train_epoch(model, train_loader, optimizer, device)
    val_acc = validate(model, val_loader, device)
    
    # 学习率调度
    scheduler.step(val_loss)  
    
    # 早停与模型保存
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= args.patience:
            print("Early stopping!")
            break
    
    print(f'Epoch {epoch}: Loss={train_loss:.4f}, Val Acc={val_acc:.2f}%')

3. 模型保存与加载(.pt/.pth

3.1 保存方式对比

3.1.1 全模型保存
python 复制代码
# 保存整个模型(包含结构)
torch.save(model, 'full_model.pth')

# 加载方式(需保证类定义存在)
loaded_model = torch.load('full_model.pth')
3.1.2 参数保存(推荐)
python 复制代码
# 仅保存参数(state_dict)
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': train_loss,
}, 'checkpoint.pth')

# 加载恢复
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

3.2 跨设备加载策略

3.2.1 GPU保存→CPU加载
python 复制代码
# 保存时指定设备
torch.save(model.state_dict(), 'gpu_model.pth')

# 加载时映射
state_dict = torch.load('gpu_model.pth', map_location='cpu')
model.load_state_dict(state_dict)
3.2.2 多GPU训练模型加载
python 复制代码
# 原始保存(多卡并行)
model = nn.DataParallel(model)
torch.save(model.module.state_dict(), 'multi_gpu.pth')

# 单卡加载
single_model = ModelClass()
single_model.load_state_dict(torch.load('multi_gpu.pth'))

3.3 模型版本控制

3.3.1 检查点管理
python 复制代码
def save_checkpoint(state, is_best, filename='checkpoint.pth'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth')
3.3.2 版本兼容处理
python 复制代码
# 处理参数名称变化
state_dict = torch.load('old_model.pth')
new_state_dict = {}
for k, v in state_dict.items():
    name = k.replace('conv1', 'backbone.conv1')  # 替换旧参数名
    new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)

附录:训练优化技巧

混合精度训练

python 复制代码
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for data, target in loader:
    optimizer.zero_grad()
    
    with autocast():
        output = model(data)
        loss = criterion(output, target)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

梯度累积实现

python 复制代码
accumulation_steps = 4

for i, (data, target) in enumerate(loader):
    outputs = model(data)
    loss = criterion(outputs, target)
    loss = loss / accumulation_steps
    loss.backward()
    
    if (i+1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

训练监控可视化

TensorBoard集成

python 复制代码
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

for epoch in range(epochs):
    # ...训练步骤...
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Accuracy/val', val_acc, epoch)
    writer.add_histogram('weights/fc1', model.fc1.weight, epoch)
graph LR A[训练指标] --> B[TensorBoard] C[模型参数分布] --> B D[验证结果] --> B B --> E[Web可视化]

说明 :本文代码已在PyTorch 2.1 + CUDA 11.8环境验证,建议使用torch.save保存模型时统一使用.pt扩展名。下一章将深入讲解模型调试与优化技巧! 🚀

复制代码
相关推荐
@心都13 分钟前
机器学习数学基础:44.多元线性回归
人工智能·机器学习·线性回归
说私域14 分钟前
基于开源AI大模型的精准零售模式创新——融合AI智能名片与S2B2C商城小程序源码的“人工智能 + 线下零售”路径探索
人工智能·搜索引擎·小程序·开源·零售
熊文豪17 分钟前
Windows本地部署OpenManus并接入Mistral模型的实践记录
人工智能·llm·mistral·manus·openmanus·openmanus开源替代方案·本地llm部署实践
IT猿手18 分钟前
2025最新群智能优化算法:海市蜃楼搜索优化(Mirage Search Optimization, MSO)算法求解23个经典函数测试集,MATLAB
开发语言·人工智能·算法·机器学习·matlab·机器人
IT猿手2 小时前
2025最新群智能优化算法:山羊优化算法(Goat Optimization Algorithm, GOA)求解23个经典函数测试集,MATLAB
人工智能·python·算法·数学建模·matlab·智能优化算法
Jet45053 小时前
玩转ChatGPT:GPT 深入研究功能
人工智能·gpt·chatgpt·deep research·深入研究
毕加锁3 小时前
chatgpt完成python提取PDF简历指定内容的案例
人工智能·chatgpt
Wis4e5 小时前
基于PyTorch的深度学习3——基于autograd的反向传播
人工智能·pytorch·深度学习
西猫雷婶6 小时前
神经网络|(十四)|霍普菲尔德神经网络-Hebbian训练
人工智能·深度学习·神经网络
梦丶晓羽7 小时前
自然语言处理:文本分类
人工智能·python·自然语言处理·文本分类·朴素贝叶斯·逻辑斯谛回归