pytorch跑手写体实验

目录

1、环境条件

2、代码实现

3、总结


1、环境条件

  1. pycharm编译器
  2. pytorch依赖
  3. matplotlib依赖
  4. numpy依赖等等

2、代码实现

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义数据变换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载 MNIST 数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# 定义 LeNet-5 模型
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 初始化模型、损失函数和优化器
model = LeNet5().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
epochs = 5
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')

# 保存模型
torch.save(model.state_dict(), 'lenet5.pth')
print('Model saved to lenet5.pth')

# 加载模型
model = LeNet5()
model.load_state_dict(torch.load('lenet5.pth'))
model.to(device)
model.eval()

# 在测试集上评估模型
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on the test set: {100 * correct / total:.2f}%')

# 加载并预处理本地图片进行预测
from PIL import Image

def load_and_preprocess_image(image_path):
    img = Image.open(image_path).convert('L')  # 转为灰度图
    img = img.resize((28, 28))
    img = np.array(img, dtype=np.float32)
    img = (img / 255.0 - 0.5) / 0.5  # 归一化到[-1, 1]
    img = torch.tensor(img).unsqueeze(0).unsqueeze(0)  # 添加批次和通道维度
    return img.to(device)

# 预测本地图片
image_path = '4.png'  # 替换为你的本地图片路径
img = load_and_preprocess_image(image_path)

# 使用加载的模型进行预测
model.eval()
with torch.no_grad():
    outputs = model(img)
    _, predicted = torch.max(outputs, 1)

# 打印预测结果
predicted_label = predicted.item()
print(f'预测结果: {predicted_label}')

# 显示图片及预测结果
img_np = img.cpu().numpy().squeeze()
plt.imshow(img_np, cmap='gray')
plt.title(f'预测结果: {predicted_label}')
plt.show()

解释:torch.save()方法完成模型的保存,image_path为本地图片,用于测试

3、总结

安装环境是比较难的点,均使用pip install 。。指令进行依赖环境的安装,其他的比较简单。

学习之所以会想睡觉,是因为那是梦开始的地方。

ଘ(੭ˊᵕˋ)੭ (开心) ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)

------不写代码不会凸的小刘

相关推荐
Q_Q51100828515 分钟前
python+django/flask的眼科患者随访管理系统 AI智能模型
spring boot·python·django·flask·node.js·php
忙碌54417 分钟前
AI大模型时代下的全栈技术架构:从深度学习到云原生部署实战
人工智能·深度学习·架构
LZ_Keep_Running19 分钟前
智能变电巡检:AI检测新突破
人工智能
InfiSight智睿视界1 小时前
AI 技术助力汽车美容行业实现精细化运营管理
大数据·人工智能
没有钱的钱仔2 小时前
机器学习笔记
人工智能·笔记·机器学习
听风吹等浪起2 小时前
基于改进TransUNet的港口船只图像分割系统研究
人工智能·深度学习·cnn·transformer
SunnyDays10112 小时前
如何使用Python高效转换Excel到HTML
python·excel转html
Q_Q5110082852 小时前
python+django/flask的在线学习系统的设计与实现 积分兑换礼物
spring boot·python·django·flask·node.js·php
化作星辰2 小时前
深度学习_原理和进阶_PyTorch入门(2)后续语法3
人工智能·pytorch·深度学习
boonya2 小时前
ChatBox AI 中配置阿里云百炼模型实现聊天对话
人工智能·阿里云·云计算·chatboxai