pytorch跑手写体实验

目录

1、环境条件

2、代码实现

3、总结


1、环境条件

  1. pycharm编译器
  2. pytorch依赖
  3. matplotlib依赖
  4. numpy依赖等等

2、代码实现

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义数据变换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载 MNIST 数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# 定义 LeNet-5 模型
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 初始化模型、损失函数和优化器
model = LeNet5().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
epochs = 5
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')

# 保存模型
torch.save(model.state_dict(), 'lenet5.pth')
print('Model saved to lenet5.pth')

# 加载模型
model = LeNet5()
model.load_state_dict(torch.load('lenet5.pth'))
model.to(device)
model.eval()

# 在测试集上评估模型
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on the test set: {100 * correct / total:.2f}%')

# 加载并预处理本地图片进行预测
from PIL import Image

def load_and_preprocess_image(image_path):
    img = Image.open(image_path).convert('L')  # 转为灰度图
    img = img.resize((28, 28))
    img = np.array(img, dtype=np.float32)
    img = (img / 255.0 - 0.5) / 0.5  # 归一化到[-1, 1]
    img = torch.tensor(img).unsqueeze(0).unsqueeze(0)  # 添加批次和通道维度
    return img.to(device)

# 预测本地图片
image_path = '4.png'  # 替换为你的本地图片路径
img = load_and_preprocess_image(image_path)

# 使用加载的模型进行预测
model.eval()
with torch.no_grad():
    outputs = model(img)
    _, predicted = torch.max(outputs, 1)

# 打印预测结果
predicted_label = predicted.item()
print(f'预测结果: {predicted_label}')

# 显示图片及预测结果
img_np = img.cpu().numpy().squeeze()
plt.imshow(img_np, cmap='gray')
plt.title(f'预测结果: {predicted_label}')
plt.show()

解释:torch.save()方法完成模型的保存,image_path为本地图片,用于测试

3、总结

安装环境是比较难的点,均使用pip install 。。指令进行依赖环境的安装,其他的比较简单。

学习之所以会想睡觉,是因为那是梦开始的地方。

ଘ(੭ˊᵕˋ)੭ (开心) ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)

------不写代码不会凸的小刘

相关推荐
Agent手记几秒前
制造业物流延迟预警系统,从0到1落地实操指南 | 企业级AI Agent架构实战
人工智能·ai
u0110225123 分钟前
如何自定义查询历史记录面板的展示风格_时间轴样式设计
jvm·数据库·python
2301_769340675 分钟前
HTML怎么实现快捷跳转顶部_HTML固定悬浮锚点按钮【介绍】
jvm·数据库·python
老马952711 分钟前
opencode7-桌面应用实战2
java·人工智能·后端
DogDaoDao12 分钟前
【GitHub】Ruflo:面向 Claude Code 的企业级多智能体编排平台深度解析
人工智能·深度学习·大模型·github·ai编程·claude·ruflo
yuanpan13 分钟前
Python + PyAutoGUI 实战:Windows 自动化办公脚本开发入门
windows·python·自动化
m0_6091604917 分钟前
MySQL如何限制触发器递归调用的深度_防止触发器死循环方法
jvm·数据库·python
璞华Purvar18 分钟前
2026化工新材料PLM行业白皮书:璞华易研,以垂直深耕重构研发数智底座
人工智能
广州灵眸科技有限公司18 分钟前
瑞芯微(EASY EAI)RV1126B yolov11-track多目标跟踪部署教程
linux·开发语言·网络·人工智能·yolo·机器学习·目标跟踪
灵动小溪20 分钟前
claude code工具PC安装部署
人工智能·算法