迁移学习的案例

python 复制代码
from torch.utils.data import Dataset,DataLoader
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from torch import nn
from torchvision import models
import torch.nn.functional as F

# 模型加载与参数冻结
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():
    print(param)
    param.requires_grad = False  # 冻结预训练权重

# 替换全连接层(适应20类输出)
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)

# 收集需要更新的参数(仅全连接层参数)
params_to_update = []
for param in resnet_model.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)


# 数据增强与预处理
data_transforms = {
    'train':
        transforms.Compose([
            transforms.Resize([300, 300]),
            transforms.RandomRotation(45),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomGrayscale(p=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    'valid':
        transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
}


# 自定义数据集类(此处省略具体实现,需继承Dataset并重写相关方法)
class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.imgs = []
        self.labels = []
        self.transform = transform
        with open(self.file_path,encoding='utf-8') as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image = Image.open(self.imgs[idx])
        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]
        label = torch.from_numpy(np.array(label, dtype=np.int64))
        return image, label


# 加载训练集和验证集
training_data = food_dataset(file_path =r'E:\pythonProject3\深度学习\卷积神经网络\食物分类\train1.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path=r'E:\pythonProject3\深度学习\卷积神经网络\食物分类\test1.txt', transform=data_transforms['valid'])

# 数据加载器
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)


# 设备选择
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")


model = resnet_model.to(device)

def train(dataloader,model,loss_fn,optimizer):
    model.train()
    batch_size_num = 1
    for X ,y in dataloader:
        X,y = X.to(device),y.to(device)
        pred = model(X)
        loss = loss_fn(pred,y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_value = loss.item()
        if batch_size_num % 100 ==0:
            print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")
        batch_size_num +=1
best_acc = 0
def test(dataloader,model,loss_fn):
    global best_acc
    size = len(dataloader.dataset)
    num_batches= len(dataloader)
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for X ,y in dataloader:
            X,y = X.to(device),y.to(device)

            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_pj_loss = test_loss / num_batches
    test_acy = correct / size * 100
    print(f"Avg loss: {test_pj_loss:>7f} \n Accuray: {test_acy:>5.2f}%")
    # 检查是否是最佳准确率
    if correct > best_acc:
        best_acc = correct
        print(f"保存最佳模型,新最佳准确率: {test_acy:>5.2f}%")
    else:
        # 打印当前最佳准确率
        best_accuracy_percent = (best_acc / size) * 100
        print(f"最佳准确率: {best_accuracy_percent:>5.2f}%")
    return test_pj_loss, test_acy
# 损失函数、优化器、学习率调度器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update, lr=0.001)  # 仅优化全连接层参数
# optimizer = torch.optim.Adam(resnet_model.parameters(), lr=0.001)  # 若要优化所有参数,可取消此注释
# scheduler = torch.optim.lr_scheduler.stepLR(optimizer,step_size=5,gamma=0.5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',  # 指标越小越好(如损失)
    factor=0.5,  # 学习率调整倍数
    patience=3   # 多少个epoch指标无改善则调整
)


i=10
for j in range(i):
    print(f"Epoch {j+1}\n----------")
    train(train_dataloader, model,loss_fn,optimizer)
    test_loss, test_acy = test(test_dataloader, model, loss_fn)
    # scheduler.step()
    scheduler.step(test_loss)

这段代码是基于 PyTorch 实现的迁移学习(特征提取法)实战案例,核心是用预训练的 ResNet18 模型解决 20 类食物分类任务。下面我们来逐部分解析代码:

一、核心思路回顾

代码用的是「预训练模型做特征提取」方法:

  1. 加载别人训好的 ResNet18(在 ImageNet 1000 类图片上预训练),冻结主干网络参数(不让它忘记 "识别通用图像特征" 的能力);
  2. 把 ResNet18 的 "分类头"(最后一层全连接层)换成适合自己任务的 "2 分类头"(这里是 20 类食物,所以输出维度是 20);
  3. 只训练新换的全连接层参数,用少量食物数据就能快速出效果。

二、代码模块逐段解析

1. 导入依赖库(基础工具准备)
复制代码
from torch.utils.data import Dataset,DataLoader  # 数据加载核心库(Dataset定义数据格式,DataLoader批量加载)
import torch  # PyTorch核心库(张量计算、模型训练)
import numpy as np  # 数值计算(处理标签转换)
from PIL import Image  # 图片读取库(加载图像文件)
from torchvision import transforms  # 图像预处理( resize、翻转、归一化等)
from torch import nn  # 神经网络层(定义全连接层、损失函数)
from torchvision import models  # 预训练模型库(直接调用ResNet18)
import torch.nn.functional as F  # 常用激活函数等(这里未直接用,预留)

作用:把后续需要的 "数据处理、模型构建、训练工具" 都导入,相当于 "备好工具箱"。

2. 模型加载与参数冻结(迁移学习核心步骤)
复制代码
# 1. 加载预训练ResNet18(带默认权重,即ImageNet预训练结果)
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

# 2. 冻结主干网络所有参数(关键!不让预训练的特征提取层"失忆")
for param in resnet_model.parameters():
    print(param)  # (可选)打印参数,查看主干网络权重(可注释掉减少输出)
    param.requires_grad = False  # 设为False:训练时不更新这些参数

# 3. 替换全连接层(适配20类食物分类任务)
in_features = resnet_model.fc.in_features  # 获取原全连接层的输入维度(ResNet18默认是512)
resnet_model.fc = nn.Linear(in_features, 20)  # 新全连接层:输入512,输出20(对应20类)

# 4. 收集需要更新的参数(只保留新全连接层的参数,减少计算量)
params_to_update = []
for param in resnet_model.parameters():
    if param.requires_grad == True:  # 只有新全连接层的requires_grad是True
        params_to_update.append(param)

关键细节

weights=models.ResNet18_Weights.DEFAULT:自动下载预训练权重(如果本地没有),避免从头训练;

冻结参数的原因:ResNet18 的主干网络(卷积层)已经学会 "边缘、纹理、形状" 等通用图像特征,这些特征对 "食物分类" 也有用,冻结后只训分类头,既快又省数据;

替换全连接层:原 ResNet18 输出 1000 类(对应 ImageNet),我们要分 20 类,所以必须换最后一层,输入维度保持和主干网络输出一致(512)。

3. 数据增强与预处理(提升模型泛化能力)
复制代码
data_transforms = {
    'train':  # 训练集预处理(加数据增强,防止过拟合)
        transforms.Compose([
            transforms.Resize([300, 300]),  # 先放大到300x300(为后续裁剪留空间)
            transforms.RandomRotation(45),  # 随机旋转(-45°~45°)
            transforms.CenterCrop(224),  # 裁剪到224x224(ResNet要求的输入尺寸)
            transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻转
            transforms.RandomVerticalFlip(p=0.5),  # 50%概率垂直翻转
            transforms.RandomGrayscale(p=0.1),  # 10%概率转灰度图
            transforms.ToTensor(),  # 转成Tensor(PyTorch模型只能处理Tensor)
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 归一化(用ImageNet的均值/方差,预训练模型要求)
        ]),
    'valid':  # 验证集预处理(不加增强,真实评估模型效果)
        transforms.Compose([
            transforms.Resize([224, 224]),  # 直接缩放到224x224
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
}

核心逻辑

训练集加增强:通过 "旋转、翻转" 等增加数据多样性,让模型学到更通用的特征,避免 "只认训练集里的图片,换张图就错"(过拟合);

验证集不加增强:因为验证集要测 "模型在真实场景下的表现",不能用经过修改的图片;

归一化:预训练模型是用归一化后的 ImageNet 数据训的,我们的输入必须用相同的均值 / 方差,否则模型会 "不适应"。

4. 自定义数据集类(读取自己的食物数据)
复制代码
class food_dataset(Dataset):  # 继承PyTorch的Dataset类,自定义数据格式
    def __init__(self, file_path, transform=None):
        self.file_path = file_path  # txt文件路径(存图片路径和标签)
        self.imgs = []  # 存所有图片的路径
        self.labels = []  # 存所有图片的标签(20类对应的数字)
        self.transform = transform  # 预处理函数(训练/验证用不同的)
        
        # 从txt文件读取数据(txt格式:每行是"图片路径 标签",空格分隔)
        with open(self.file_path,encoding='utf-8') as f:
            samples = [x.strip().split(' ') for x in f.readlines()]  # 按行分割,再按空格分路径和标签
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

    def __len__(self):  # 必须实现:返回数据集总数量
        return len(self.imgs)

    def __getitem__(self, idx):  # 必须实现:根据索引idx返回1个样本(图片+标签)
        # 1. 读取图片
        image = Image.open(self.imgs[idx])  # 用PIL打开图片
        if self.transform:  # 应用预处理(训练集增强/验证集常规处理)
            image = self.transform(image)

        # 2. 处理标签(转成Tensor,且类型是int64,PyTorch要求)
        label = self.labels[idx]
        label = torch.from_numpy(np.array(label, dtype=np.int64))  # 字符串标签转成int64类型的Tensor
        
        return image, label  # 返回(图片Tensor,标签Tensor)

使用前提

你需要有一个train1.txttest1.txt,格式如下(每行 1 个样本):

复制代码
E:\food_data\apple1.jpg 0
E:\food_data\banana2.jpg 1
...

其中 "0、1" 是食物类别的编号(0~19,共 20 类)。

5. 加载训练 / 验证集(批量喂给模型)
复制代码
# 1. 用自定义的food_dataset加载数据(指定txt路径和预处理方式)
training_data = food_dataset(file_path=r'E:\...\train1.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path=r'E:\...\test1.txt', transform=data_transforms['valid'])

# 2. 用DataLoader批量加载(模型训练不能单张图喂,要批量处理提升效率)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)  # 训练集:批量64,打乱顺序(让模型学的更全面)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)  # 验证集:批量64,打乱与否不影响(但这里也设了True,不影响结果)

关键参数

batch_size=64:每次喂给模型 64 张图(根据电脑显存调整,显存小就设 16/32);

shuffle=True(训练集):每次 epoch 打乱数据顺序,避免模型 "记顺序" 而不是学特征。

6. 设备选择(适配 GPU/CPU,加速训练)
复制代码
# 优先用GPU(cuda),其次用苹果芯片的MPS,最后用CPU
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")  # 打印当前用的设备(比如"Using cuda device")

model = resnet_model.to(device)  # 把模型放到指定设备上(必须!否则模型和数据不在一个设备会报错)

为什么要指定设备

GPU(cuda)训练速度是 CPU 的 10~100 倍,比如 ResNet18 用 CPU 训可能要几小时,GPU 只要几分钟;

注意:后续的 "数据(X、y)" 也要放到同一设备(代码里X,y = X.to(device),y.to(device)就是做这个)。

7. 训练函数(模型 "学习" 的过程)
复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()  # 设为训练模式(启用 dropout、batchnorm等训练特有的层)
    batch_size_num = 1  # 记录当前训练到第几个batch
    for X, y in dataloader:  # 循环读取每个batch的(图片X,标签y)
        # 1. 把数据放到设备上(和模型同设备)
        X, y = X.to(device), y.to(device)
        
        # 2. 前向传播:模型预测结果
        pred = model(X)  # 输入X,输出20类的概率(shape:[64,20])
        
        # 3. 计算损失(预测结果和真实标签的差距)
        loss = loss_fn(pred, y)  # 用CrossEntropyLoss(多分类任务专用)
        
        # 4. 反向传播:更新参数(只更params_to_update,即新全连接层)
        optimizer.zero_grad()  # 清空上一轮的梯度(避免累积)
        loss.backward()  # 计算梯度(从损失反向传到参数)
        optimizer.step()  # 用梯度更新参数
        
        # 5. 打印损失(每100个batch打印一次,方便监控)
        loss_value = loss.item()  # 把Tensor类型的损失转成Python数值
        if batch_size_num % 100 == 0:
            print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")
        batch_size_num += 1

核心流程

前向传播(算预测)→ 算损失(找差距)→ 反向传播(算梯度)→ 更新参数(缩小差距),这是深度学习训练的核心循环。

8. 测试函数(评估模型效果,保存最佳模型)
复制代码
best_acc = 0  # 记录最佳准确率(初始为0)
def test(dataloader, model, loss_fn):
    global best_acc  # 用全局变量记录最佳准确率
    size = len(dataloader.dataset)  # 验证集总样本数
    num_batches = len(dataloader)  # 验证集总batch数
    model.eval()  # 设为评估模式(禁用 dropout、固定batchnorm参数)
    test_loss = 0  # 总验证损失
    correct = 0  # 正确预测的样本数
    
    with torch.no_grad():  # 禁用梯度计算(评估时不用更新参数,节省内存和时间)
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)  # 预测
            
            # 累加损失和正确数
            test_loss += loss_fn(pred, y).item()  # 累加每个batch的损失
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # 统计正确数:pred.argmax(1)取概率最大的类别编号
    
    # 计算平均损失和准确率
    test_avg_loss = test_loss / num_batches  # 平均每个batch的损失
    test_acc = correct / size * 100  # 准确率(百分比)
    print(f"Avg loss: {test_avg_loss:>7f} \n Accuracy: {test_acc:>5.2f}%")
    
    # 保存最佳模型(准确率更高时更新)
    if correct > best_acc:
        best_acc = correct
        print(f"保存最佳模型,新最佳准确率: {test_acc:>5.2f}%")
        # (可选)这里可以加模型保存代码:torch.save(model.state_dict(), "best_model.pth")
    else:
        # 打印当前最佳准确率
        best_acc_percent = (best_acc / size) * 100
        print(f"最佳准确率: {best_acc_percent:>5.2f}%")
    
    return test_avg_loss, test_acc

关键细节

model.eval()with torch.no_grad():评估时必须加,否则模型状态不对,准确率计算不准;

pred.argmax(1):取每个样本预测概率中最大的那个类别(比如 pred 是 [64,20],argmax (1) 后是 [64],每个元素是 0~19 的类别编号);

保存最佳模型:避免训练后期过拟合导致准确率下降,只保留表现最好的模型。

9. 配置训练工具(损失函数、优化器、学习率调度器)
复制代码
# 1. 损失函数:多分类任务用CrossEntropyLoss(自带softmax,不用自己加)
loss_fn = nn.CrossEntropyLoss()

# 2. 优化器:只优化需要更新的参数(params_to_update,即新全连接层)
optimizer = torch.optim.Adam(params_to_update, lr=0.001)  # Adam是常用优化器,学习率0.001(可调整)
# (注释掉的代码)如果要做"微调(Fine-tuning)",就优化所有参数:
# optimizer = torch.optim.Adam(resnet_model.parameters(), lr=0.001)

# 3. 学习率调度器:根据验证损失调整学习率(避免后期学习率太大导致震荡)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',  # 指标"越小越好"(这里是验证损失)
    factor=0.5,  # 学习率调整倍数:如果指标没改善,就乘以0.5(比如0.001→0.0005)
    patience=3   # 连续3个epoch指标没改善,才调整学习率
)
# (注释掉的代码)固定步长调度器:每5个epoch学习率乘以0.5
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

调度器的作用

训练初期用较大的学习率(0.001)快速接近最优解;

后期如果验证损失不再下降,减小学习率(乘以 0.5),让模型更精细地调整参数,避免 "跑过头"。

10. 启动训练循环(控制训练轮次,调用训练 / 测试函数)
复制代码
i = 10  # 训练10个epoch(1个epoch=把训练集完整过一遍)
for j in range(i):
    print(f"Epoch {j+1}\n----------")  # 打印当前是第几个epoch
    train(train_dataloader, model, loss_fn, optimizer)  # 训练一轮
    test_loss, test_acc = test(test_dataloader, model, loss_fn)  # 测试一轮
    scheduler.step(test_loss)  # 根据验证损失调整学习率(对应ReduceLROnPlateau)
    # scheduler.step()  # 如果用StepLR,就用这行(不用传指标)

epoch 的选择

10 个 epoch 是比较合理的初始值,可根据效果调整:如果验证准确率还在上升,就增加 epoch;如果准确率下降(过拟合),就减少 epoch 或加正则化。

三、代码整体逻辑总结(从输入到输出)

  1. 数据端:从 txt 读取食物图片路径和标签→自定义 Dataset 处理图片 + 预处理→DataLoader 批量加载;
  2. 模型端:加载预训练 ResNet18→冻结主干网络→替换 20 类全连接层;
  3. 训练端:按 epoch 循环→每个 epoch 先训练(更新分类头参数)→再测试(评估准确率)→根据验证损失调整学习率;
  4. 目标:用迁移学习快速训练出能分 20 类食物的模型,且避免过拟合、节省计算资源。
相关推荐
源雀数智3 小时前
源雀SCRM开源:企微文件防泄密
java·人工智能·企业微信·流量运营
Honeysea_703 小时前
容器的定义及工作原理
人工智能·深度学习·机器学习·docker·ai·持续部署
fantasy_arch3 小时前
SVT-AV1 svt_aom_motion_estimation_kernel 函数分析
人工智能·算法·av1
Acrel136119655143 小时前
别让电能质量问题拖后腿:工业场景中电能治理的战略意义
大数据·人工智能·能源·创业创新
長琹3 小时前
AES加密算法详细加密步骤代码实现--身份证号码加解密系统
网络·数据库·人工智能·python·密码学
一只鱼丸yo3 小时前
70B大模型也能在笔记本上跑?揭秘让AI“瘦身”的黑科技
人工智能·科技·机器学习·语言模型
极客智造3 小时前
OpenCV C++ 核心:Mat 与像素操作全解析
c++·人工智能·opencv
极客智造3 小时前
OpenCV C++ 色彩空间详解:转换、应用与 LUT 技术
c++·人工智能·opencv
湫兮之风3 小时前
OpenCV: cv::warpAffine()逆仿射变换详解
人工智能·opencv·计算机视觉