pytorch跑手写体实验

目录

1、环境条件

2、代码实现

3、总结


1、环境条件

  1. pycharm编译器
  2. pytorch依赖
  3. matplotlib依赖
  4. numpy依赖等等

2、代码实现

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义数据变换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载 MNIST 数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# 定义 LeNet-5 模型
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 初始化模型、损失函数和优化器
model = LeNet5().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
epochs = 5
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')

# 保存模型
torch.save(model.state_dict(), 'lenet5.pth')
print('Model saved to lenet5.pth')

# 加载模型
model = LeNet5()
model.load_state_dict(torch.load('lenet5.pth'))
model.to(device)
model.eval()

# 在测试集上评估模型
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on the test set: {100 * correct / total:.2f}%')

# 加载并预处理本地图片进行预测
from PIL import Image

def load_and_preprocess_image(image_path):
    img = Image.open(image_path).convert('L')  # 转为灰度图
    img = img.resize((28, 28))
    img = np.array(img, dtype=np.float32)
    img = (img / 255.0 - 0.5) / 0.5  # 归一化到[-1, 1]
    img = torch.tensor(img).unsqueeze(0).unsqueeze(0)  # 添加批次和通道维度
    return img.to(device)

# 预测本地图片
image_path = '4.png'  # 替换为你的本地图片路径
img = load_and_preprocess_image(image_path)

# 使用加载的模型进行预测
model.eval()
with torch.no_grad():
    outputs = model(img)
    _, predicted = torch.max(outputs, 1)

# 打印预测结果
predicted_label = predicted.item()
print(f'预测结果: {predicted_label}')

# 显示图片及预测结果
img_np = img.cpu().numpy().squeeze()
plt.imshow(img_np, cmap='gray')
plt.title(f'预测结果: {predicted_label}')
plt.show()

解释:torch.save()方法完成模型的保存,image_path为本地图片,用于测试

3、总结

安装环境是比较难的点,均使用pip install 。。指令进行依赖环境的安装,其他的比较简单。

学习之所以会想睡觉,是因为那是梦开始的地方。

ଘ(੭ˊᵕˋ)੭ (开心) ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)

------不写代码不会凸的小刘

相关推荐
曲幽几秒前
FastAPI + Celery 实战:异步任务的坑与解法,我帮你踩了一遍
redis·python·fastapi·web·async·celery·background·task·queue
亦复何言??6 分钟前
BeyondMimic 论文解析
人工智能·算法·机器人
深蓝海拓8 分钟前
使用@property将类方法包装为属性
开发语言·python
Lee川9 分钟前
🛠️ LangChain Tools 实战指南:让 AI 拥有“动手能力”
人工智能
gorgeous(๑>؂<๑)10 分钟前
【CVPR26-索尼】EW-DETR:通过增量低秩检测Transformer实现动态世界目标检测
人工智能·深度学习·目标检测·计算机视觉·transformer
xianluohuanxiang14 分钟前
新能源功率预测的“生死局”:从“能报曲线”到“能做收益”,中间差的不是一点算法
人工智能
码农垦荒笔记31 分钟前
Claude Code 2026 年 3 月全面进化:Auto 模式、Computer Use 与云端持续执行重塑 AI 编程工作流
人工智能·ai 编程·claude code·agentic coding·computer use
threerocks37 分钟前
【Claude Code 系列课程】01 | Claude Code 架构全览
人工智能·ai编程·claude
熊猫代跑得快39 分钟前
Agent 通用架构入门学习
人工智能·agent·智能体
格林威39 分钟前
Baumer相机锂电池极片裁切毛刺检测:防止内部短路的 5 个核心方法,附 OpenCV+Halcon 实战代码!
开发语言·人工智能·数码相机·opencv·计算机视觉·c#·视觉检测