卷积神经网络实现图像分类

复制代码
# 1.导入依赖包
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose
import torch.optim as optim
from torch.utils.data import DataLoader
import time
import matplotlib.pyplot as plt
from torchsummary import summary

BATCH_SIZE = 8


# 2. 获取数据集
def create_dataset():
    # 加载数据集:训练集数据和测试数据
    train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))
    valid = CIFAR10(root='data', train=False, transform=Compose([ToTensor()]))
    # 返回数据集结果
    return train, valid


# if __name__ == '__main__':
#     # 数据集加载
#     train_dataset, valid_dataset = create_dataset()
#     # 数据集类别
#     print("数据集类别:", train_dataset.class_to_idx)
#     # 数据集中的图像数据
#     print("训练集数据集:", train_dataset.data.shape)
#     print("测试集数据集:", valid_dataset.data.shape)
#     # 图像展示
#     plt.figure(figsize=(2, 2))
#     plt.imshow(train_dataset.data[1])
#     plt.title(train_dataset.targets[1])
#     plt.show()


# 3.模型构建
class ImageClassification(nn.Module):
    # 定义网络结构
    def __init__(self):
        super(ImageClassification, self).__init__()
        # 定义网络层:卷积层+池化层
        self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 全连接层
        self.linear1 = nn.Linear(576, 120)
        self.linear2 = nn.Linear(120, 84)
        self.out = nn.Linear(84, 10)

    # 定义前向传播
    def forward(self, x):
        # 卷积+relu+池化
        x = torch.relu(self.conv1(x))
        x = self.pool1(x)
        # 卷积+relu+池化
        x = torch.relu(self.conv2(x))
        x = self.pool2(x)
        # 将特征图做成以为向量的形式:相当于特征向量
        x = x.reshape(x.size(0), -1)
        # 全连接层
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        # 返回输出结果
        return self.out(x)


# if __name__ == '__main__':
#     # 模型实例化
#     model = ImageClassification()
#     summary(model, input_size=(3, 32, 32), batch_size=1)

# 4.训练函数编写
def train(model, train_dataset):
    criterion = nn.CrossEntropyLoss()  # 构建损失函数
    optimizer = optim.Adam(model.parameters(), lr=1e-3)  # 构建优化方法
    epoch = 20  # 训练轮数
    for epoch_idx in range(epoch):
        # 构建数据加载器
        dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        sam_num = 0  # 样本数量
        total_loss = 0.0  # 损失总和
        start = time.time()  # 开始时间
        # 遍历数据进行网络训练
        for x, y in dataloader:
            output = model(x)
            loss = criterion(output, y)  # 计算损失
            optimizer.zero_grad()  # 梯度清零
            loss.backward()  # 反向传播
            optimizer.step()  # 参数更新
            total_loss += loss.item()  # 统计损失和
            sam_num += 1
        print('epoch:%2s loss:%.5f time:%.2fs' % (epoch_idx + 1, total_loss / sam_num, time.time() - start))
    # 模型保存
    torch.save(model.state_dict(), 'data/image_classification.pth')




def test(valid_dataset):
    # 构建数据加载器
    dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
    # 加载模型并加载训练好的权重
    model = ImageClassification()
    model.load_state_dict(torch.load('data/image_classification.pth'))
    model.eval()
    # 计算精度
    total_correct = 0
    total_samples = 0
    # 遍历每个batch的数据,获取预测结果,计算精度
    for x, y in dataloader:
        output = model(x)
        total_correct += (torch.argmax(output, dim=-1) == y).sum()
        total_samples += len(y)
        # 打印精度
    print('Acc: %.2f' % (total_correct / total_samples))


if __name__ == '__main__':
    # 数据集加载
    train_dataset, valid_dataset = create_dataset()
    # 模型实例化
    model = ImageClassification()
    # 模型训练
    # train(model, train_dataset)
    # 模型预测
    test(valid_dataset)
相关推荐
CoovallyAIHub3 小时前
港大&字节重磅发布DanceGRPO:突破视觉生成RLHF瓶颈,多项任务性能提升超180%!
深度学习·算法·计算机视觉
CoovallyAIHub4 小时前
英伟达ViPE重磅发布!解决3D感知难题,SLAM+深度学习完美融合(附带数据集下载地址)
深度学习·算法·计算机视觉
惯导马工1 天前
【论文导读】ORB-SLAM3:An Accurate Open-Source Library for Visual, Visual-Inertial and
深度学习·算法
隐语SecretFlow2 天前
国人自研开源隐私计算框架SecretFlow,深度拆解框架及使用【开发者必看】
深度学习
Billy_Zuo2 天前
人工智能深度学习——卷积神经网络(CNN)
人工智能·深度学习·cnn
羊羊小栈2 天前
基于「YOLO目标检测 + 多模态AI分析」的遥感影像目标检测分析系统(vue+flask+数据集+模型训练)
人工智能·深度学习·yolo·目标检测·毕业设计·大作业
l12345sy2 天前
Day24_【深度学习—广播机制】
人工智能·pytorch·深度学习·广播机制
IT古董2 天前
【第五章:计算机视觉-项目实战之图像分类实战】1.经典卷积神经网络模型Backbone与图像-(4)经典卷积神经网络ResNet的架构讲解
人工智能·计算机视觉·cnn
九章云极AladdinEdu2 天前
超参数自动化调优指南:Optuna vs. Ray Tune 对比评测
运维·人工智能·深度学习·ai·自动化·gpu算力
研梦非凡3 天前
ICCV 2025|从粗到细:用于高效3D高斯溅射的可学习离散小波变换
人工智能·深度学习·学习·3d