Pytorch中的数据加载

数据加载之前,先学习两个Pytorch中的类:Dataset和DataLoader。在 PyTorch 中,Dataset 和 DataLoader 是两个非常重要的类,用于高效地加载和处理数据。它们通常一起使用,以便在训练深度学习模型时更好地管理数据。如果数据是图像数据,可以使用Image模块来完成图片操作。除此之外,可能还需要用到Python标准库中的os模块,如文件路径拼接,文件列表。

1、Dataset类和DataLoader类,Image模块、os模块

1. Dataset 类

Dataset 是一个抽象类,表示一个数据集。你可以通过继承 Dataset 类来创建自定义的数据集。Dataset 类主要定义了以下两个核心方法:

  • len(self): 返回数据集中样本的数量。

  • getitem(self, idx): 根据索引 idx 返回对应的样本。

    通过这些方法,Dataset 类允许你以统一的方式访问数据集中的样本。

自定义 Dataset 示例

python 复制代码
from torch.utils.data import Dataset
import torch

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 示例数据
data = torch.randn(100, 3)  # 100个样本,每个样本有3个特征
labels = torch.randint(0, 2, (100,))  # 100个标签,0或1

# 创建自定义数据集
dataset = CustomDataset(data, labels)

# 访问数据集中的样本
sample, label = dataset[0]
print(f"Sample: {sample}, Label: {label}")

更详细的介绍可以见:Pytorch中的torch.utils.data.Dataset 类

2. DataLoader 类

DataLoader 是一个迭代器,用于从 Dataset 中高效地加载数据。它提供了以下功能:

  • 批量加载数据: 可以将数据分成多个小批量(mini-batches)进行加载。
  • 多线程加载: 可以使用多个线程并行加载数据,减少 I/O 瓶颈。
  • 数据打乱: 可以在每次迭代时打乱数据顺序,以避免模型过拟合。
  • 自定义采样策略: 可以通过 Sampler 和 BatchSampler 自定义数据加载的顺序。

使用 DataLoader 示例

python 复制代码
from torch.utils.data import DataLoader

# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)

# 遍历DataLoader
for batch_idx, (samples, labels) in enumerate(dataloader):
    print(f"Batch {batch_idx}:")
    print(f"Samples: {samples}, Labels: {labels}")

参数说明

  • dataset: 需要加载的数据集,通常是 Dataset 类的实例。
  • batch_size: 每个批次的样本数量。
  • shuffle: 是否在每个 epoch 打乱数据顺序。
  • num_workers: 用于数据加载的子进程数量。如果设置为 0,则数据加载在主进程中进行。

总结

  • Dataset 类用于定义数据集的结构和如何访问数据。
  • DataLoader 类用于高效地加载数据,支持批量加载、多线程加载和数据打乱等功能。

结合使用 Dataset 和 DataLoader 可以让你在训练深度学习模型时更加高效地处理数据。

3. Image模块

PIL(Python Imaging Library)是一个用于图像处理的Python库,而Image是其中的一个核心模块,提供了丰富的图像操作功能。PIL库已经停止更新,但其分支Pillow继续维护并扩展了功能,因此现在通常使用Pillow来代替PIL。

Image模块的主要功能

Image模块提供了以下主要功能:

打开和保存图像:

  • Image.open(file):打开图像文件。
  • Image.save(file, format):保存图像到文件。

图像属性:

  • image.size:获取图像的尺寸(宽度, 高度)。
  • image.mode:获取图像的模式(如 "RGB", "L", "CMYK" 等)。
  • image.format:获取图像的格式(如 "JPEG", "PNG" 等)。

图像操作:

  • image.resize(size):调整图像大小。
  • image.rotate(angle):旋转图像。
  • image.crop(box):裁剪图像。
  • image.transpose(method):翻转或旋转图像(如 Image.FLIP_LEFT_RIGHT)。

图像转换:

  • image.convert(mode):转换图像模式(如将彩色图像转换为灰度图像)。
  • image.thumbnail(size):生成缩略图。

图像显示:

  • image.show():显示图像(使用系统默认的图像查看器)。

图像处理:

  • image.filter(filter):应用滤镜(如模糊、锐化等)。
  • image.paste(im, box):将一张图像粘贴到另一张图像上。

示例代码

以下是一个简单的示例,展示如何使用Image模块:

python 复制代码
from PIL import Image

# 打开图像
image = Image.open("example.jpg")

# 获取图像属性
print(f"Size: {image.size}")
print(f"Format: {image.format}")
print(f"Mode: {image.mode}")

# 调整图像大小
resized_image = image.resize((200, 200))

# 旋转图像
rotated_image = image.rotate(45)

# 保存图像
resized_image.save("resized_example.jpg")
rotated_image.save("rotated_example.jpg")

# 显示图像
image.show()

总结

Image模块是Pillow库中用于处理图像的核心模块,提供了从打开、操作到保存图像的完整功能。无论是简单的图像处理任务还是复杂的图像操作,Image模块都能胜任。

4. os模块

4.1 os.listdir函数

os.listdir是 Python 标准库 os 模块中的一个函数,用于列出指定目录中的所有文件和子目录的名称。它是一个非常常用的函数,用于遍历目录内容或检查文件系统中的文件。

函数定义

python 复制代码
os.listdir(path='.')

参数:

path(可选):要列出内容的目录路径。默认为当前目录(.)。

返回值:

返回一个包含目录中所有文件和子目录名称的列表(list)。列表中的每个元素是一个字符串,表示文件或目录的名称。

主要特点

  1. 列出文件和目录:
  • os.listdir 会返回指定目录下的所有文件和子目录的名称,包括隐藏文件(以 . 开头的文件)。
  • 它不会递归列出子目录中的内容。
  1. 返回相对路径:
  • 返回的文件名是相对于给定路径的相对路径,而不是绝对路径。
  1. 不区分文件和目录:
  • 返回的列表中包含文件和目录的名称,但不会区分它们是文件还是目录。如果需要区分,可以结合 os.path.isfile 或os.path.isdir 使用。
  1. 不保证顺序:
  • 返回的文件列表的顺序是不确定的,通常取决于文件系统的实现。

示例代码

  1. 列出当前目录的内容
python 复制代码
import os

# 列出当前目录中的所有文件和子目录
contents = os.listdir()
print("当前目录内容:", contents)
  1. 列出指定目录的内容
python 复制代码
import os

# 列出指定目录中的所有文件和子目录
path = "/path/to/your/directory"
contents = os.listdir(path)
print(f"目录 {path} 的内容:", contents)
  1. 区分文件和目录
python 复制代码
import os

path = "/path/to/your/directory"
contents = os.listdir(path)

# 区分文件和目录
files = [f for f in contents if os.path.isfile(os.path.join(path, f))]
directories = [d for d in contents if os.path.isdir(os.path.join(path, d))]

print("文件:", files)
print("目录:", directories)
  1. 列出特定类型的文件
python 复制代码
import os

path = "/path/to/your/directory"
contents = os.listdir(path)

# 列出所有 .txt 文件
txt_files = [f for f in contents if f.endswith(".txt")]
print("文本文件:", txt_files)

注意事项

  1. 路径问题:
  • 如果路径不存在或没有访问权限,os.listdir 会抛出 FileNotFoundError 或 PermissionError 异常。
  • 可以使用 os.path.exists 检查路径是否存在。
  1. 隐藏文件:
  • os.listdir 会列出隐藏文件(如 .gitignore 或 .DS_Store),如果需要过滤隐藏文件,可以手动处理。
  1. 性能问题:
  • 对于包含大量文件的目录,os.listdir 可能会比较慢。如果需要更高效的文件遍历,可以考虑使用 os.scandir。

与 os.scandir 的区别

os.listdir 返回的是文件名列表,而 os.scandir 返回的是 DirEntry 对象的迭代器,提供了更多的文件信息(如文件类型、文件属性等),性能也更好。例如:

python 复制代码
import os

path = "/path/to/your/directory"
with os.scandir(path) as entries:
    for entry in entries:
        print(entry.name, entry.is_file())

总结

os.listdir 是一个简单而强大的工具,用于列出目录中的文件和子目录。它非常适合快速查看目录内容,但在需要更多文件信息或处理大量文件时,可以考虑使用 os.scandir 或其他更高级的文件遍历方法。

4.2 os.path.join函数

os.path.join 是 Python 标准库 os.path 模块中的一个函数,用于将多个路径组件连接成一个完整的路径。它能够根据操作系统的不同,自动使用正确的路径分隔符(例如,在 Windows 上使用 \,在 Linux 和 macOS 上使用 /),从而避免手动拼接路径时可能出现的错误。

函数定义

python 复制代码
os.path.join(path, *paths)

参数:

  • path:第一个路径组件(通常是目录路径)。
  • *paths:可变参数,表示后续的路径组件(可以是文件名或子目录名)。

返回值:

返回一个字符串,表示拼接后的路径。

主要特点

  1. 跨平台兼容性:
  • os.path.join 会根据操作系统的不同,自动使用正确的路径分隔符。
    例如:
    在 Windows 上:os.path.join("dir", "file.txt") 返回 "dir\file.txt"。
    在 Linux/macOS 上:os.path.join("dir", "file.txt") 返回 "dir/file.txt"。
  1. 处理多余的分隔符:
  • 如果路径组件中已经包含路径分隔符,os.path.join 会自动处理多余的分隔符,避免出现重复的分隔符。
  1. 支持绝对路径:
  • 如果某个路径组件是绝对路径(例如以 / 或 C:\ 开头),os.path.join 会忽略之前的路径组件,从该绝对路径开始拼接。
  1. 灵活性:
  • 可以拼接任意数量的路径组件。

示例代码

  1. 基本用法
python 复制代码
import os

# 拼接路径
path = os.path.join("dir", "subdir", "file.txt")
print("拼接后的路径:", path)
  • 在 Windows 上输出:
    拼接后的路径: dir\subdir\file.txt
  • 在 Linux/macOS 上输出:
    拼接后的路径: dir/subdir/file.txt
  1. 处理绝对路径
python 复制代码
import os

# 如果某个组件是绝对路径,之前的组件会被忽略
path = os.path.join("dir", "/absolute/path", "file.txt")
print("拼接后的路径:", path)

输出:

拼接后的路径: /absolute/path/file.txt

  1. 处理多余的分隔符
python 复制代码
import os

# 自动处理多余的分隔符
path = os.path.join("dir/", "/subdir/", "file.txt")
print("拼接后的路径:", path)
  • 在 Windows 上输出:
    拼接后的路径: dir\subdir\file.txt
  • 在 Linux/macOS 上输出:
    拼接后的路径: dir/subdir/file.txt
  1. 拼接多个组件
python 复制代码
import os

# 拼接多个路径组件
path = os.path.join("root", "dir1", "dir2", "file.txt")
print("拼接后的路径:", path)
  • 在 Windows 上输出:
    拼接后的路径: root\dir1\dir2\file.txt
  • 在 Linux/macOS 上输出:
    拼接后的路径: root/dir1/dir2/file.txt

注意事项

  1. 路径规范化:
    os.path.join 不会自动规范化路径(例如处理 . 或 ...)。如果需要规范化路径,可以使用 os.path.normpath。
python 复制代码
import os
path = os.path.join("dir", "..", "file.txt")
normalized_path = os.path.normpath(path)
print("规范化后的路径:", normalized_path)

输出:

python 复制代码
规范化后的路径: file.txt
  1. 空路径组件:
    如果某个路径组件为空字符串,os.path.join 会忽略它。
python 复制代码
import os
path = os.path.join("dir", "", "file.txt")
print("拼接后的路径:", path)

输出:

python 复制代码
拼接后的路径: dir/file.txt
  1. 避免手动拼接路径:
    手动拼接路径(例如使用 + 或字符串格式化)可能会导致跨平台兼容性问题,因此推荐使用 os.path.join。

总结

os.path.join 是一个非常有用的工具,用于安全、跨平台地拼接路径。它能够自动处理路径分隔符和多余的分隔符,确保生成的路径在不同操作系统上都能正常工作。在编写文件路径相关的代码时,强烈推荐使用 os.path.join 来避免潜在的错误。

2、数据加载示例

2.1 示例一

2.1.1 了解数据集内容及格式

首先,将数据集放到项目所在文件夹中。数据集示例如下:


整个数据集分为训练集(train)和验证集(val);训练集和验证集中都包含两种数据:蚂蚁(ants)和蜜蜂(bees),文件名就是数据的标签;ants中又包含很多张图片数据,bees同理:

2.1.2 编写程序加载数据集

  1. 导入所需的类或模块
python 复制代码
from torch.utils.data import Dataset
from PIL import Image
import os
  1. 编写数据类
      编写一个表示数据集的类,继承自Dataset类,并重写方法__init__()和方法__getitem__()和方法__len__():
python 复制代码
class Mydata(Dataset):

    def __init__(self,root_dir,label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.image_path = os.listdir(self.path)

    def __getitem__(self, idx):
       img_name = self.image_path[idx]
       img_path = os.path.join(self.path, img_name)
       img = Image.open(img_path)
       label = self.label_dir
       return img,label

    def __len__(self):
        return len(self.image_path)
  • 创建实例对象时,需要传入根目录(root_dir,即上述数据集的"train"或者"val"),以及标签目录(label_dir,即数据集中的"ants"或者"bees")。方法__init__()将路径拼接,并得到图片文件名的列表。
  • 方法__getitem__()在对该示例索引时,通过索引找到图片文件名的列表中的名字(img_name = self.image_path[idx]),通过路径拼接找到图片文件路径(img_path = os.path.join(self.path, img_name))并读取图片文件(img = Image.open(img_path)),最后返回图片对象和标签。
  • 方法__len__()返回图片文件名的列表的长度,即图片的数量。
  1. 读取训练数据集

创建Mydata实例并传入所需的文件路径("hymenoptera_data/train"和 "ants")。

python 复制代码
root_dir = "hymenoptera_data/train"
ants_label_dir = "ants"
ants_dataset = Mydata(root_dir, ants_label_dir)

通过索引该实例查看一下返回值:返回的是一个图像对象和一个标签

python 复制代码
>>> ants_dataset[0]
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=768x512 at 0x2BDD6D6A220>, 'ants')

用两个变量分别接收图片对象和其标签,并展示一下图片:

python 复制代码
>>> img, label=ants_dataset[0]
>>> img.show()

训练集中还有一类bees数据。也用一样的方法加载,最后把两个数据集(ants和bees)合并拼接(+)起来:

python 复制代码
root_dir = "hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = Mydata(root_dir, ants_label_dir)
bees_dataset = Mydata(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset

合并之后的数据集长度等于两个数据集长度之和:

python 复制代码
>>> len(ants_dataset)
124
>>> len(bees_dataset)
121
>>> len(train_dataset)
245

可以看到合并之后的数据集索引为123以及之前的标签是ants,索引为124及之后的标签是bees:

python 复制代码
train_dataset[123][1]
'ants'
train_dataset[124][1]
'bees'

2.2 示例二

2.2.1 了解数据集内容及格式

首先,将数据集放到项目所在文件夹中。数据集示例如下:

同样包含训练集(train)和验证集(val)。

不过,不同的是,例如在train文件夹下,图片数据和图片标签分别存放在不同的文件夹中(如ants_image和ants_label分别存放蚂蚁的图片和对应图片的标签)。


ants_image中为.ipg的图片文件,ants_label中为.txt文本文件,他们的名字是一一对应的。

文本文件中是该图片的标签。

2.2.2 编写程序加载数据集

  1. 导入所需的类或模块
python 复制代码
from torch.utils.data import Dataset
from PIL import Image
import os
  1. 编写数据类
      编写一个表示数据集的类,继承自Dataset类,并重写方法__init__()和方法__getitem__()和方法__len__():
python 复制代码
from torch.utils.data import Dataset
from PIL import Image
import os

class Mydata(Dataset):

    def __init__(self,root_dir, image_dir, label_dir):
        self.root_dir = root_dir
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.image_path = os.path.join(self.root_dir, self.image_dir)
        self.label_path = os.path.join(self.root_dir, self.label_dir)
        self.image_list = os.listdir(self.image_path)
        self.label_list = os.listdir(self.label_path)

    def __getitem__(self, idx):
       img_name = self.image_list[idx]
       img_item_path = os.path.join(self.image_path, img_name)
       img = Image.open(img_item_path)

       label_name = self.label_list[idx]
       label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
       with open(label_item_path, 'r') as f:
           label = f.readline()

       return img,label

    def __len__(self):
        return len(self.image_list)
  • 创建实例对象时,需要传入根目录(root_dir,即上述数据集的"train"或者"val"),以及图像文件目录(image_dir,即数据集中的"ants_image"或者"bees_image"),标签目录(label_dir,即数据集中的"ants_label"或者"bees_label")。方法__init__()将路径拼接,并得到图片文件名的列表和图片标签文件名的列表。
  • 方法__getitem__()在对该示例索引时,通过索引找到图片文件名列表中的名字(img_name = self.image_list[idx]),通过路径拼接找到图片文件路径(img_item_path = os.path.join(self.image_path, img_name))并读取图片文件(img = Image.open(img_item_path));通过索引找到对应的标签文件名(label_name = self.label_list[idx]),通过路径拼接找到标签文件路径(label_item_path = os.path.join(self.root_dir, self.label_dir, label_name))并读取.txt文件的内容。最后返回图片对象和标签。
  • 方法__len__()返回图片文件名的列表的长度,即图片的数量。
  1. 读取训练数据集

创建Mydata实例并传入所需的文件路径("练手数据集/train", "ants_image"和"ants_label")。

python 复制代码
root_dir = "练手数据集/train"
ants_image_dir = "ants_image"
ants_label_dir = "ants_label"
ants_dataset = Mydata(root_dir, ants_image_dir, ants_label_dir)

通过索引该实例查看一下返回值:返回的是一个图像对象和一个标签

python 复制代码
>>> ants_dataset[1]
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x333 at 0x1DC370ACB80>, 'ants')

用两个变量分别接收图片对象和其标签,并展示一下图片:

python 复制代码
>>> img, label=ants_dataset[1]
>>> img.show()

训练集中还有一类bees数据。也用一样的方法加载,最后把两个数据集(ants和bees)合并拼接(+)起来:

python 复制代码
root_dir = "练手数据集/train"
ants_image_dir = "ants_image"
ants_label_dir = "ants_label"
bees_image_dir = "bees_image"
bees_label_dir = "bees_label"
ants_dataset = Mydata(root_dir, ants_image_dir, ants_label_dir)
bees_dataset = Mydata(root_dir, bees_image_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset

合并之后的数据集长度等于两个数据集长度之和:

python 复制代码
>>> len(ants_dataset)
124
>>> len(bees_dataset)
121
>>> len(train_dataset)
245

可以看到合并之后的数据集索引为123以及之前的标签是ants,索引为124及之后的标签是bees:

python 复制代码
train_dataset[123][1]
'ants'
train_dataset[124][1]
'bees'
相关推荐
果冻人工智能2 分钟前
Linux 之父把 AI 泡沫喷了个遍:90% 是营销,10% 是现实。
人工智能
PacosonSWJTU5 分钟前
python基础-07-模式匹配与正则表达式
python·mysql·正则表达式
程序员总部9 分钟前
单例模式在Python中的实现和应用
开发语言·python·单例模式
demonlg011212 分钟前
Go 语言 fmt 模块的完整方法详解及示例
开发语言·后端·golang
测试盐22 分钟前
django入门教程之cookie和session【六】
后端·python·django
冷琴199623 分钟前
基于python+django的商城网站-电子商城管理系统源码+运行
开发语言·python·django
右恩35 分钟前
jupyter使用过程中遇到的问题
ide·python·jupyter
果冻人工智能38 分钟前
Sal Khan 和 Bill Gates 对 AI 的看法错了
人工智能
Tadecanlan1 小时前
[C++面试] 你了解视图吗?
开发语言·c++