用PyTorch训练一个猫狗分类器

数据集准备

使用Kaggle提供的猫狗分类数据集(Dogs vs Cats),包含25,000张图片(12,500张猫和12,500张狗)。数据集可从以下链接下载:

解压后目录结构应如下:

复制代码
data/
    train/
        cat.0.jpg
        dog.0.jpg
        ...
    test/
        test1.jpg
        test2.jpg
        ...

数据预处理

使用torchvision.transforms对图像进行标准化和增强:

python 复制代码
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

数据加载

使用ImageFolderDataLoader加载数据:

python 复制代码
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

train_dataset = ImageFolder('data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = ImageFolder('data/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

模型定义

使用预训练的ResNet18模型并进行微调:

python 复制代码
import torch.nn as nn
from torchvision.models import resnet18

model = resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)  # 替换全连接层

训练配置

设置损失函数和优化器:

python 复制代码
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

训练循环

实现完整的训练过程:

python 复制代码
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

模型评估

在测试集上评估模型性能:

python 复制代码
correct = 0
total = 0
model.eval()
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total}%')

源码保存

将完整代码保存为cat_dog_classifier.py

python 复制代码
# 完整代码包含上述所有片段
# 添加必要的import语句和主函数入口

模型保存与加载

训练完成后保存模型:

python 复制代码
torch.save(model.state_dict(), 'cat_dog_classifier.pth')

加载已保存的模型:

python 复制代码
model.load_state_dict(torch.load('cat_dog_classifier.pth'))

扩展建议

  1. 尝试使用不同的预训练模型(如VGG16、EfficientNet)
  2. 增加数据增强方法(如随机旋转、颜色抖动)
  3. 实现学习率调度器
  4. 添加早停机制防止过拟合
  5. 使用TensorBoard记录训练过程

完整项目代码可参考GitHub仓库: https://github.com/example/cat-dog-classifier-pytorch

相关推荐
陈天伟教授25 分钟前
人工智能应用-机器听觉:7. 统计合成法
人工智能·语音识别
笨蛋不要掉眼泪1 小时前
Spring Boot集成LangChain4j:与大模型对话的极速入门
java·人工智能·后端·spring·langchain
昨夜见军贴06161 小时前
IACheck AI审核技术赋能消费认证:为智能宠物喂食器TELEC报告构筑智能合规防线
人工智能·宠物
DisonTangor1 小时前
阿里开源语音识别模型——Qwen3-ASR
人工智能·开源·语音识别
万事ONES1 小时前
ONES 签约北京高级别自动驾驶示范区专设国有运营平台——北京车网
人工智能·机器学习·自动驾驶
qyr67891 小时前
深度解析:3D细胞培养透明化试剂供应链与主要制造商分布
大数据·人工智能·3d·市场分析·市场报告·3d细胞培养·细胞培养
软件开发技术深度爱好者1 小时前
浅谈人工智能(AI)对个人发展的影响
人工智能
一路向北he2 小时前
esp32 arduino环境的搭建
人工智能
SmartBrain2 小时前
Qwen3-VL 模型架构及原理详解
人工智能·语言模型·架构·aigc
renhongxia12 小时前
AI算法实战:逻辑回归在风控场景中的应用
人工智能·深度学习·算法·机器学习·信息可视化·语言模型·逻辑回归