【深度学习】pytorch,MNIST手写数字分类

efficientnet_b0的迁移学习

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import models
import matplotlib.pyplot as plt

# 定义超参数
batch_size = 240
learning_rate = 0.001
num_epochs = 10

# 数据预处理,包括调整图像大小并将单通道图像复制到三个通道
transform = transforms.Compose([
    transforms.Resize(224),  # 调整图像大小以适应EfficientNetB0
    transforms.Grayscale(num_output_channels=3),  # 将单通道图像复制到三个通道
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 使用ImageNet的均值和标准差
])

# 加载数据集
train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)

# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=32)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载预训练的EfficientNetB0模型并调整最后的分类层
model = models.efficientnet_b0(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)  # MNIST共10个类别
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 用于绘图的数据
train_losses = []
test_accuracies = []

# 训练模型
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        print(f"\rEpoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    # 计算平均损失
    avg_loss = running_loss / len(train_loader)
    train_losses.append(avg_loss)

    # 测试准确率
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move test data to the correct device
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    test_accuracies.append(accuracy)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%')

# save
torch.save(model.state_dict(), 'mnist_efficientnetb0.pth')

# 绘制损失函数和准确率图
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.show()

训练10轮,测试准确率很猛:

Epoch 10/10, Loss: 0.0087, Test Accuracy: 99.60%

相关推荐
电棍2338 分钟前
AUTODL服务器环境配置和下载数据概述
运维·深度学习·uv
Coding茶水间33 分钟前
基于深度学习的交通事故检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
人工智能·深度学习·yolo·目标检测·机器学习
传说故事1 小时前
【论文自动阅读】How Much 3D Do Video Foundation Models Encode?
人工智能·深度学习·3d
囊中之锥.2 小时前
神经网络原理通俗讲解:结构、偏置、损失函数与梯度下降
人工智能·深度学习·神经网络
棒棒的皮皮2 小时前
YOLO 拓展应用全解析(目标跟踪 / 实例分割 / 姿态估计 / 多目标检测)
深度学习·yolo·目标检测·计算机视觉·目标跟踪
子午2 小时前
【2026原创】眼底眼疾识别系统~Python+深度学习+人工智能+CNN卷积神经网络算法+图像识别
人工智能·python·深度学习
Ai尚研修-贾莲2 小时前
自然科学领域机器学习与深度学习——高维数据预处理—可解释ML/DL—时空建模—不确定性量化-全程AI+Python场景
人工智能·深度学习·机器学习·自然科学·时空建模·高维数据预处理·可解释ml/dl
赵域Phoenix2 小时前
赵煜的时序建模学习手札——三种路线概览(统计学/机器学习/深度学习)
深度学习·机器学习
qq_571099352 小时前
学习周报三十一
人工智能·深度学习·学习
五羟基己醛2 小时前
【深度学习项目】Gan网络下的SAR目标增广
人工智能·深度学习·生成对抗网络