Pytorch笔记之分类

文章目录


前言

使用Pytorch进行MNIST分类,使用TensorDataset与DataLoader封装、加载本地数据集。


一、导入库

python 复制代码
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader # 数据集工具
from load_mnist import load_mnist # 本地数据集

二、数据处理

1、导入本地数据集,将标签值设置为int类型,构建张量

2、使用TensorDataset与DataLoader封装训练集与测试集

python 复制代码
# 构建数据
x_train, y_train, x_test, y_test = \
    load_mnist(normalize=True, flatten=False, one_hot_label=False)
# 数据处理
x_train = torch.from_numpy(x_train.astype(np.float32))
y_train = torch.from_numpy(y_train.astype(np.int64))
x_test = torch.from_numpy(x_test.astype(np.float32))
y_test = torch.from_numpy(y_test.astype(np.int64))
# 数据集封装
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)
batch_size = 64
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size,
                          shuffle=True)

三、构建模型

输入到全连接层之前需要把(batch_size,28,28)展平为(batch_size,784)

交叉熵损失函数整合了Softmax,在模型中可以不添加Softmax

python 复制代码
# 继承模型
class FC(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 10)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x):
        y = self.fc1(x.view(x.shape[0],-1))
        y = self.softmax(y)
        return y
# 定义模型
model = FC()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

四、迭代训练

从DataLoader中取出x和y,进行前向和反向的计算

python 复制代码
for epoch in range(10):
    print('Epoch:', epoch)
    for i,data in enumerate(train_loader):
        x, y = data
        y_pred = model.forward(x)
        loss = loss_function(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

五、模型评估

在测试集中进行验证

使用.item()获得tensor的取值

python 复制代码
	correct = 0
    for i,data in enumerate(test_loader):
        x, y = data
        y_pred = model.forward(x)
        _, y_pred = torch.max(y_pred, 1)
        correct += (y_pred == y).sum().item()
    acc = correct / len(test_dataset)
    print('Accuracy:{:.2%}'.format(acc))

总结

记录了TensorDataset与DataLoader的使用方法,模型的构建与训练和上一篇Pytorch笔记之回归相似。

相关推荐
Sincerelyplz13 分钟前
【Temproal】快速了解Temproal的核心概念以及使用
笔记·后端·开源
Yo_Becky2 小时前
【PyTorch】PyTorch预训练模型缓存位置迁移,也可拓展应用于其他文件的迁移
人工智能·pytorch·经验分享·笔记·python·程序人生·其他
xinxiangwangzhi_2 小时前
pytorch底层原理学习--PyTorch 架构梳理
人工智能·pytorch·架构
DIY机器人工房2 小时前
0.96寸OLED显示屏 江协科技学习笔记(36个知识点)
笔记·科技·stm32·单片机·嵌入式硬件·学习·江协科技
FF-Studio3 小时前
【硬核数学 · LLM篇】3.1 Transformer之心:自注意力机制的线性代数解构《从零构建机器学习、深度学习到LLM的数学认知》
人工智能·pytorch·深度学习·线性代数·机器学习·数学建模·transformer
future14123 小时前
每日问题总结
经验分享·笔记
循环过三天5 小时前
3-1 PID算法改进(积分部分)
笔记·stm32·单片机·学习·算法·pid
之歆6 小时前
Python-封装和解构-set及操作-字典及操作-解析式生成器-内建函数迭代器-学习笔记
笔记·python·学习
DKPT7 小时前
Java组合模式实现方式与测试方法
java·笔记·学习·设计模式·组合模式
受之以蒙7 小时前
Rust & WASM 之 wasm-bindgen 基础:让 Rust 与 JavaScript 无缝对话
前端·笔记·rust