实践 PyTorch 手写数字识别

py 版本:Python 3.12.7

安装库: pip install numpy torch torchvision matplotlib

运行: python test.py

py版本不对可能无法运行,默认数据集需要科学上网才能下载,默认的验证代码是从验证包里取图片,注释的代码是我本地构造的图片,用库里的图片,替换纯色背景,手写一个数字,大小改到28x28就可以验证了

python 复制代码
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from PIL import Image


class Net(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(28*28, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)
    
    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x


def get_data_loader(is_train):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
    return DataLoader(data_set, batch_size=15, shuffle=True)


def evaluate(test_data, net):
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            outputs = net.forward(x.view(-1, 28*28))
            for i, output in enumerate(outputs):
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1
    return n_correct / n_total


def main():

    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = Net()
    
    print("initial accuracy:", evaluate(test_data, net))
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(2):
        for (x, y) in train_data:
            net.zero_grad()
            output = net.forward(x.view(-1, 28*28))
            loss = torch.nn.functional.nll_loss(output, y)
            loss.backward()
            optimizer.step()
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))

    for (n, (x, _)) in enumerate(test_data):
        if n > 3:
            break
        predict = torch.argmax(net.forward(x[0].view(-1, 28*28)))
        plt.figure(n)
        plt.imshow(x[0].view(28, 28))
        plt.title("prediction: " + str(int(predict)))
    plt.show()

def load_custom_image(image_path):
    """ 加载自定义手写数字图片,并转换为 MNIST 兼容格式 """
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),  # 转换为灰度图
        transforms.Resize((28, 28)),  # 调整大小
        transforms.ToTensor(),  # 转换为 PyTorch 张量
        transforms.Normalize((0.1307,), (0.3081,))  # 使用 MNIST 的归一化参数
    ])
    image = Image.open(image_path)
    return transform(image).unsqueeze(0)  # 添加 batch 维度

if __name__ == "__main__":
    main()
    """
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = Net()
    
    print("initial accuracy:", evaluate(test_data, net))
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(2):
        for (x, y) in train_data:
            net.zero_grad()
            output = net.forward(x.view(-1, 28*28))
            loss = torch.nn.functional.nll_loss(output, y)
            loss.backward()
            optimizer.step()
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))
    
    image_tensor = load_custom_image("C:\\Users\\we\\Desktop\\7.png")
    predict = torch.argmax(net.forward(image_tensor.view(-1, 28*28)))
    print("prediction: " + str(int(predict)))
    """

默认代码验证结果

手写图片

验证结果

来源:【10分钟入门神经网络 PyTorch 手写数字识别】 https://www.bilibili.com/video/BV1GC4y15736/?share_source=copy_web\&vd_source=33a387ed337161d7e4f60dd9167ab954

相关推荐
数据科学作家2 小时前
学数据分析必囤!数据分析必看!清华社9本书覆盖Stata/SPSS/Python全阶段学习路径
人工智能·python·机器学习·数据分析·统计·stata·spss
HXQ_晴天4 小时前
CASToR 生成的文件进行转换
python
CV缝合救星4 小时前
【Arxiv 2025 预发行论文】重磅突破!STAR-DSSA 模块横空出世:显著性+拓扑双重加持,小目标、大场景统统拿下!
人工智能·深度学习·计算机视觉·目标跟踪·即插即用模块
java1234_小锋5 小时前
Scikit-learn Python机器学习 - 特征预处理 - 标准化 (Standardization):StandardScaler
python·机器学习·scikit-learn
Python×CATIA工业智造5 小时前
Python带状态生成器完全指南:从基础到高并发系统设计
python·pycharm
向qian看_-_5 小时前
Linux 使用pip报错(error: externally-managed-environment )解决方案
linux·python·pip
Nicole-----5 小时前
Python - Union联合类型注解
开发语言·python
TDengine (老段)6 小时前
从 ETL 到 Agentic AI:工业数据管理变革与 TDengine IDMP 的治理之道
数据库·数据仓库·人工智能·物联网·时序数据库·etl·tdengine
蓝桉8026 小时前
如何进行神经网络的模型训练(视频代码中的知识点记录)
人工智能·深度学习·神经网络
星期天要睡觉7 小时前
深度学习——数据增强(Data Augmentation)
人工智能·深度学习