用PyTorch训练一个猫狗分类器

数据集准备

使用Kaggle提供的猫狗分类数据集(Dogs vs Cats),包含25,000张图片(12,500张猫和12,500张狗)。数据集可从以下链接下载:

解压后目录结构应如下:

复制代码
data/
    train/
        cat.0.jpg
        dog.0.jpg
        ...
    test/
        test1.jpg
        test2.jpg
        ...

数据预处理

使用torchvision.transforms对图像进行标准化和增强:

python 复制代码
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

数据加载

使用ImageFolderDataLoader加载数据:

python 复制代码
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

train_dataset = ImageFolder('data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = ImageFolder('data/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

模型定义

使用预训练的ResNet18模型并进行微调:

python 复制代码
import torch.nn as nn
from torchvision.models import resnet18

model = resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)  # 替换全连接层

训练配置

设置损失函数和优化器:

python 复制代码
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

训练循环

实现完整的训练过程:

python 复制代码
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

模型评估

在测试集上评估模型性能:

python 复制代码
correct = 0
total = 0
model.eval()
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total}%')

源码保存

将完整代码保存为cat_dog_classifier.py

python 复制代码
# 完整代码包含上述所有片段
# 添加必要的import语句和主函数入口

模型保存与加载

训练完成后保存模型:

python 复制代码
torch.save(model.state_dict(), 'cat_dog_classifier.pth')

加载已保存的模型:

python 复制代码
model.load_state_dict(torch.load('cat_dog_classifier.pth'))

扩展建议

  1. 尝试使用不同的预训练模型(如VGG16、EfficientNet)
  2. 增加数据增强方法(如随机旋转、颜色抖动)
  3. 实现学习率调度器
  4. 添加早停机制防止过拟合
  5. 使用TensorBoard记录训练过程

完整项目代码可参考GitHub仓库: https://github.com/example/cat-dog-classifier-pytorch

相关推荐
buttonupAI8 小时前
今日Reddit各AI板块高价值讨论精选(2025-12-20)
人工智能
2501_904876488 小时前
2003-2021年上市公司人工智能的采纳程度测算数据(含原始数据+计算结果)
人工智能
竣雄8 小时前
计算机视觉:原理、技术与未来展望
人工智能·计算机视觉
救救孩子把8 小时前
44-机器学习与大模型开发数学教程-4-6 大数定律与中心极限定理
人工智能·机器学习
Rabbit_QL9 小时前
【LLM评价指标】从概率到直觉:理解语言模型的困惑度
人工智能·语言模型·自然语言处理
呆萌很9 小时前
HSV颜色空间过滤
人工智能
roman_日积跬步-终至千里9 小时前
【人工智能导论】02-搜索-高级搜索策略探索篇:从约束满足到博弈搜索
java·前端·人工智能
FL16238631299 小时前
[C#][winform]基于yolov11的淡水鱼种类检测识别系统C#源码+onnx模型+评估指标曲线+精美GUI界面
人工智能·yolo·目标跟踪
爱笑的眼睛1110 小时前
从 Seq2Seq 到 Transformer++:深度解构与自构建现代机器翻译核心组件
java·人工智能·python·ai
小润nature10 小时前
AI时代对编程技能学习方式的根本变化(1)
人工智能