pytorch简单神经网络模型训练

目录

一、导入包

二、数据预处理

三、定义神经网络

四、训练模型和测试模型

五、程序入口


一、导入包

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim # 导入优化器
from torchvision import datasets, transforms # 导入数据集和数据预处理库
from torch.utils.data import DataLoader # 数据加载库

二、数据预处理

python 复制代码
def data_loader():
    '''数据的预处理'''

    # 定义数据预处理
    transform = transforms.Compose([

        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))
    ])

    # 加载FashionMNIST数据集
    train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.FashionMNIST(root='.data', train=False, download=True, transform=transform)

    # 数据集加载器
    train_loader = DataLoader( train_dataset,batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset,batch_size=64, shuffle=False)

    return train_loader, test_loader

这是一个常用于机器学习和深度学习研究中的数据集,包含了10类不同时尚商品的图像,每类有6000张训练图像和1000张测试图像。使用了PyTorch框架中的torchvision库来下载和加载Fashion-MNIST数据集。代码中定义了一个transform,它会将图像转换为张量,并对其进行归一化处理。然后,分别创建训练集和测试集的数据加载器train_loadertest_loader,这些加载器会在训练过程中以批量的形式提供数据。

三、定义神经网络

python 复制代码
# 定义神经网络
class QYNN(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28,128)
        self.fc2 = nn.Linear(128,10)

    def forward(self, x):
        # 将数据展平
        x = torch.flatten(x,start_dim=1)
        # 激活x,方便数据全联接
        x = torch.relu(self.fc1(x))
        # 输出10分类
        x = self.fc2(x)

        return x

代码定义了一个简单的全连接神经网络,适用于Fashion-MNIST这样的图像分类任务。这个网络包含两个全连接层(fc1fc2),分别用于特征提取和分类。

这里是您定义的QYNN类的一些解释:

  • __init__方法定义了网络的结构。网络接受28x28像素的灰度图像作为输入,首先通过一个线性层fc1将784个像素值映射到128个特征,然后通过第二个线性层fc2将128个特征映射到10个输出,对应于10个类别。

  • forward方法定义了数据通过网络的前向传播过程。输入数据首先被展平成一个一维向量,然后通过fc1层,接着是ReLU激活函数,最后通过fc2层输出每个类别的得分。

四、训练模型和测试模型

训练模型

python 复制代码
def train(model, train_loader):
    '''训练模型'''

    # 训练轮数
    epochs = 10
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            # 梯度清零
            optimizer.zero_grad()
            # 将图片塞进去
            outputs = model(inputs)
            # 计算损失
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 更新参数
            optimizer.step()
            # 损失值的累加
            running_loss += loss.item()
        print(f'Epoch:{epoch+1}/{epochs} | Loss: {running_loss/len(train_loader)}')

测试模型

test函数使用了torch.no_grad()来禁用梯度计算,因为在测试阶段我们不需要计算梯度。函数遍历测试数据加载器中的每个批次,将输入数据传递给模型以获取输出,然后使用torch.max函数来获取每个样本的最高得分类别作为预测结果。最后,函数计算预测正确的样本数量与总样本数量,从而得到准确率。

python 复制代码
def test(model, test_loader):
    '''测试模型'''

    correct = 0 # 正确的数量
    total = 0 # 样本的总量

    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)

            _, predicted = torch.max(outputs, 1)

            # 获取本次样本的数量
            total += labels.size(0)

            # 预测值 和 标签 相同则正确, 对预测值进行累加
            correct += (predicted == labels).sum().item()

    print(f'Test Accuracy: {correct / total:.2%}')

五、程序入口

python 复制代码
if __name__ == '__main__':
    # 设置随机种子
    torch.manual_seed(21)

    # 实例化神经网络
    model = QYNN()
    # 交叉商
    criterion = nn.CrossEntropyLoss()
    # 优化器
    optimizer = optim.SGD(model.parameters(),lr = 0.01)
    # 数据集
    train_loader, test_loader = data_loader()
    # 训练样本
    train(model, train_loader)
    # 测设样本
    test(model, test_loader)
相关推荐
碧海银沙音频科技研究院9 小时前
1-1杰理蓝牙SOC的UI配置开发方法
人工智能·深度学习·算法
龙文浩_11 小时前
AI梯度下降与PyTorch张量操作技术指南
人工智能·pytorch·python·深度学习·神经网络·机器学习·自然语言处理
清空mega12 小时前
动手学深度学习——样式迁移
人工智能·深度学习
MRDONG112 小时前
Prompt Engineering进阶指南
人工智能·深度学习·神经网络·机器学习·自然语言处理
QQ6765800813 小时前
基于深度学习YOLO的苹果采摘点图像识别 苹果枝条分割识别 苹果分割检测 苹果茎叶分割识别 果园自动化采摘设备目标识别算法第10386期
深度学习·yolo·自动化·苹果采摘点图像·苹果枝条分割·苹果茎叶分割·果园自动化采摘设备
碧海银沙音频科技研究院13 小时前
虚拟机ubuntu与windows共享文件夹(Samba共享)解决WSL加载SI工程满卡问题
人工智能·深度学习·算法
小江的记录本13 小时前
【Transformer架构】Transformer架构核心知识体系(包括自注意力机制、多头注意力、Encoder-Decoder结构)
java·人工智能·后端·python·深度学习·架构·transformer
AI先驱体验官13 小时前
债小白分析:债务优化服务的新变量、AI能否带来行业升级
大数据·人工智能·深度学习·重构·aigc
SomeB1oody14 小时前
【Python深度学习】2.1. 卷积神经网络(CNN)模型理论(基础):卷积运算、池化、ReLU函数
开发语言·人工智能·python·深度学习·机器学习·cnn
sp_fyf_202417 小时前
【大语言模型】 WizardLM:赋能大型预训练语言模型以遵循复杂指令
人工智能·深度学习·神经网络·语言模型·自然语言处理