pytorch跑手写体实验

目录

1、环境条件

2、代码实现

3、总结


1、环境条件

  1. pycharm编译器
  2. pytorch依赖
  3. matplotlib依赖
  4. numpy依赖等等

2、代码实现

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义数据变换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载 MNIST 数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# 定义 LeNet-5 模型
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 初始化模型、损失函数和优化器
model = LeNet5().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
epochs = 5
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')

# 保存模型
torch.save(model.state_dict(), 'lenet5.pth')
print('Model saved to lenet5.pth')

# 加载模型
model = LeNet5()
model.load_state_dict(torch.load('lenet5.pth'))
model.to(device)
model.eval()

# 在测试集上评估模型
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on the test set: {100 * correct / total:.2f}%')

# 加载并预处理本地图片进行预测
from PIL import Image

def load_and_preprocess_image(image_path):
    img = Image.open(image_path).convert('L')  # 转为灰度图
    img = img.resize((28, 28))
    img = np.array(img, dtype=np.float32)
    img = (img / 255.0 - 0.5) / 0.5  # 归一化到[-1, 1]
    img = torch.tensor(img).unsqueeze(0).unsqueeze(0)  # 添加批次和通道维度
    return img.to(device)

# 预测本地图片
image_path = '4.png'  # 替换为你的本地图片路径
img = load_and_preprocess_image(image_path)

# 使用加载的模型进行预测
model.eval()
with torch.no_grad():
    outputs = model(img)
    _, predicted = torch.max(outputs, 1)

# 打印预测结果
predicted_label = predicted.item()
print(f'预测结果: {predicted_label}')

# 显示图片及预测结果
img_np = img.cpu().numpy().squeeze()
plt.imshow(img_np, cmap='gray')
plt.title(f'预测结果: {predicted_label}')
plt.show()

解释:torch.save()方法完成模型的保存,image_path为本地图片,用于测试

3、总结

安装环境是比较难的点,均使用pip install 。。指令进行依赖环境的安装,其他的比较简单。

学习之所以会想睡觉,是因为那是梦开始的地方。

ଘ(੭ˊᵕˋ)੭ (开心) ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)

------不写代码不会凸的小刘

相关推荐
小爬虫程序猿几秒前
利用Python爬虫速卖通按关键字搜索AliExpress商品
开发语言·爬虫·python
诚威_lol_中大努力中14 分钟前
关于VQ-GAN利用滑动窗口生成 高清图像
人工智能·神经网络·生成对抗网络
Eiceblue19 分钟前
使用Python获取PDF文本和图片的精确位置
开发语言·python·pdf
我叫czc22 分钟前
【Python高级353】python实现多线程版本的TCP服务器
服务器·python·tcp/ip
爱数学的程序猿26 分钟前
Python入门:6.深入解析Python中的序列
android·服务器·python
中关村科金35 分钟前
中关村科金智能客服机器人如何解决客户个性化需求与标准化服务之间的矛盾?
人工智能·机器人·在线客服·智能客服机器人·中关村科金
逸_38 分钟前
Product Hunt 今日热榜 | 2024-12-25
人工智能
Luke Ewin44 分钟前
基于3D-Speaker进行区分说话人项目搭建过程报错记录 | 通话录音说话人区分以及语音识别 | 声纹识别以及语音识别 | pyannote-audio
人工智能·语音识别·声纹识别·通话录音区分说话人
comli_cn1 小时前
使用清华源安装python包
开发语言·python
DashVector1 小时前
如何通过HTTP API检索Doc
数据库·人工智能·http·阿里云·数据库开发·向量检索