【基于深度学习的验证码识别】---- part3数据加载、模型等API介绍(1)

一、MNIST数据集

MNIST(Modified National Institute of Standards and Technology)数据集是计算机视觉和机器学习领域最经典的入门级数据集之一,主要用于手写数字识别任务。

使用示例(以PyTorch为例)

复制代码
from torchvision.datasets import MNIST
mnist_train = MNIST(root='./MNIST_data', train=True, download=True)
- 复制代码
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
mnist_train = MNIST(root='./MNIST_data', train=True, download=True)

# 训练集长度
print(len(mnist_train))
# 取第一个图片
print(mnist_train[0])
image = mnist_train[5000][0]
# 打印出图片
plt.imshow(image)
plt.show()
print(mnist_train[5000][1])

二、数据加载

在PyTorch中,使用DataLoader加载MNIST数据集时,参数的合理配置直接影响训练效率和模型性能。以下是核心参数的详细说明及其在MNIST场景中的应用:

- 复制代码
from torch.utils.data import DataLoader

参数:batch_size、shuffle、num_workers、pin_memory、drop_last

1、batch_size(批次大小)
  • 定义 :每个批次包含的样本数量。例如,batch_size=64表示每次迭代加载64张图像。
  • 作用:定义每个批次包含的样本数量。例如,若batch_size=64,则每次迭代从数据集中加载64张手写数字图像。
  • MNIST应用
    MNIST图像尺寸为28x28,单个样本数据量小,通常可设置较大的batch_size(如64或128)以充分利用显存并加速训练。
    显存不足时需减小batch_size,否则会引发内存错误(OOM)
2、 shuffle(数据打乱)
  • 定义:是否在每个训练周期(epoch)开始时随机打乱数据顺序。

  • 作用

    • 防止模型偏见 :避免模型学习到数据顺序特征(如MNIST训练集需设为True)。
    • 测试集处理 :测试时通常设为False以保持评估结果一致性。
  • MNIST应用

    python 复制代码
    # 训练集打乱,测试集不打乱
    train_loader = DataLoader(..., shuffle=True)
    test_loader = DataLoader(..., shuffle=False)
3、 num_workers(子进程数)
  • 定义:用于并行加载数据的子进程数量。默认为0(主进程加载)。

  • 作用

    • 加速数据加载:多进程并行读取数据(建议设为CPU核心数的2~4倍,如4或8)。
    • 资源平衡:MNIST数据量小,过高值可能导致内存溢出(需实验调优)。
  • MNIST应用

    python 复制代码
    # 使用4个子进程加载数据
    train_loader = DataLoader(..., num_workers=4)
4、pin_memory(内存锁定)
  • 定义:是否将数据复制到CUDA固定内存(pinned memory)。
  • 作用
    • 加速GPU传输 :启用后,数据从CPU到GPU的传输速度更快(GPU训练时强烈建议设为True)。
    • 资源占用:仅对GPU有效,CPU训练时可忽略。
  • MNIST应用
python 复制代码
  # GPU训练时启用内存锁定
  train_loader = DataLoader(..., pin_memory=True)
5、 drop_last(丢弃末批)
  • 定义 :当数据集大小无法被batch_size整除时,是否丢弃最后一个不完整批次。

  • 作用

    • 避免小批次影响 :丢弃末尾样本(如MNIST训练集60000样本,batch_size=64时最后一个批次含16样本)。
    • 分布式训练对齐:需所有批次大小一致时启用。
  • MNIST应用

    python 复制代码
    # 丢弃不完整批次
    train_loader = DataLoader(..., batch_size=64, drop_last=True)

代码示例

python 复制代码
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# 创建 DataLoader
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

三、图片处理 transform

在深度学习中,图像数据通常需要进行预处理(如缩放、裁剪、归一化等)以适应模型的输入要求。PyTorch 提供了 torchvision.transforms 模块,用于定义和实现这些图像处理操作。

transforms 的作用

transforms 是一个用于图像预处理的工具集,可以将一系列图像处理操作组合在一起,形成一个处理流水线(pipeline)。这些操作通常包括:

  • 数据增强:增加数据的多样性,防止模型过拟合。
  • 数据标准化:将数据转换为模型所需的格式(如归一化到特定范围)。
  • 数据转换:将图像转换为张量(Tensor)格式,以便输入模型。
常用 transforms 操作
1、基础操作
  • Resize: 调整图像大小。
python 复制代码
transforms.Resize((height, width))  # 将图像调整为指定大小
  • CenterCrop: 从图像中心裁剪指定大小的区域。
python 复制代码
transforms.CenterCrop(size)  # 裁剪大小为 (size, size)
  • RandomCrop: 随机裁剪图像。
python 复制代码
transforms.RandomCrop(size)  # 随机裁剪大小为 (size, size)
  • RandomHorizontalFlip: 随机水平翻转图像。
python 复制代码
transforms.RandomHorizontalFlip(p=0.5)  # 以 50% 的概率水平翻转
  • RandomRotation: 随机旋转图像。
python 复制代码
transforms.RandomRotation(degrees=30)  # 随机旋转 ±30 度
2、 颜色变换
  • ColorJitter: 随机改变图像的亮度、对比度、饱和度和色调。
python 复制代码
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
  • Grayscale: 将图像转换为灰度图。
python 复制代码
transforms.Grayscale(num_output_channels=1)  # 转换为单通道灰度图
3、 归一化和标准化
  • ToTensor: 将图像(PIL 或 NumPy 格式)转换为 PyTorch 张量(Tensor),并将像素值从 [0, 255] 缩放到 [0, 1]。
python 复制代码
transforms.ToTensor()

在使用 transforms.ToTensor() 处理图像后,PyTorch 会将图像的通道维度移动到最前面。

transforms.ToTensor() 的作用

1.将图像转换为张量:

输入的图像通常是 PIL 图像或 NumPy 数组,形状为 (H, W, C),其中:
H 是图像的高度(Height)。
W 是图像的宽度(Width)。
C 是图像的通道数(Channels,例如 RGB 图像为 3,灰度图像为 1)。
transforms.ToTensor() 会将图像转换为 PyTorch 张量(Tensor),并将像素值从 [0, 255] 缩放到 [0, 1]。

2通道维度的变化:

转换后的张量形状为 (C, H, W),即通道维度被移动到最前面。

这种格式是 PyTorch 的标准输入格式,便于后续的模型处理。

  • Normalize: 对图像进行标准化处理(减去均值,除以标准差)。
python 复制代码
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

这里的均值和标准差通常是根据数据集计算的(例如 ImageNet 的均值和标准差)。

4、 组合操作
  • Compose: 将多个操作组合成一个流水线。
python 复制代码
 transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
示例代码
python 复制代码
from torchvision import datasets, transforms

# 定义 transforms 流水线
transform = transforms.Compose([
    transforms.Resize((32, 32)),          # 调整图像大小为 32x32
    transforms.RandomHorizontalFlip(),    # 随机水平翻转
    transforms.ToTensor(),                # 转换为张量,并缩放到 [0, 1]
    transforms.Normalize((0.5,), (0.5,))  # 归一化到 [-1, 1]
])

# 加载 MNIST 数据集并应用 transforms
train_dataset = datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=transform
)

# 查看处理后的图像
image, label = train_dataset[0]
print(image.shape)  # 输出: torch.Size([1, 32, 32])
总结
操作 作用
Resize 调整图像大小。
CenterCrop 从图像中心裁剪指定大小的区域。
RandomCrop 随机裁剪图像。
RandomHorizontalFlip 随机水平翻转图像。
RandomRotation 随机旋转图像。
ColorJitter 随机改变图像的亮度、对比度、饱和度和色调。
Grayscale 将图像转换为灰度图。
ToTensor 将图像转换为张量,并缩放到 [0, 1]。
Normalize 对图像进行标准化处理(减去均值,除以标准差)。
Compose 将多个操作组合成一个流水线。
如何在数据加载过程中看到图片的样子

先轴交换,再利用make_grid合并再处理成数组.numpy()后,就可以展示出来

python 复制代码
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

my_transforms = transforms.Compose(
    [transforms.PILToTensor(),
     ]
)
mnist_train = MNIST(root='./MNIST_data', train=True, download=True, transform=transforms.PILToTensor())
dataloader = DataLoader(mnist_train, batch_size=5, shuffle=True) #DataLoader 初始化
for (image, label) in dataloader:# 遍历 DataLoader
    print(image.shape) #torch.Size([5, 1, 28, 28])
    print(label) #tensor([3, 1, 2, 8, 3])
    print(make_grid(image).shape)  #torch.Size([3, 32, 152])  使用 make_grid 将图像拼接成网格
    image = make_grid(image).permute(1,2,0).numpy()#调整网格图像的维度并转换为 NumPy 数组
    plt.imshow(image) #使用 Matplotlib 显示图像
    plt.show()
    exit()
相关推荐
Sunday_ding2 小时前
NLP 与常见的nlp应用
人工智能·自然语言处理
一ge科研小菜鸡2 小时前
当下主流 AI 模型对比:ChatGPT、DeepSeek、Grok 及其他前沿技术
人工智能
ai产品老杨3 小时前
全流程数字化管理的智慧物流开源了。
前端·javascript·vue.js·人工智能·安全
mzgong3 小时前
图像分割的mask有空洞怎么修补
人工智能·opencv·计算机视觉
一面千人3 小时前
从零开始:基于 PyTorch 的图像分类模型
pytorch·深度学习·cnn·图像分类·模型优化·cifar-10·调试经验·前沿趋势
墨绿色的摆渡人3 小时前
pytorch小记(十二):pytorch中 masked_fill_() vs. masked_fill() 详解
人工智能·pytorch·python
scdifsn3 小时前
动手学深度学习11.9. Adadelta-笔记&练习(PyTorch)
pytorch·笔记·深度学习·优化器·adadelta算法
QBorfy4 小时前
08篇 AI从零开始 - LangChain学习与实战(5) 基于RAG开发问答机器人
前端·人工智能·deepseek
赛卡4 小时前
Python直方图:从核密度估计到高维空间解析
开发语言·人工智能·python·matlab