用PyTorch训练一个猫狗分类器

数据集准备

使用Kaggle提供的猫狗分类数据集(Dogs vs Cats),包含25,000张图片(12,500张猫和12,500张狗)。数据集可从以下链接下载:

解压后目录结构应如下:

复制代码
data/
    train/
        cat.0.jpg
        dog.0.jpg
        ...
    test/
        test1.jpg
        test2.jpg
        ...

数据预处理

使用torchvision.transforms对图像进行标准化和增强:

python 复制代码
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

数据加载

使用ImageFolderDataLoader加载数据:

python 复制代码
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

train_dataset = ImageFolder('data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = ImageFolder('data/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

模型定义

使用预训练的ResNet18模型并进行微调:

python 复制代码
import torch.nn as nn
from torchvision.models import resnet18

model = resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)  # 替换全连接层

训练配置

设置损失函数和优化器:

python 复制代码
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

训练循环

实现完整的训练过程:

python 复制代码
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

模型评估

在测试集上评估模型性能:

python 复制代码
correct = 0
total = 0
model.eval()
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total}%')

源码保存

将完整代码保存为cat_dog_classifier.py

python 复制代码
# 完整代码包含上述所有片段
# 添加必要的import语句和主函数入口

模型保存与加载

训练完成后保存模型:

python 复制代码
torch.save(model.state_dict(), 'cat_dog_classifier.pth')

加载已保存的模型:

python 复制代码
model.load_state_dict(torch.load('cat_dog_classifier.pth'))

扩展建议

  1. 尝试使用不同的预训练模型(如VGG16、EfficientNet)
  2. 增加数据增强方法(如随机旋转、颜色抖动)
  3. 实现学习率调度器
  4. 添加早停机制防止过拟合
  5. 使用TensorBoard记录训练过程

完整项目代码可参考GitHub仓库: https://github.com/example/cat-dog-classifier-pytorch

相关推荐
默默开发2 小时前
完整版:本地电脑 + WiFi 搭建 AI 自动炒股 + 自我学习系统
人工智能·学习·电脑
zzh940772 小时前
2026年AI文件上传功能实战:聚合站处理图片、PDF、PPT全指南
人工智能·pdf·powerpoint
新缸中之脑6 小时前
Paperless-NGX实战文档管理
人工智能
无极低码8 小时前
ecGlypher新手安装分步指南(标准化流程)
人工智能·算法·自然语言处理·大模型·rag
grant-ADAS8 小时前
记录paddlepaddleOCR从环境到使用默认模型,再训练自己的数据微调模型再推理
人工智能·深度学习
炎爆的土豆翔8 小时前
OpenCV 阈值二值化优化实战:LUT 并行、手写 AVX2 与 cv::threshold 性能对比
人工智能·opencv·计算机视觉
智能相对论9 小时前
从AWE看到海尔智慧家庭步步引领
人工智能
云和数据.ChenGuang9 小时前
魔搭社区 测试AI案例故障
人工智能·深度学习·机器学习·ai·mindstudio
小锋学长生活大爆炸9 小时前
【工具】无需Token!WebAI2API将网页AI转为API使用
人工智能·深度学习·chatgpt·openclaw
昨夜见军贴06169 小时前
AI审核赋能司法鉴定:IACheck如何保障刑事证据检测报告精准无误、经得起推敲?
人工智能