卷积神经网络实现图像分类

复制代码
# 1.导入依赖包
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose
import torch.optim as optim
from torch.utils.data import DataLoader
import time
import matplotlib.pyplot as plt
from torchsummary import summary

BATCH_SIZE = 8


# 2. 获取数据集
def create_dataset():
    # 加载数据集:训练集数据和测试数据
    train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))
    valid = CIFAR10(root='data', train=False, transform=Compose([ToTensor()]))
    # 返回数据集结果
    return train, valid


# if __name__ == '__main__':
#     # 数据集加载
#     train_dataset, valid_dataset = create_dataset()
#     # 数据集类别
#     print("数据集类别:", train_dataset.class_to_idx)
#     # 数据集中的图像数据
#     print("训练集数据集:", train_dataset.data.shape)
#     print("测试集数据集:", valid_dataset.data.shape)
#     # 图像展示
#     plt.figure(figsize=(2, 2))
#     plt.imshow(train_dataset.data[1])
#     plt.title(train_dataset.targets[1])
#     plt.show()


# 3.模型构建
class ImageClassification(nn.Module):
    # 定义网络结构
    def __init__(self):
        super(ImageClassification, self).__init__()
        # 定义网络层:卷积层+池化层
        self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 全连接层
        self.linear1 = nn.Linear(576, 120)
        self.linear2 = nn.Linear(120, 84)
        self.out = nn.Linear(84, 10)

    # 定义前向传播
    def forward(self, x):
        # 卷积+relu+池化
        x = torch.relu(self.conv1(x))
        x = self.pool1(x)
        # 卷积+relu+池化
        x = torch.relu(self.conv2(x))
        x = self.pool2(x)
        # 将特征图做成以为向量的形式:相当于特征向量
        x = x.reshape(x.size(0), -1)
        # 全连接层
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        # 返回输出结果
        return self.out(x)


# if __name__ == '__main__':
#     # 模型实例化
#     model = ImageClassification()
#     summary(model, input_size=(3, 32, 32), batch_size=1)

# 4.训练函数编写
def train(model, train_dataset):
    criterion = nn.CrossEntropyLoss()  # 构建损失函数
    optimizer = optim.Adam(model.parameters(), lr=1e-3)  # 构建优化方法
    epoch = 20  # 训练轮数
    for epoch_idx in range(epoch):
        # 构建数据加载器
        dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        sam_num = 0  # 样本数量
        total_loss = 0.0  # 损失总和
        start = time.time()  # 开始时间
        # 遍历数据进行网络训练
        for x, y in dataloader:
            output = model(x)
            loss = criterion(output, y)  # 计算损失
            optimizer.zero_grad()  # 梯度清零
            loss.backward()  # 反向传播
            optimizer.step()  # 参数更新
            total_loss += loss.item()  # 统计损失和
            sam_num += 1
        print('epoch:%2s loss:%.5f time:%.2fs' % (epoch_idx + 1, total_loss / sam_num, time.time() - start))
    # 模型保存
    torch.save(model.state_dict(), 'data/image_classification.pth')




def test(valid_dataset):
    # 构建数据加载器
    dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
    # 加载模型并加载训练好的权重
    model = ImageClassification()
    model.load_state_dict(torch.load('data/image_classification.pth'))
    model.eval()
    # 计算精度
    total_correct = 0
    total_samples = 0
    # 遍历每个batch的数据,获取预测结果,计算精度
    for x, y in dataloader:
        output = model(x)
        total_correct += (torch.argmax(output, dim=-1) == y).sum()
        total_samples += len(y)
        # 打印精度
    print('Acc: %.2f' % (total_correct / total_samples))


if __name__ == '__main__':
    # 数据集加载
    train_dataset, valid_dataset = create_dataset()
    # 模型实例化
    model = ImageClassification()
    # 模型训练
    # train(model, train_dataset)
    # 模型预测
    test(valid_dataset)
相关推荐
冰西瓜6001 天前
深度学习的数学原理(十)—— 权重如何自发分工
人工智能·深度学习·计算机视觉
Faker66363aaa1 天前
Mask R-CNN实现植物存在性检测与分类详解_基于R50-FPN-GRoIE_1x_COCO模型分析
人工智能·分类·cnn
csdn_life181 天前
训练式推理:算力通缩时代下下一代AI部署范式的创新与落地
人工智能·深度学习·机器学习
Coding茶水间1 天前
基于深度学习的猪识别系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·python·深度学习·yolo·目标检测
Sunhen_Qiletian1 天前
深度学习之模型的部署、web框架 服务端及客户端案例
人工智能·深度学习
LaughingZhu1 天前
Product Hunt 每日热榜 | 2026-02-15
人工智能·经验分享·深度学习·神经网络·产品运营
cyforkk1 天前
YAML 配置文件中的常见陷阱:内联字典与块映射混用
人工智能·深度学习·机器学习
Katecat996631 天前
基于YOLO11-EfficientViT的辉长岩及其相关岩石类型计算机视觉识别分类系统_1
人工智能·计算机视觉·分类
月光有害1 天前
深入解析批归一化 (Batch Normalization): 稳定并加速深度学习的基石
开发语言·深度学习·batch
Suryxin.1 天前
从0开始复现nano-vllm「llm_engine.py」
人工智能·python·深度学习·ai·vllm