【深度学习】pytorch,MNIST手写数字分类

efficientnet_b0的迁移学习

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import models
import matplotlib.pyplot as plt

# 定义超参数
batch_size = 240
learning_rate = 0.001
num_epochs = 10

# 数据预处理,包括调整图像大小并将单通道图像复制到三个通道
transform = transforms.Compose([
    transforms.Resize(224),  # 调整图像大小以适应EfficientNetB0
    transforms.Grayscale(num_output_channels=3),  # 将单通道图像复制到三个通道
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 使用ImageNet的均值和标准差
])

# 加载数据集
train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)

# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=32)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载预训练的EfficientNetB0模型并调整最后的分类层
model = models.efficientnet_b0(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)  # MNIST共10个类别
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 用于绘图的数据
train_losses = []
test_accuracies = []

# 训练模型
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        print(f"\rEpoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    # 计算平均损失
    avg_loss = running_loss / len(train_loader)
    train_losses.append(avg_loss)

    # 测试准确率
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move test data to the correct device
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    test_accuracies.append(accuracy)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%')

# save
torch.save(model.state_dict(), 'mnist_efficientnetb0.pth')

# 绘制损失函数和准确率图
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.show()

训练10轮,测试准确率很猛:

Epoch 10/10, Loss: 0.0087, Test Accuracy: 99.60%

相关推荐
禁默1 分钟前
Ops-Transformer深入:CANN生态Transformer专用算子库赋能多模态生成效率跃迁
人工智能·深度学习·transformer·cann
杜子不疼.3 分钟前
基于CANN GE图引擎的深度学习模型编译与优化技术
人工智能·深度学习
chaser&upper13 分钟前
预见未来:在 AtomGit 解码 CANN ops-nn 的投机采样加速
人工智能·深度学习·神经网络
机器懒得学习29 分钟前
智能股票分析系统
python·深度学习·金融
vx_biyesheji000133 分钟前
豆瓣电影推荐系统 | Python Django 协同过滤 Echarts可视化 深度学习 大数据 毕业设计源码
大数据·爬虫·python·深度学习·django·毕业设计·echarts
饭饭大王6661 小时前
当 AI 系统开始“自省”——在 `ops-transformer` 中嵌入元认知能力
人工智能·深度学习·transformer
TechWJ2 小时前
CANN ops-nn神经网络算子库技术剖析:NPU加速的基石
人工智能·深度学习·神经网络·cann·ops-nn
心疼你的一切2 小时前
拆解 CANN 仓库:实现 AIGC 文本生成昇腾端部署
数据仓库·深度学习·aigc·cann
哈__2 小时前
CANN加速图神经网络GNN推理:消息传递与聚合优化
人工智能·深度学习·神经网络
User_芊芊君子2 小时前
CANN_MetaDef图定义框架全解析为AI模型构建灵活高效的计算图表示
人工智能·深度学习·神经网络