pytorch网站学习
处理数据样本的代码往往会变得很乱、难以维护;理想情况下,我们希望把数据部分的代码和模型训练部分分开写,这样更容易阅读、也更好维护。
简单说:数据和模型最好"分工明确",不要写在一起。
PyTorch 提供了两个数据处理的"基本工具":
-
torch.utils.data.Dataset
-
torch.utils.data.DataLoader
它们可以用来处理官方内置的数据集 ,也可以用来加载你自己的数据。
Dataset 存储样本及其对应的标签,而 DataLoader 则在 Dataset 周围封装了一个迭代器,以便轻松访问这些样本。
-
Dataset:用于存储样本和对应的标签,类似一个"数据库",它记录了所有数据。
-
DataLoader:基于 Dataset 封装了一个可迭代对象,方便你在训练过程中一次取出一个批次(batch)的数据。
-
Dataset = 数据仓库,负责"存"数据
-
DataLoader = 快递员,负责"送"数据,一批一批送给模型训练用
PyTorch 提供了 Dataset
(负责存数据)和 DataLoader
(负责送数据)两个工具,可以方便地管理、加载各种数据
PyTorch 的领域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集都是 torch.utils.data.Dataset
的子类,,例如,FashionMNIST 数据集就是一个专门用于服装图像识别的预加载数据集,它已经按照 Dataset 接口组织好了数据,你可以直接用来训练和测试模型
参数解释:
✅ root
:这是用来存放训练/测试数据的文件夹路径。
✅ train
:指定是加载训练集(train=True
)还是测试集(train=False
)。
✅ download=True
:如果你指定的 root
路径下没有数据,它会自动联网下载。
✅ transform
和 target_transform
:
-
transform
是对图像特征做的变换(比如转为张量、归一化等) -
target_transform
是对标签做的变换(比如 one-hot 编码)from torchvision import datasets, transforms
定义图像的预处理操作:把图片转成张量
transform = transforms.ToTensor()
加载训练集
train_data = datasets.FashionMNIST(
root="data", # 数据保存目录
train=True, # 加载训练集
download=True, # 如果没有就下载
transform=transform # 图像预处理
)加载测试集
test_data = datasets.FashionMNIST(
root="data",
train=False, # 加载测试集
download=True,
transform=transform
)
如何手动取出数据集里的样本,并把它们可视化显示出来,
遍历和可视化数据集
我们可以像访问列表那样,用下标手动访问数据集:training_data[index]
我们使用 matplotlib
来把训练数据中的一些样本画出来进行可视化。
什么是 training_data[index]?
在 PyTorch 中,像 training_data 这种数据集对象,其实可以像列表(list)一样使用:
image, label = training_data[0] # 取出第一个样本(包括图像和标签)
image 是一张 28×28 的图(张量)
label 是它的标签(比如 "T-shirt/top")
# 标签编号和对应的文字(类别)之间的映射关系
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8)) # 创建一个图形窗口,大小为 8x8 英寸
cols, rows = 3, 3 # 准备画一个 3 行 3 列 的图像网格(共 9 张图)
for i in range(1, cols * rows + 1): # 循环9次(从1到9)
sample_idx = torch.randint(len(training_data), size=(1,)).item() # 随机选一个样本索引
img, label = training_data[sample_idx] # 从训练集中取出图像和标签
figure.add_subplot(rows, cols, i) # 添加一个子图(3x3 的第 i 个格子)
plt.title(labels_map[label]) # 设置图像标题为标签名称(比如 "Sneaker")
plt.axis("off") # 不显示坐标轴
plt.imshow(img.squeeze(), cmap="gray") # 显示图像(压缩维度 + 灰度图)
plt.show() # 显示整张图(9张图一起展示)

如何自己创建一个自定义的数据集(Custom Dataset),让 PyTorch 能读取自己的图片和标签,比如本地的一些图片文件和 CSV 表格。
为你自己的文件创建一个自定义数据集
自定义 Dataset 类时,必须实现三个函数:__init__
(初始化)、__len__
(返回样本总数) 和 __getitem__
(获取指定样本)
如果你不是用官方的数据集(比如 FashionMNIST),而是用你自己文件夹里的图片 + CSV 表里的标签,那就需要自己写一个"自定义数据集类":
-
__init__()
:定义数据集在哪里、怎么加载图片和标签 -
__len__()
:告诉 PyTorch 你一共有多少张图(样本数量)
__len__
函数这个函数的作用是:返回数据集中样本(图片)的数量。
-
__getitem__()
:定义怎么通过索引取出一张图和它的标签(比如dataset[0]
)import os # 用于路径拼接
import pandas as pd # 用于读取 CSV 文件
from torchvision.io import read_image # 用于读取图像(转为张量)
from torch.utils.data import Dataset # 自定义数据集要继承这个类自定义图片数据集类,继承自 PyTorch 的 Dataset 基类
class CustomImageDataset(Dataset):
# 初始化函数:加载CSV标签表、图片文件夹路径、图像和标签的预处理方法
def init(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file) # 读取CSV文件,包含图片文件名和对应标签
self.img_dir = img_dir # 图片所在的文件夹路径
self.transform = transform # 图像的预处理方法(例如缩放、归一化)
self.target_transform = target_transform # 标签的预处理方法(例如转one-hot)# 返回数据集中样本的总数量 def __len__(self): return len(self.img_labels) # 返回 CSV 中的行数(也就是图片数量) # 按照索引返回一张图片和它的标签 def __getitem__(self, idx): # 根据索引从CSV中获取图片文件名,并拼接成完整路径 img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) # 使用 torchvision.io.read_image 读取图片(返回的是Tensor格式) image = read_image(img_path) # 获取对应的标签(CSV第二列) label = self.img_labels.iloc[idx, 1] # 如果定义了图像预处理,就应用它 if self.transform: image = self.transform(image) # 如果定义了标签预处理,就应用它 if self.target_transform: label = self.target_transform(label) # 返回一对数据:(图像,标签) return image, label
__init__
函数
当我们创建 Dataset 数据集对象时,这个 __init__
函数会被运行一次。
在这个函数中,我们设置好图像所在的文件夹路径、标签文件(CSV),以及两种预处理方法(transform)
这个时候 Python 就会自动去运行你写的 __init__
函数,完成以下事情:
做什么 | 举例 |
---|---|
读入标签文件 | 从 CSV 读出每张图对应的标签 |
记住图片路径 | 比如你的图片都在 "images/" 文件夹里 |
保存预处理方法 | 如果你要对图像做缩放、归一化等处理,也在这里传进来 |
你可以把 __getitem__()
想象成这样一个问题:
你对 PyTorch 说:"嘿,帮我从数据集中拿出第 5 张图像,还有它的标签。"
PyTorch 就会执行你写的 __getitem__(5)
,然后:
-
去 CSV 表里看第5行,拿到图像文件名,比如
img5.png
-
拼成路径,比如
images/img5.png
-
用
read_image()
把它读成模型能用的格式(张量) -
拿到它的标签,比如
label=2
(代表"Pullover") -
如果你有设置 transform,就先处理一下
-
返回
(图像张量, 标签)
给你
使用 DataLoader 为训练准备数据
Dataset(数据集)每次只能取出一条数据(特征和标签)。
而在训练模型时,我们通常希望将样本按小批量(minibatch)送入模型,
并且在每一轮训练(epoch)中打乱数据的顺序 ,以减少模型过拟合,
同时利用 Python 的多进程功能来加快数据的读取速度。
DataLoader 是一个可迭代对象,它通过一个简单的 API 帮我们封装了以上所有复杂操作。
这里的 API 就是"别人已经写好的功能接口",你只要用很简单的方式去"调用它",就可以完成很复杂的事情。
就像你开车,不用知道发动机怎么工作,你只需要踩油门,这个"油门"就是给你用的 API。
没有 DataLoader 时的问题 | DataLoader 自动帮你做了什么 |
---|---|
一次只能读一张图 | ✅ 自动按 batch_size 读多张图 |
每次都按固定顺序读 | ✅ 每轮训练前自动打乱数据 |
读取慢(尤其是大数据) | ✅ 用多进程后台加速加载数据 |
写代码复杂 | ✅ 封装好,只要一行就能搞定 |
minibatch (中文叫"小批量")指的是:**每次训练时不把所有数据一次性喂给模型,而是一次取出一小部分来训练。**举个例子:
你有 10,000 张训练图像,不可能一次性都送给模型(太慢/太耗显存)。
你可以这样设置:
batch_size = 64
就是:每次训练用 64 张图,学完一批,再取下一批。
这种方式叫:小批量训练(mini-batch training)
什么是 shuffle(打乱数据)?
定义:shuffle 指的是:在每轮训练开始前,把训练数据的顺序随机打乱。
为什么要打乱?
假如你的数据是按类别排好顺序的(比如先全是猫,后全是狗):
模型可能先学猫学很久,突然一下全是狗,这样容易 过拟合某一类,泛化能力差。
所以我们会在每个 epoch 前加个参数:
DataLoader(..., shuffle=True)
表示:每一轮训练前,重新随机排序数据。
什么是多进程加载(num_workers)?
定义:PyTorch 可以使用多个"后台工作进程(线程)"同时从磁盘里读取图片,加快加载速度。
举个例子:
你用 DataLoader
加载数据时可以设置:
DataLoader(dataset, batch_size=64, num_workers=4)
意思是:开 4 个后台进程来同时读数据!
就像你点外卖,找了 4 个骑手一起送菜,当然比 1 个骑手送得快。
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)什么意思啊
这段代码是用 PyTorch 的 DataLoader
,将训练数据和测试数据按小批量分组,并在每轮开始时随机打乱顺序,方便高效地进行模型训练和测试。
遍历 DataLoader
我们已经把数据集加载进了 DataLoader
,现在可以根据需要对数据集进行迭代(逐批处理)。
下面的每次迭代都会返回一批 train_features
(训练特征)和 train_labels
(标签),每批包含 64 个样本和对应的标签(即 batch_size=64)。
因为我们设置了 shuffle=True
,所以在我们把所有批次迭代完之后,数据会被自动打乱顺序。
(如果你想更精细地控制数据加载的顺序,可以了解一下 PyTorch 的 Sampler
机制。)
Samplers
是 PyTorch 中 更灵活地控制数据加载顺序 的工具。
如果你想自己控制"数据加载顺序"、"打乱方式"、"分组策略"等,就可以用 Sampler
来代替 shuffle=True
。
Sampler 是一个类,用来控制 DataLoader
在每一轮训练中应该以什么顺序取数据的索引。
常见的 Sampler 类型
Sampler 类别 | 作用 |
---|---|
SequentialSampler |
按顺序取数据(默认用于 shuffle=False ) |
RandomSampler |
随机打乱数据(默认用于 shuffle=True ) |
SubsetRandomSampler |
只随机抽样部分数据(适合做验证集) |
WeightedRandomSampler |
按权重随机抽样(处理数据不平衡) |
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
从训练集里拿出一批数据,并显示其中一张图片和它的标签