利用 Resnet50 重新训练,完成宠物数据集的识别,附源代码。。

如果你对深度学习有所了解,知道神经网络可以识别图片,但还没自己动手训练过模型,这篇文章会非常适合你。

这篇文章将使用 PyTorch 和 ResNet50,基于 Oxford-IIIT Pet 数据集(37 类宠物)完成一个完整的训练过程。

这个方法也可以应用到你自己的数据集上,比如识别不同种类的花或物体。

接下来,带你一步步完成这个任务。

Attention:全网最全的 AI 小白到 AI 大神的天梯成长学习路线,几十万原创专栏和硬核视频,点击这里查看:AI小白到AI大神的天梯之路

什么是 ResNet50,为什么选择它?

ResNet50 是一个深度卷积神经网络,包含 50 层,设计用来处理图像分类任务。

它在 ImageNet 数据集上表现优异,能识别 1000 种物体。

我们今天的目标是重新训练它,让它学会识别新的类别------37 种宠物。

选择 ResNet50 的理由很简单------

  • 成熟的结构,它已经被广泛验证,适合大多数图像分类任务。
  • 开箱即用:PyTorch 提供了现成的实现,省去自己设计的麻烦。
  • 高效性:即使从零开始训练,也能得到不错的结果。

下面,我们将训练过程拆成几个关键步骤,逐步讲解。

训练 ResNet50 的四大步骤

步骤 1:准备数据

模型训练的第一步是准备数据。

Oxford-IIIT Pet 数据集包含大量宠物照片,我们需要调整它们的格式,确保模型能正确处理。

代码是这样实现的:

python 复制代码
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 将图像调整为 224x224 像素
    transforms.ToTensor(),          # 将图像转换为 Tensor 格式
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])  # 标准化像素值])   
  • ResizeResNet50 的输入要求是 224x224,所有图像需要统一到这个尺寸。
  • ToTensor将图片从普通格式转为模型能处理的数字格式(范围 0 到 1)。
  • Normalize用 ImageNet 的均值和标准差标准化数据,帮助模型更快收敛。

接着,用 DataLoader 将数据分成小批次:

plain 复制代码
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

这里 batch_size=32 表示每次处理 32 张图片,shuffle=True 打乱顺序,避免模型记住数据的排列。

步骤 2:搭建模型------调整 ResNet50 的结构

ResNet50 是一个现成的模型,但我们需要根据任务调整它。

原始的 ResNet50 输出 1000 类,而我们的数据集只有 37 类,因此需要修改最后一层。

代码实现如下:

python 复制代码
model = torchvision.models.resnet50(weights=None)  # 初始化 ResNet50,不使用预训练权重
model.fc = nn.Linear(model.fc.in_features, 37)     # 将全连接层改为 37 类输出
model = model.to(device)      # 转移到 GPU 或 CPU
  • weights=None表示将从零开始训练模型。
  • model.fc这一行代码修改了模型最后一层(全连接层),将输出特征数改为 37 个,对应 37 类宠物。如果你有自己的数据集,且分类数量与原始模型不一致,也需要进行类似的修改。
  • to(device)根据设备(GPU 或 CPU)运行模型,GPU 会显著加速训练。

步骤 3:定义学习方式

模型需要知道如何学习以及学习步长是什么样的,这样才能优化模型参数的调整过程。

这个过程主要涉及损失函数和优化器。

损失函数衡量的是模型预测值与真实答案之间的差距,优化器则负责调整模型的参数。

用代码中是这样定义的:

python 复制代码
criterion = nn.CrossEntropyLoss()          # 交叉熵损失,用于分类任务
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam 优化器,学习率 0.001
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # 学习率调度器
  • 损失函数采用的是交叉熵损失函数,该函数是多分类任务的标准选择。
  • 优化器Adam 是一种高效的优化算法,lr=0.001 是初始学习率。
  • 调度器每 5 个 epoch,学习率乘以 0.1,逐步降低以稳定训练。

步骤 4:训练与测试------让模型学习和验证

训练其实就是让模型反复调整自己参数的过程,验证则是检查训练的效果。

训练和验证的逻辑分别在两个函数中实现。

训练函数:

python 复制代码
def train(epoch):
    model.train()  # 进入训练模式
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()  # 清零梯度
        outputs = model(inputs)  # 前向传播
        loss = criterion(outputs, targets)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
  • train()激活模型的训练模式(启用 dropout/BN 层的全局统计功能)。
  • 流程模型预测 -> 计算损失 -> 调整参数。

测试函数:

python 复制代码
def test(epoch):
    model.eval()  # 进入测试模式
    with torch.no_grad():  # 关闭梯度计算
        for inputs, targets in test_loader:
            outputs = model(inputs)
            # 计算准确率...
  • eval()切换到测试模式,关闭训练时的随机性(Dropout, BN 不再进行全局统计)。
  • no_grad()节省内存,提高测试效率。

主循环运行 20 个 epoch,每次训练后测试,并保存最佳模型:

python 复制代码
for epoch in range(1, 21):
    train(epoch)
    test_acc = test(epoch)
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), "best_pet_model.pth")

训练效果

运行完整的代码后,你会看到类似这样的结果:

plain 复制代码
Epoch 1 | Train Acc: 50.23%Epoch 1 | Test Acc: 52.10%...Best Test Accuracy: 85.67%

这表示模型在测试集上的最高准确率达到 85.67%。

如果效果不理想,可以尝试下面的改进方法。

改进建议

使用预训练权重

weights=None 改为 weights='DEFAULT',利用 ImageNet 的经验加速训练。

数据增强

transform 中加入 transforms.RandomHorizontalFlip(),增加数据多样性。

调整参数

尝试不同的学习率(如 0.0001)或 batch_size(如 64),找到最佳组合。

通过以上的四个步骤------准备数据、搭建模型、设定规则、训练测试,你就可以用 ResNet50 训练自己的数据集了。

这个过程并不复杂,只要理解每个部分的逻辑,就能灵活应用到其他任务上。

如果你有自己的数据集,不妨试一试。

宠物训练的完整代码见这里:github.com/dongdongcan...

备注,本文的完整代码最好在 GPU 环境下运行。

我创建了一个《小而精的AI学习圈》知识星球,星球上有几十万字原创高质量的技术专栏分享,同时你也可以在星球向我提问。 点击这里,我们星球见! 点击这里查看所有 AI 技术专栏

相关推荐
幸好我会魔法2 分钟前
常见限流算法及实现
java·开发语言·算法
飞奔的马里奥26 分钟前
力扣Hot100——136. 只出现一次的数字
算法·leetcode·职场和发展
FAREWELL0007542 分钟前
Leetcode做题记录----3
算法·leetcode·职场和发展
limbo01261 小时前
2025-3-17算法打卡
数据结构·算法·leetcode
白云千载尽2 小时前
LMDrive大语言模型加持的自动驾驶闭环系统 原理与复现过程记录
人工智能·经验分享·python·算法·机器学习·语言模型·自动驾驶
CoovallyAIHub2 小时前
99.22%准确率!EfficientNet优化算法实现猪肉新鲜度无损快检
深度学习·算法·计算机视觉
仟濹2 小时前
【递归与动态规划(DP) C/C++】(1)递归 与 动态规划(DP)
算法·动态规划
tpoog2 小时前
[贪心算法]-最大数(lambda 表达式的补充)
算法·贪心算法
橘颂TA2 小时前
【C++】树和二叉树的实现(上)
数据结构·算法·二叉树·
maisui121382 小时前
牛客周赛 Round 85
算法·深度优先·codeforces