深入PyTorch模型的训练与可视化 —— 掌握迁移学习等模型训练效果提升的办法

目录

前言

什么是迁移学习

为什么迁移学习效果好

迁移学习训练流程

使用PyTorch加载预训练模型

查看模型结构

替换分类层

冻结特征提取层

微调训练(Fine-Tuning)

迁移学习训练示例

数据增强的重要性

常见数据增强方法

PyTorch实现数据增强

[Batch Normalization优化训练](#Batch Normalization优化训练)

Dropout防止过拟合

类别不平衡处理

类别权重训练

学习率调度器

[Early Stopping](#Early Stopping)

混合精度训练

集成学习提升效果

模型训练效果提升路线图

可视化训练过程

项目实战推荐策略

常见面试题

什么是迁移学习?

为什么迁移学习效果好?

什么是Fine-Tuning?

Dropout作用是什么?

BatchNorm作用是什么?

为什么要使用数据增强?

总结


在前面的文章中,我们已经学习了:

复制代码
Tensor数据处理

神经网络搭建

损失函数与优化器

模型训练流程

模型权重保存与加载

但是在实际项目中,很多同学都会遇到一个问题:

复制代码
模型能训练

但是效果不好

例如:

复制代码
准确率上不去

训练速度慢

过拟合严重

数据量不足

模型泛化能力差

事实上,深度学习项目中:

复制代码
网络结构只占成功因素的一部分

训练策略和优化技巧同样重要

很多工业级项目并不会从零开始训练模型,而是采用:

复制代码
迁移学习

数据增强

预训练模型

微调训练

学习率调优

集成学习

等方式来提升模型效果。

本文将系统讲解:

复制代码
迁移学习原理

预训练模型使用

模型微调

数据增强

类别不平衡处理

训练技巧优化

模型效果提升方案

什么是迁移学习

迁移学习(Transfer Learning)是深度学习领域最重要的技术之一。

核心思想:

复制代码
利用已有模型学到的知识

解决新的任务

举个例子:

复制代码
ImageNet拥有1000万级图片

训练得到ResNet模型

已经学会识别:

边缘

纹理

形状

物体特征

那么:

复制代码
猫狗识别

水果分类

工业缺陷检测

这些任务完全没必要重新学习基础特征。

直接利用已有知识即可。

这就是迁移学习。


为什么迁移学习效果好

传统训练方式:

复制代码
随机初始化参数

从0开始学习

存在问题:

复制代码
训练慢

需要大量数据

容易过拟合

迁移学习:

复制代码
加载预训练模型

继承已有特征提取能力

优势:

复制代码
收敛更快

准确率更高

数据需求更少

特别适用于:

复制代码
中小规模数据集

迁移学习训练流程

整体流程如下:

例如:

复制代码
ResNet50

原本识别1000类

↓

修改最后一层

↓

识别5类水果

使用PyTorch加载预训练模型

PyTorch提供了大量预训练模型。

导入:

python 复制代码
import torchvision.models as models

加载ResNet50:

python 复制代码
model = models.resnet50(
    weights=models.ResNet50_Weights.DEFAULT
)

此时:

复制代码
模型参数已经训练完成

无需重新训练。


查看模型结构

查看网络:

复制代码
print(model)

输出:

复制代码
Conv

BatchNorm

Residual Block

FC

其中最后一层:

复制代码
model.fc

类似:

复制代码
Linear(
    in_features=2048,
    out_features=1000
)

替换分类层

假设:

复制代码
原任务1000分类

新任务5分类

修改:

python 复制代码
import torch.nn as nn

model.fc = nn.Linear(
    2048,
    5
)

这样:

复制代码
特征提取能力保留

分类能力重新学习

冻结特征提取层

小数据集训练时:

复制代码
不需要训练全部参数

冻结:

python 复制代码
for param in model.parameters():

    param.requires_grad = False

然后:

python 复制代码
for param in model.fc.parameters():

    param.requires_grad = True

结果:

复制代码
只训练最后分类层

优势:

复制代码
训练快

避免过拟合

微调训练(Fine-Tuning)

如果数据量较大:

复制代码
可以训练部分卷积层

例如:

python 复制代码
for param in model.layer4.parameters():

    param.requires_grad = True

这样:

复制代码
高级特征重新学习

低级特征保留

效果通常优于完全冻结。


迁移学习训练示例

定义优化器:

python 复制代码
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.0001
)

定义损失函数:

复制代码
criterion = nn.CrossEntropyLoss()

训练:

python 复制代码
for epoch in range(20):

    output = model(images)

    loss = criterion(
        output,
        labels
    )

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()

数据增强的重要性

深度学习有一句经典名言:

复制代码
Garbage In

Garbage Out

数据质量决定模型上限。

数据增强可以:

复制代码
扩充训练样本

提高泛化能力

减少过拟合

常见数据增强方法

图像增强:

复制代码
随机翻转

随机旋转

随机裁剪

颜色扰动

随机缩放

示意:


PyTorch实现数据增强

python 复制代码
from torchvision import transforms

transform = transforms.Compose([

    transforms.RandomHorizontalFlip(),

    transforms.RandomRotation(15),

    transforms.RandomResizedCrop(224),

    transforms.ToTensor()

])

作用:

复制代码
每轮生成不同训练样本

Batch Normalization优化训练

BN层作用:

复制代码
稳定数据分布

加速训练

网络中:

复制代码
nn.BatchNorm2d(64)

优势:

复制代码
提高收敛速度

缓解梯度消失

现代CNN基本标配。


Dropout防止过拟合

原理:

复制代码
随机关闭神经元

代码:

复制代码
nn.Dropout(0.5)

表示:

复制代码
50%概率失活

效果:

复制代码
提高泛化能力

类别不平衡处理

实际数据:

复制代码
正常样本 9500

异常样本 500

训练后:

复制代码
模型倾向预测正常

解决方案:

复制代码
类别权重

过采样

欠采样

Focal Loss

类别权重训练

例如:

python 复制代码
weights = torch.tensor([1,5])

criterion = nn.CrossEntropyLoss(
    weight=weights
)

作用:

复制代码
增加少数类别重要性

学习率调度器

固定学习率通常不是最佳方案。

PyTorch:

python 复制代码
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=10,
    gamma=0.1
)

训练:

复制代码
scheduler.step()

效果:

复制代码
后期训练更加稳定

Early Stopping

训练过程中:

复制代码
验证集准确率长期不提升

则停止训练。

逻辑:

复制代码
连续10轮无提升

↓

停止训练

优势:

复制代码
节约资源

防止过拟合

混合精度训练

现代GPU支持:

复制代码
FP16训练

PyTorch:

复制代码
from torch.cuda.amp import autocast

使用:

python 复制代码
with autocast():

    output = model(x)

    loss = criterion(
        output,
        y
    )

优势:

复制代码
显存降低

训练加速

集成学习提升效果

多个模型共同预测:

复制代码
ResNet

DenseNet

EfficientNet

投票:

复制代码
模型A

模型B

模型C

↓

综合结果

通常:

复制代码
准确率更高

模型训练效果提升路线图


可视化训练过程

记录Loss:

python 复制代码
loss_list.append(
    loss.item()
)

记录Accuracy:

python 复制代码
acc_list.append(
    acc.item()
)

绘图:

python 复制代码
import matplotlib.pyplot as plt

plt.plot(loss_list)

plt.show()

观察:

复制代码
Loss下降

Accuracy上升

判断训练是否正常。


项目实战推荐策略

小数据集:

复制代码
迁移学习

冻结卷积层

数据增强

中型数据集:

复制代码
迁移学习

部分微调

学习率衰减

大数据集:

复制代码
全量训练

混合精度

多GPU训练

常见面试题

什么是迁移学习?

复制代码
利用已有模型知识

解决新任务

为什么迁移学习效果好?

复制代码
减少随机学习过程

保留通用特征

什么是Fine-Tuning?

复制代码
在预训练模型基础上继续训练

Dropout作用是什么?

复制代码
防止过拟合

BatchNorm作用是什么?

复制代码
稳定训练

加速收敛

为什么要使用数据增强?

复制代码
扩大数据规模

提高泛化能力

总结

在工业级深度学习项目中:

复制代码
网络结构 ≠ 最终效果

真正决定模型性能的往往是:

复制代码
迁移学习

数据增强

微调训练

正则化

学习率策略

训练技巧

其中:

复制代码
迁移学习
是提升模型效果性价比最高的方法

完整优化路线:

复制代码
预训练模型
      ↓
迁移学习
      ↓
微调训练
      ↓
数据增强
      ↓
正则化
      ↓
学习率调优
      ↓
模型集成

掌握这些方法之后,你将能够把一个普通模型训练到更高的准确率和更好的泛化能力,为后续学习:

复制代码
目标检测

语义分割

Transformer

大模型微调

多模态模型

打下坚实基础。

可以说:

深度学习项目的竞争力,很多时候并不来自于更复杂的网络结构,而来自于更合理的训练策略和更成熟的模型优化经验。

相关推荐
小和尚同志2 小时前
AI 自动化测试探索(二):Chrome-devtools MCP
人工智能·e2e·aigc
花酒锄作田2 小时前
Pydantic校验配置文件
python
hboot3 小时前
AI工程师第四课 - 深度学习入门
pytorch·python·神经网络
冬奇Lab4 小时前
Workflow 系列(02):设计范式——四层架构、三种 Context 传递模式与确认门设计
人工智能·agent·工作流引擎
冬奇Lab4 小时前
每日一个开源项目(第145篇):Trellis - 把项目记忆、规范和任务上下文持久化进代码仓库
人工智能·开源·资讯
有道AI情报局4 小时前
Harness即产品
人工智能·agent
罗西的思考5 小时前
机器人 / 强化学习】HIL-SERL:人类在环驱动的具身智能进化框架
人工智能·算法·机器学习
IT_陈寒6 小时前
SpringBoot自动配置的坑,我的API突然就404了
前端·人工智能·后端
笃行3507 小时前
从零到上线:用 EdgeOne Makers + CodeBuddy 搭一个「对账核对员」AI Agent
人工智能
用户6856326208697 小时前
Claude Code 乱猜字段名?我给它写了一个"数据库查询约束 Skill"
人工智能