pytorch cnn 实现猫狗分类

### 文章目录

  • [@[toc]](#文章目录 @[toc] 1. 导入必要的库 2. 定义数据集类 3. 数据预处理和加载 4. 定义 CNN 模型 5. 定义损失函数和优化器 6. 训练模型 7. 保存模型 8. 使用模型进行预测 9 完整代码 10. 总结)
  • [1. 导入必要的库](#文章目录 @[toc] 1. 导入必要的库 2. 定义数据集类 3. 数据预处理和加载 4. 定义 CNN 模型 5. 定义损失函数和优化器 6. 训练模型 7. 保存模型 8. 使用模型进行预测 9 完整代码 10. 总结)
  • [2. 定义数据集类](#文章目录 @[toc] 1. 导入必要的库 2. 定义数据集类 3. 数据预处理和加载 4. 定义 CNN 模型 5. 定义损失函数和优化器 6. 训练模型 7. 保存模型 8. 使用模型进行预测 9 完整代码 10. 总结)
  • [3. 数据预处理和加载](#文章目录 @[toc] 1. 导入必要的库 2. 定义数据集类 3. 数据预处理和加载 4. 定义 CNN 模型 5. 定义损失函数和优化器 6. 训练模型 7. 保存模型 8. 使用模型进行预测 9 完整代码 10. 总结)
  • [4. 定义 CNN 模型](#文章目录 @[toc] 1. 导入必要的库 2. 定义数据集类 3. 数据预处理和加载 4. 定义 CNN 模型 5. 定义损失函数和优化器 6. 训练模型 7. 保存模型 8. 使用模型进行预测 9 完整代码 10. 总结)
  • [5. 定义损失函数和优化器](#文章目录 @[toc] 1. 导入必要的库 2. 定义数据集类 3. 数据预处理和加载 4. 定义 CNN 模型 5. 定义损失函数和优化器 6. 训练模型 7. 保存模型 8. 使用模型进行预测 9 完整代码 10. 总结)
  • [6. 训练模型](#文章目录 @[toc] 1. 导入必要的库 2. 定义数据集类 3. 数据预处理和加载 4. 定义 CNN 模型 5. 定义损失函数和优化器 6. 训练模型 7. 保存模型 8. 使用模型进行预测 9 完整代码 10. 总结)
  • [7. 保存模型](#文章目录 @[toc] 1. 导入必要的库 2. 定义数据集类 3. 数据预处理和加载 4. 定义 CNN 模型 5. 定义损失函数和优化器 6. 训练模型 7. 保存模型 8. 使用模型进行预测 9 完整代码 10. 总结)
  • [8. 使用模型进行预测](#文章目录 @[toc] 1. 导入必要的库 2. 定义数据集类 3. 数据预处理和加载 4. 定义 CNN 模型 5. 定义损失函数和优化器 6. 训练模型 7. 保存模型 8. 使用模型进行预测 9 完整代码 10. 总结)
  • [9 完整代码](#文章目录 @[toc] 1. 导入必要的库 2. 定义数据集类 3. 数据预处理和加载 4. 定义 CNN 模型 5. 定义损失函数和优化器 6. 训练模型 7. 保存模型 8. 使用模型进行预测 9 完整代码 10. 总结)
  • [10. 总结](#文章目录 @[toc] 1. 导入必要的库 2. 定义数据集类 3. 数据预处理和加载 4. 定义 CNN 模型 5. 定义损失函数和优化器 6. 训练模型 7. 保存模型 8. 使用模型进行预测 9 完整代码 10. 总结)

1. 导入必要的库

py 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os

2. 定义数据集类

我们将创建一个自定义数据集类来加载猫狗图片。

py 复制代码
class CatDogDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['cat', 'dog']
        self.image_paths = []
        self.labels = []

        # 遍历 cat 和 dog 目录,加载图片路径和标签
        for idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, img_name))
                self.labels.append(idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')  # 确保图片是 RGB 格式
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

3. 数据预处理和加载

定义数据预处理方法,并加载数据集。

py 复制代码
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # 调整图片大小
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])

# 加载数据集
train_dataset = CatDogDataset(root_dir='path_to_train_data', transform=transform)
val_dataset = CatDogDataset(root_dir='path_to_val_data', transform=transform)
# 数据加载器

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

4. 定义 CNN 模型

py 复制代码
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, 2)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 16 * 16)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

model = CNN()

5. 定义损失函数和优化器

py 复制代码
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

6. 训练模型

py 复制代码
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

    # 验证模型
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f"Validation Accuracy: {100 * correct / total:.2f}%")

7. 保存模型

py 复制代码
torch.save(model.state_dict(), 'cat_dog_classifier.pth')

8. 使用模型进行预测

py 复制代码
# 加载模型
model.load_state_dict(torch.load('cat_dog_classifier.pth'))
model.eval()

# 预测函数
def predict_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
    
    return 'cat' if predicted.item() == 0 else 'dog'

# 使用模型进行预测
image_path = 'path_to_test_image.jpg'
prediction = predict_image(image_path)
print(f"The image is a {prediction}")

# 使用模型进行预测
image_path = 'path_to_test_image.jpg'
prediction = predict_image(image_path)
print(f"The image is a {prediction}")

9 完整代码

py 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os

class CatDogDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['cats', 'dogs']
        self.image_paths = []
        self.labels = []

        # 遍历 cat 和 dog 目录,加载图片路径和标签
        for idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            num_pets = 0
            for img_name in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, img_name))
                self.labels.append(idx)
                # print("class_dir : ", img_name)
                # num_pets = num_pets + 1
                # if num_pets >= 5000:
                #     break


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')  # 确保图片是 RGB 格式
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # 调整图片大小
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])

# 加载数据集
train_dataset = CatDogDataset(root_dir='D:/Cache/dataset/PetImages/train', transform=transform)
val_dataset = CatDogDataset(root_dir='D:/Cache/dataset/PetImages/valid', transform=transform)

# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, 2)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 16 * 16)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

model = CNN()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def train():
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

        # 验证模型
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        print(f"Validation Accuracy: {100 * correct / total:.2f}%")


    torch.save(model.state_dict(), 'cat_dog_classifier.pth')

# 预测函数
def predict_image(image_path):
    
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
    
    return 'cat' if predicted.item() == 0 else 'dog'


def test():
    # 使用模型进行预测
    # 加载模型
    model.load_state_dict(torch.load('cat_dog_classifier.pth'))
    model.eval()
    image_path = 'D:/Cache/dataset/PetImages/Dog/6.jpg'
    image_path = 'D:/develop/pytorch/dogcat/img/training/dogs/dog1.jpg'
    image_path = 'D:/develop/pytorch/dogcat/img/training/cats/4.jpg'
    prediction = predict_image(image_path)
    print(f"The image is a {prediction}")

import matplotlib.pyplot as plt

def test1():
    image_path = 'D:/develop/pytorch/dogcat/img/training/cats/3.jpg'
    img = Image.open(image_path).convert('RGB')

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    img = transform(img)
    img = img.unsqueeze(0)  # 添加batch维度

    model.load_state_dict(torch.load('cat_dog_classifier.pth'))
    model.eval()
    prediction = predict_image(image_path)

    class_names = ['cat', 'dog']
    print("Predicted class:", prediction)

    plt.imshow(img.squeeze().numpy().transpose((1, 2, 0)))
    plt.show()

if __name__ == '__main__':
    train()
    test1()
复制代码
D:\develop\pytorch\dogcat>python3.7 dogVsCat.py
C:\Users\yosola\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\PIL\TiffImagePlugin.py:864: UserWarning: Truncated File Read
  warnings.warn(str(msg))
Epoch [1/10], Loss: 0.5782
Validation Accuracy: 76.25%
Epoch [2/10], Loss: 0.4676
Validation Accuracy: 81.10%
Epoch [3/10], Loss: 0.4201
Validation Accuracy: 84.90%
Epoch [4/10], Loss: 0.3605
Validation Accuracy: 88.25%
Epoch [5/10], Loss: 0.2949
Validation Accuracy: 92.50%
Epoch [6/10], Loss: 0.2234
Validation Accuracy: 95.90%
Epoch [7/10], Loss: 0.1562
Validation Accuracy: 98.00%
Epoch [8/10], Loss: 0.1069
Validation Accuracy: 98.60%
Epoch [9/10], Loss: 0.0907
Validation Accuracy: 99.70%
Epoch [10/10], Loss: 0.0785
Validation Accuracy: 99.50%
Predicted class: cat

10. 总结

我们定义了一个自定义数据集类 CatDogDataset 来加载猫狗图片。

使用 PyTorch 的 DataLoader 加载数据。

定义了一个简单的 CNN 模型进行训练。

保存训练好的模型,并使用模型进行预测。

你可以根据需要调整模型的架构、超参数和数据增强方法。希望这个示例对你有帮助!

相关推荐
东皇太星6 小时前
ResNet (2015)(卷积神经网络)
人工智能·神经网络·cnn
qy-ll20 小时前
深度学习——CNN入门
人工智能·深度学习·cnn
hacker7071 天前
openGauss 在K12教育场景的数据处理测评:CASE WHEN 实现高效分类
人工智能·分类·数据挖掘
java1234_小锋2 天前
基于Python深度学习的车辆车牌识别系统(PyTorch2卷积神经网络CNN+OpenCV4实现)视频教程 - 自定义字符图片数据集
python·深度学习·cnn·车牌识别
AI即插即用2 天前
即插即用系列 | CVPR 2025 WPFormer:用于表面缺陷检测的查询式Transformer
人工智能·深度学习·yolo·目标检测·cnn·视觉检测·transformer
大数据魔法师2 天前
分类与回归算法(六)- 集成学习(随机森林、梯度提升决策树、Stacking分类)相关理论
分类·回归·集成学习
AI即插即用2 天前
即插即用系列 | 2025 MambaNeXt-YOLO 炸裂登场!YOLO 激吻 Mamba,打造实时检测新霸主
人工智能·pytorch·深度学习·yolo·目标检测·计算机视觉·视觉检测
大数据魔法师2 天前
分类与回归算法(五)- 决策树分类
决策树·分类·回归
忘却的旋律dw2 天前
使用LLM模型的tokenizer报错AttributeError: ‘dict‘ object has no attribute ‘model_type‘
人工智能·pytorch·python
happy egg3 天前
随机森林分类VS回归
随机森林·分类·回归