用PyTorch训练一个猫狗分类器

数据集准备

使用Kaggle提供的猫狗分类数据集(Dogs vs Cats),包含25,000张图片(12,500张猫和12,500张狗)。数据集可从以下链接下载:

解压后目录结构应如下:

复制代码
data/
    train/
        cat.0.jpg
        dog.0.jpg
        ...
    test/
        test1.jpg
        test2.jpg
        ...

数据预处理

使用torchvision.transforms对图像进行标准化和增强:

python 复制代码
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

数据加载

使用ImageFolderDataLoader加载数据:

python 复制代码
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

train_dataset = ImageFolder('data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = ImageFolder('data/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

模型定义

使用预训练的ResNet18模型并进行微调:

python 复制代码
import torch.nn as nn
from torchvision.models import resnet18

model = resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)  # 替换全连接层

训练配置

设置损失函数和优化器:

python 复制代码
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

训练循环

实现完整的训练过程:

python 复制代码
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

模型评估

在测试集上评估模型性能:

python 复制代码
correct = 0
total = 0
model.eval()
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total}%')

源码保存

将完整代码保存为cat_dog_classifier.py

python 复制代码
# 完整代码包含上述所有片段
# 添加必要的import语句和主函数入口

模型保存与加载

训练完成后保存模型:

python 复制代码
torch.save(model.state_dict(), 'cat_dog_classifier.pth')

加载已保存的模型:

python 复制代码
model.load_state_dict(torch.load('cat_dog_classifier.pth'))

扩展建议

  1. 尝试使用不同的预训练模型(如VGG16、EfficientNet)
  2. 增加数据增强方法(如随机旋转、颜色抖动)
  3. 实现学习率调度器
  4. 添加早停机制防止过拟合
  5. 使用TensorBoard记录训练过程

完整项目代码可参考GitHub仓库: https://github.com/example/cat-dog-classifier-pytorch

相关推荐
Shawn_Shawn4 小时前
mcp学习笔记(一)-mcp核心概念梳理
人工智能·llm·mcp
33三 三like6 小时前
《基于知识图谱和智能推荐的养老志愿服务系统》开发日志
人工智能·知识图谱
芝士爱知识a6 小时前
【工具推荐】2026公考App横向评测:粉笔、华图与智蛙面试App功能对比
人工智能·软件推荐·ai教育·结构化面试·公考app·智蛙面试app·公考上岸
Forrit6 小时前
ptyorch安装
pytorch
腾讯云开发者7 小时前
港科大熊辉|AI时代的职场新坐标——为什么你应该去“数据稀疏“的地方?
人工智能
工程师老罗7 小时前
YoloV1数据集格式转换,VOC XML→YOLOv1张量
xml·人工智能·yolo
yLDeveloper8 小时前
从模型评估、梯度难题到科学初始化:一步步解析深度学习的训练问题
深度学习
Coder_Boy_8 小时前
技术让开发更轻松的底层矛盾
java·大数据·数据库·人工智能·深度学习
啊森要自信8 小时前
CANN ops-cv:面向计算机视觉的 AI 硬件端高效算子库核心架构与开发逻辑
人工智能·计算机视觉·架构·cann