卷积神经网络实现图像分类

复制代码
# 1.导入依赖包
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose
import torch.optim as optim
from torch.utils.data import DataLoader
import time
import matplotlib.pyplot as plt
from torchsummary import summary

BATCH_SIZE = 8


# 2. 获取数据集
def create_dataset():
    # 加载数据集:训练集数据和测试数据
    train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))
    valid = CIFAR10(root='data', train=False, transform=Compose([ToTensor()]))
    # 返回数据集结果
    return train, valid


# if __name__ == '__main__':
#     # 数据集加载
#     train_dataset, valid_dataset = create_dataset()
#     # 数据集类别
#     print("数据集类别:", train_dataset.class_to_idx)
#     # 数据集中的图像数据
#     print("训练集数据集:", train_dataset.data.shape)
#     print("测试集数据集:", valid_dataset.data.shape)
#     # 图像展示
#     plt.figure(figsize=(2, 2))
#     plt.imshow(train_dataset.data[1])
#     plt.title(train_dataset.targets[1])
#     plt.show()


# 3.模型构建
class ImageClassification(nn.Module):
    # 定义网络结构
    def __init__(self):
        super(ImageClassification, self).__init__()
        # 定义网络层:卷积层+池化层
        self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 全连接层
        self.linear1 = nn.Linear(576, 120)
        self.linear2 = nn.Linear(120, 84)
        self.out = nn.Linear(84, 10)

    # 定义前向传播
    def forward(self, x):
        # 卷积+relu+池化
        x = torch.relu(self.conv1(x))
        x = self.pool1(x)
        # 卷积+relu+池化
        x = torch.relu(self.conv2(x))
        x = self.pool2(x)
        # 将特征图做成以为向量的形式:相当于特征向量
        x = x.reshape(x.size(0), -1)
        # 全连接层
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        # 返回输出结果
        return self.out(x)


# if __name__ == '__main__':
#     # 模型实例化
#     model = ImageClassification()
#     summary(model, input_size=(3, 32, 32), batch_size=1)

# 4.训练函数编写
def train(model, train_dataset):
    criterion = nn.CrossEntropyLoss()  # 构建损失函数
    optimizer = optim.Adam(model.parameters(), lr=1e-3)  # 构建优化方法
    epoch = 20  # 训练轮数
    for epoch_idx in range(epoch):
        # 构建数据加载器
        dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        sam_num = 0  # 样本数量
        total_loss = 0.0  # 损失总和
        start = time.time()  # 开始时间
        # 遍历数据进行网络训练
        for x, y in dataloader:
            output = model(x)
            loss = criterion(output, y)  # 计算损失
            optimizer.zero_grad()  # 梯度清零
            loss.backward()  # 反向传播
            optimizer.step()  # 参数更新
            total_loss += loss.item()  # 统计损失和
            sam_num += 1
        print('epoch:%2s loss:%.5f time:%.2fs' % (epoch_idx + 1, total_loss / sam_num, time.time() - start))
    # 模型保存
    torch.save(model.state_dict(), 'data/image_classification.pth')




def test(valid_dataset):
    # 构建数据加载器
    dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
    # 加载模型并加载训练好的权重
    model = ImageClassification()
    model.load_state_dict(torch.load('data/image_classification.pth'))
    model.eval()
    # 计算精度
    total_correct = 0
    total_samples = 0
    # 遍历每个batch的数据,获取预测结果,计算精度
    for x, y in dataloader:
        output = model(x)
        total_correct += (torch.argmax(output, dim=-1) == y).sum()
        total_samples += len(y)
        # 打印精度
    print('Acc: %.2f' % (total_correct / total_samples))


if __name__ == '__main__':
    # 数据集加载
    train_dataset, valid_dataset = create_dataset()
    # 模型实例化
    model = ImageClassification()
    # 模型训练
    # train(model, train_dataset)
    # 模型预测
    test(valid_dataset)
相关推荐
胡耀超4 小时前
DataOceanAI Dolphin(ffmpeg音频转化教程) 多语言(中国方言)语音识别系统部署与应用指南
python·深度学习·ffmpeg·音视频·语音识别·多模态·asr
HUIMU_5 小时前
DAY12&DAY13-新世纪DL(Deeplearning/深度学习)战士:破(改善神经网络)1
人工智能·深度学习
mit6.8246 小时前
[1Prompt1Story] 注意力机制增强 IPCA | 去噪神经网络 UNet | U型架构分步去噪
人工智能·深度学习·神经网络
Coovally AI模型快速验证6 小时前
YOLO、DarkNet和深度学习如何让自动驾驶看得清?
深度学习·算法·yolo·cnn·自动驾驶·transformer·无人机
科大饭桶7 小时前
昇腾AI自学Day2-- 深度学习基础工具与数学
人工智能·pytorch·python·深度学习·numpy
努力还债的学术吗喽7 小时前
2021 IEEE【论文精读】用GAN让音频隐写术骗过AI检测器 - 对抗深度学习的音频信息隐藏
人工智能·深度学习·生成对抗网络·密码学·音频·gan·隐写
weixin_507929919 小时前
第G7周:Semi-Supervised GAN 理论与实战
人工智能·pytorch·深度学习
AI波克布林11 小时前
发文暴论!线性注意力is all you need!
人工智能·深度学习·神经网络·机器学习·注意力机制·线性注意力
Blossom.11811 小时前
把 AI 推理塞进「 8 位 MCU 」——0.5 KB RAM 跑通关键词唤醒的魔幻之旅
人工智能·笔记·单片机·嵌入式硬件·深度学习·机器学习·搜索引擎
2502_9271612813 小时前
DAY 40 训练和测试的规范写法
人工智能·深度学习·机器学习