深入PyTorch模型的训练与可视化 —— 掌握迁移学习等模型训练效果提升的办法

目录

前言

什么是迁移学习

为什么迁移学习效果好

迁移学习训练流程

使用PyTorch加载预训练模型

查看模型结构

替换分类层

冻结特征提取层

微调训练(Fine-Tuning)

迁移学习训练示例

数据增强的重要性

常见数据增强方法

PyTorch实现数据增强

[Batch Normalization优化训练](#Batch Normalization优化训练)

Dropout防止过拟合

类别不平衡处理

类别权重训练

学习率调度器

[Early Stopping](#Early Stopping)

混合精度训练

集成学习提升效果

模型训练效果提升路线图

可视化训练过程

项目实战推荐策略

常见面试题

什么是迁移学习?

为什么迁移学习效果好?

什么是Fine-Tuning?

Dropout作用是什么?

BatchNorm作用是什么?

为什么要使用数据增强?

总结


在前面的文章中,我们已经学习了:

复制代码
Tensor数据处理

神经网络搭建

损失函数与优化器

模型训练流程

模型权重保存与加载

但是在实际项目中,很多同学都会遇到一个问题:

复制代码
模型能训练

但是效果不好

例如:

复制代码
准确率上不去

训练速度慢

过拟合严重

数据量不足

模型泛化能力差

事实上,深度学习项目中:

复制代码
网络结构只占成功因素的一部分

训练策略和优化技巧同样重要

很多工业级项目并不会从零开始训练模型,而是采用:

复制代码
迁移学习

数据增强

预训练模型

微调训练

学习率调优

集成学习

等方式来提升模型效果。

本文将系统讲解:

复制代码
迁移学习原理

预训练模型使用

模型微调

数据增强

类别不平衡处理

训练技巧优化

模型效果提升方案

什么是迁移学习

迁移学习(Transfer Learning)是深度学习领域最重要的技术之一。

核心思想:

复制代码
利用已有模型学到的知识

解决新的任务

举个例子:

复制代码
ImageNet拥有1000万级图片

训练得到ResNet模型

已经学会识别:

边缘

纹理

形状

物体特征

那么:

复制代码
猫狗识别

水果分类

工业缺陷检测

这些任务完全没必要重新学习基础特征。

直接利用已有知识即可。

这就是迁移学习。


为什么迁移学习效果好

传统训练方式:

复制代码
随机初始化参数

从0开始学习

存在问题:

复制代码
训练慢

需要大量数据

容易过拟合

迁移学习:

复制代码
加载预训练模型

继承已有特征提取能力

优势:

复制代码
收敛更快

准确率更高

数据需求更少

特别适用于:

复制代码
中小规模数据集

迁移学习训练流程

整体流程如下:

例如:

复制代码
ResNet50

原本识别1000类

↓

修改最后一层

↓

识别5类水果

使用PyTorch加载预训练模型

PyTorch提供了大量预训练模型。

导入:

python 复制代码
import torchvision.models as models

加载ResNet50:

python 复制代码
model = models.resnet50(
    weights=models.ResNet50_Weights.DEFAULT
)

此时:

复制代码
模型参数已经训练完成

无需重新训练。


查看模型结构

查看网络:

复制代码
print(model)

输出:

复制代码
Conv

BatchNorm

Residual Block

FC

其中最后一层:

复制代码
model.fc

类似:

复制代码
Linear(
    in_features=2048,
    out_features=1000
)

替换分类层

假设:

复制代码
原任务1000分类

新任务5分类

修改:

python 复制代码
import torch.nn as nn

model.fc = nn.Linear(
    2048,
    5
)

这样:

复制代码
特征提取能力保留

分类能力重新学习

冻结特征提取层

小数据集训练时:

复制代码
不需要训练全部参数

冻结:

python 复制代码
for param in model.parameters():

    param.requires_grad = False

然后:

python 复制代码
for param in model.fc.parameters():

    param.requires_grad = True

结果:

复制代码
只训练最后分类层

优势:

复制代码
训练快

避免过拟合

微调训练(Fine-Tuning)

如果数据量较大:

复制代码
可以训练部分卷积层

例如:

python 复制代码
for param in model.layer4.parameters():

    param.requires_grad = True

这样:

复制代码
高级特征重新学习

低级特征保留

效果通常优于完全冻结。


迁移学习训练示例

定义优化器:

python 复制代码
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.0001
)

定义损失函数:

复制代码
criterion = nn.CrossEntropyLoss()

训练:

python 复制代码
for epoch in range(20):

    output = model(images)

    loss = criterion(
        output,
        labels
    )

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()

数据增强的重要性

深度学习有一句经典名言:

复制代码
Garbage In

Garbage Out

数据质量决定模型上限。

数据增强可以:

复制代码
扩充训练样本

提高泛化能力

减少过拟合

常见数据增强方法

图像增强:

复制代码
随机翻转

随机旋转

随机裁剪

颜色扰动

随机缩放

示意:


PyTorch实现数据增强

python 复制代码
from torchvision import transforms

transform = transforms.Compose([

    transforms.RandomHorizontalFlip(),

    transforms.RandomRotation(15),

    transforms.RandomResizedCrop(224),

    transforms.ToTensor()

])

作用:

复制代码
每轮生成不同训练样本

Batch Normalization优化训练

BN层作用:

复制代码
稳定数据分布

加速训练

网络中:

复制代码
nn.BatchNorm2d(64)

优势:

复制代码
提高收敛速度

缓解梯度消失

现代CNN基本标配。


Dropout防止过拟合

原理:

复制代码
随机关闭神经元

代码:

复制代码
nn.Dropout(0.5)

表示:

复制代码
50%概率失活

效果:

复制代码
提高泛化能力

类别不平衡处理

实际数据:

复制代码
正常样本 9500

异常样本 500

训练后:

复制代码
模型倾向预测正常

解决方案:

复制代码
类别权重

过采样

欠采样

Focal Loss

类别权重训练

例如:

python 复制代码
weights = torch.tensor([1,5])

criterion = nn.CrossEntropyLoss(
    weight=weights
)

作用:

复制代码
增加少数类别重要性

学习率调度器

固定学习率通常不是最佳方案。

PyTorch:

python 复制代码
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=10,
    gamma=0.1
)

训练:

复制代码
scheduler.step()

效果:

复制代码
后期训练更加稳定

Early Stopping

训练过程中:

复制代码
验证集准确率长期不提升

则停止训练。

逻辑:

复制代码
连续10轮无提升

↓

停止训练

优势:

复制代码
节约资源

防止过拟合

混合精度训练

现代GPU支持:

复制代码
FP16训练

PyTorch:

复制代码
from torch.cuda.amp import autocast

使用:

python 复制代码
with autocast():

    output = model(x)

    loss = criterion(
        output,
        y
    )

优势:

复制代码
显存降低

训练加速

集成学习提升效果

多个模型共同预测:

复制代码
ResNet

DenseNet

EfficientNet

投票:

复制代码
模型A

模型B

模型C

↓

综合结果

通常:

复制代码
准确率更高

模型训练效果提升路线图


可视化训练过程

记录Loss:

python 复制代码
loss_list.append(
    loss.item()
)

记录Accuracy:

python 复制代码
acc_list.append(
    acc.item()
)

绘图:

python 复制代码
import matplotlib.pyplot as plt

plt.plot(loss_list)

plt.show()

观察:

复制代码
Loss下降

Accuracy上升

判断训练是否正常。


项目实战推荐策略

小数据集:

复制代码
迁移学习

冻结卷积层

数据增强

中型数据集:

复制代码
迁移学习

部分微调

学习率衰减

大数据集:

复制代码
全量训练

混合精度

多GPU训练

常见面试题

什么是迁移学习?

复制代码
利用已有模型知识

解决新任务

为什么迁移学习效果好?

复制代码
减少随机学习过程

保留通用特征

什么是Fine-Tuning?

复制代码
在预训练模型基础上继续训练

Dropout作用是什么?

复制代码
防止过拟合

BatchNorm作用是什么?

复制代码
稳定训练

加速收敛

为什么要使用数据增强?

复制代码
扩大数据规模

提高泛化能力

总结

在工业级深度学习项目中:

复制代码
网络结构 ≠ 最终效果

真正决定模型性能的往往是:

复制代码
迁移学习

数据增强

微调训练

正则化

学习率策略

训练技巧

其中:

复制代码
迁移学习
是提升模型效果性价比最高的方法

完整优化路线:

复制代码
预训练模型
      ↓
迁移学习
      ↓
微调训练
      ↓
数据增强
      ↓
正则化
      ↓
学习率调优
      ↓
模型集成

掌握这些方法之后,你将能够把一个普通模型训练到更高的准确率和更好的泛化能力,为后续学习:

复制代码
目标检测

语义分割

Transformer

大模型微调

多模态模型

打下坚实基础。

可以说:

深度学习项目的竞争力,很多时候并不来自于更复杂的网络结构,而来自于更合理的训练策略和更成熟的模型优化经验。

相关推荐
程序员差不多先生1 小时前
Copilot 取消年费改按量计费:AI Coding 工具进入了什么新阶段?
人工智能·copilot·github copilot
猿粪已尽1 小时前
cc switch+codex+米醋 实现AI办公
人工智能·codex·cc switch·米醋·micu
段一凡-华北理工大学1 小时前
工业领域的Hadoop架构学习~系列文章20:故障诊断与根因分析 - 从表象到本质的智能推理
大数据·人工智能·hadoop·学习·架构·高炉炼铁·工业智能体
凌云拓界1 小时前
状态机与思考循环 ——CogitoAgent开发实战(一)
javascript·人工智能·架构·node.js·设计规范
无心水1 小时前
【OpenClaw:赚钱】案例19、内容产量5倍、广告收入翻4倍:播客转多平台内容矩阵全自动化实战(OpenAI Whisper + Claude)
java·人工智能·python·ai编程·openclaw·养龙虾·java.time
寻道模式1 小时前
【时间之外】AI不懂Mac吗?
人工智能·macos
挂科边缘1 小时前
手把手教你使用 Faster-Whisper 实时语音输入转文本,本地部署教程
人工智能·语言模型·whisper·faster-whisper·实时语音输入转文本
逗逗班学Python1 小时前
基于 Faster-Whisper 的本地语音转字幕与会议纪要系统:从音频转写到 SRT 字幕与 Markdown 纪要完整项目实战
python·语音识别·faster-whisper·字幕生成·会议纪要
SUNNY_SHUN1 小时前
把 Whisper、Moonshine、SenseVoice 统统装进手机:sherpa-onnx 离线语音部署框架,GitHub 10.9K Star
人工智能·智能手机·whisper·github