基于 PyTorch 框架写一个图片分类模型

根据您的需求,我将提供一个基于PyTorch框架的简单图片分类模型示例。这个例子使用的是卷积神经网络(CNN)来对CIFAR-10数据集进行分类,CIFAR-10包含60,000个32x32彩色图像,分为10个类别。请注意,实际应用中您可能需要根据具体任务调整模型结构、训练参数等。

首先,请确保您已经安装了PyTorch和torchvision库。如果没有安装,可以通过以下命令安装:

Shell 复制代码
pip install torch torchvision

下面是基本的模型定义和训练代码:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义简单的CNN模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.fc = nn.Linear(32 * 8 * 8, 10)  # 假设经过两次池化后尺寸变为8x8

    def forward(self, x):
        x = self.pool(self.relu(self.bn1(self.conv1(x))))
        x = self.pool(self.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 32 * 8 * 8)  # 扁平化
        x = self.fc(x)
        return x

# 数据预处理
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载CIFAR-10数据集
trainset = datasets.CIFAR10(root='./data', train=True,
                            download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False,
                           download=True, transform=transform)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# 初始化模型、损失函数和优化器
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / (i + 1)}')

print('Finished Training.')

# 测试模型
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')

这段代码展示了如何定义一个简单的CNN模型,加载CIFAR-10数据集,并进行训练和测试。请根据您的具体需求调整模型结构、训练轮数、学习率等参数。如果您的任务更加复杂或数据集不同,可能需要更复杂的模型架构和训练策略。

相关推荐
文火冰糖的硅基工坊20 小时前
[嵌入式系统-83]:算力芯片的类型与主流架构
人工智能·重构·架构
视觉语言导航1 天前
ICRA-2025 | 阿德莱德机器人拓扑导航探索!TANGO:具有局部度量控制的拓扑目标可穿越性感知具身导航
人工智能·机器人·具身智能
西猫雷婶1 天前
CNN卷积计算
人工智能·神经网络·cnn
格林威1 天前
常规线扫描镜头有哪些类型?能做什么?
人工智能·深度学习·数码相机·算法·计算机视觉·视觉检测·工业镜头
lyx33136967591 天前
#深度学习基础:神经网络基础与PyTorch
pytorch·深度学习·神经网络·参数初始化
倔强青铜三1 天前
苦练Python第63天:零基础玩转TOML配置读写,tomllib模块实战
人工智能·python·面试
递归不收敛1 天前
吴恩达机器学习课程(PyTorch 适配)学习笔记:3.3 推荐系统全面解析
pytorch·学习·机器学习
B站计算机毕业设计之家1 天前
智慧交通项目:Python+YOLOv8 实时交通标志系统 深度学习实战(TT100K+PySide6 源码+文档)✅
人工智能·python·深度学习·yolo·计算机视觉·智慧交通·交通标志
高工智能汽车1 天前
棱镜观察|极氪销量遇阻?千里智驾左手服务吉利、右手对标华为
人工智能·华为
txwtech1 天前
第6篇 OpenCV RotatedRect如何判断矩形的角度
人工智能·opencv·计算机视觉