【深度学习】pytorch,MNIST手写数字分类

efficientnet_b0的迁移学习

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import models
import matplotlib.pyplot as plt

# 定义超参数
batch_size = 240
learning_rate = 0.001
num_epochs = 10

# 数据预处理,包括调整图像大小并将单通道图像复制到三个通道
transform = transforms.Compose([
    transforms.Resize(224),  # 调整图像大小以适应EfficientNetB0
    transforms.Grayscale(num_output_channels=3),  # 将单通道图像复制到三个通道
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 使用ImageNet的均值和标准差
])

# 加载数据集
train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)

# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=32)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载预训练的EfficientNetB0模型并调整最后的分类层
model = models.efficientnet_b0(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)  # MNIST共10个类别
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 用于绘图的数据
train_losses = []
test_accuracies = []

# 训练模型
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        print(f"\rEpoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    # 计算平均损失
    avg_loss = running_loss / len(train_loader)
    train_losses.append(avg_loss)

    # 测试准确率
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move test data to the correct device
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    test_accuracies.append(accuracy)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%')

# save
torch.save(model.state_dict(), 'mnist_efficientnetb0.pth')

# 绘制损失函数和准确率图
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.show()

训练10轮,测试准确率很猛:

Epoch 10/10, Loss: 0.0087, Test Accuracy: 99.60%

相关推荐
FL162386312911 小时前
如何使用目标检测深度学习框架yolov8训练钢管管道表面缺陷VOC+YOLO格式1159张3类别的检测数据集步骤和流程
深度学习·yolo·目标检测
网安INF11 小时前
深度学习中批标准化与神经网络调优
人工智能·深度学习·神经网络·机器学习
归去_来兮13 小时前
Transformer模型原理概述
人工智能·深度学习·transformer
酌沧13 小时前
深度学习 必然用到的 微积分知识
人工智能·深度学习
码字的字节14 小时前
深入理解Transformer架构:从理论到实践
深度学习·架构·transformer
HuashuiMu花水木16 小时前
PyTorch笔记3----------统计学相关函数
人工智能·pytorch·笔记
神经星星16 小时前
AI 论文周报 | Chai-2刷新抗体设计效率,命中率提高100倍;多篇ICML入围论文一键速览
人工智能·深度学习·机器学习
昵称是6硬币17 小时前
(DETR)End-to-End Object Detection with Transformers论文精读(逐段解析)
人工智能·深度学习·目标检测·计算机视觉·transformer
AndrewHZ18 小时前
【图像处理基石】如何检测到画面中的ppt并对其进行增强?
图像处理·人工智能·pytorch·opencv·目标检测·计算机视觉·图像增强
通街市密人有20 小时前
PanTS: The Pancreatic Tumor Segmentation Dataset
人工智能·深度学习·计算机视觉