【深度学习】pytorch——常用工具模块

笔记为自我总结整理的学习笔记,若有错误欢迎指出哟~

深度学习专栏链接:
http://t.csdnimg.cn/dscW7

pytorch------常用工具模块

数据处理 torch.utils.data模块

在解决深度学习问题的过程中,往往需要花费大量的精力去处理数据,包括图像、文本、语音或其它二进制数据等。数据的处理对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练,更会提高模型效果。考虑到这点,PyTorch提供了几个高效便捷的工具,以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。

  • Dataset 类:用于表示数据集,可以通过继承这个类来创建自定义的数据集。
  • DataLoader 类:用于批量加载数据,可以指定批量大小、是否打乱数据等参数。

Dataset

在PyTorch中,数据加载可通过自定义的数据集对象。数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset,并实现两个Python魔法方法:

  • __getitem__:返回一条数据,或一个样本。obj[index]等价于obj.__getitem__(index)
  • __len__:返回样本的数量。len(obj)等价于obj.__len__()

DataLoader

要创建一个 DataLoader,我们需要指定以下参数:

  • 数据集实例:这通常是你自定义的数据集类的实例,例如 CustomDataset。
  • 批量大小(batch_size):用于指定每个批量包含的样本数量。
  • 是否打乱数据(shuffle):指定是否在每个 epoch 开始时对数据进行打乱,通常在训练过程中会打乱数据,而在验证或测试过程中不会。
  • 多线程加载(num_workers):指定用于数据加载的线程数,可以加快数据加载速度。

sampler

在 PyTorch 中,torch.utils.data.sampler 模块包含了多种用来对数据进行采样的类,例如 SequentialSamplerRandomSamplerSubsetRandomSampler 等。这些采样器可以用于创建自定义的数据采样策略,以满足不同的训练需求。

下面是一些常用采样器的用法举例:

  1. SequentialSampler:顺序采样器,在每个 epoch 中按顺序遍历整个数据集。
python 复制代码
from torch.utils.data import DataLoader, SequentialSampler

# 创建顺序采样器
sampler = SequentialSampler(dataset)

# 使用采样器创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, sampler=sampler)
  1. RandomSampler:随机采样器,每个 epoch 随机对数据集进行采样。
python 复制代码
from torch.utils.data import DataLoader, RandomSampler

# 创建随机采样器
sampler = RandomSampler(dataset, replacement=True, num_samples=100)

# 使用采样器创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, sampler=sampler)
  1. SubsetRandomSampler:从给定索引中随机采样子集。
python 复制代码
from torch.utils.data import SubsetRandomSampler

# 创建一个索引列表
indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# 使用随机子集采样器创建数据加载器
sampler = SubsetRandomSampler(indices)
data_loader = DataLoader(dataset, batch_size=32, sampler=sampler)
  1. WeightedRandomSampler:加权随机采样,允许根据每个样本的权重来进行采样,从而更灵活地处理不平衡的数据集。这在处理类别不平衡、稀有事件或其他特定情况下非常有用。
python 复制代码
from torch.utils.data import DataLoader, WeightedRandomSampler

# 假设有一个数据集和对应的样本权重
dataset = YourDataset()
weights = [0.1, 0.5, 0.8, 0.3, 0.6]  # 每个样本的权重

# 创建加权随机采样器
sampler = WeightedRandomSampler(weights, num_samples=10, replacement=True)

# 使用采样器创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, sampler=sampler)

在这个示例中,WeightedRandomSampler 接受一个权重列表作为输入,并可以指定采样的样本数目和是否使用放回抽样(replacement=True 表示可以重复采样同一个样本)。

这些示例展示了如何使用不同的采样器来创建数据加载器,并指定不同的采样策略。可以根据具体的训练需求选择合适的采样器,并结合数据加载器来灵活地管理数据的采样和训练过程。

torch.utils.data的使用

假设有一个包含图像数据和对应标签的数据集,将创建一个自定义的数据集类来加载这些数据,并使用 DataLoader 来批量加载数据供模型训练使用。

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader

# 假设你有图像数据和对应标签的数据集
class CustomDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = targets
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]
        if self.transform:
            x = self.transform(x)
        return x, y

# 创建数据集实例
# 假设 data 和 targets 是你的图像数据和对应标签
custom_dataset = CustomDataset(data, targets, transform=your_transforms)

# 使用 DataLoader 批量加载数据
batch_size = 32
shuffle = True
num_workers = 4  # 可以加快数据加载速度的线程数

data_loader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

# 遍历数据加载器,获取批量数据
for inputs, labels in data_loader:
    # 在这里执行模型训练或推理
    pass
  • 创建了一个自定义的数据集类 CustomDataset,该类继承自 torch.utils.data.Dataset,并实现了 lengetitem 方法。
  • 创建了数据集实例 custom_dataset,并使用 DataLoader 实例 data_loader 批量加载数据。
  • 通过遍历数据加载器,可以获取批量的输入数据和对应的标签,用于模型的训练或推理过程。

计算机视觉工具包 torchvision

PyTorch 的计算机视觉工具包 torchvision 提供了一系列用于图像处理和计算机视觉任务的工具和数据集。它包含了常用的数据集(如 ImageNet、CIFAR10、COCO 等)、图像变换操作、模型架构以及预训练的模型等功能,方便用户快速构建和训练计算机视觉模型。

以下是 torchvision 中一些常用的功能和模块:

  1. 数据集和数据加载器torchvision.datasets 模块提供了常用的图像数据集,例如 CIFAR, COCO, MNIST 等,并且可以通过 torchvision.transforms 模块中的图像变换操作对数据进行预处理。同时,torchvision.transforms 还提供了各种图像变换操作,如裁剪、缩放、翻转等,用于数据增强和预处理。

  2. 模型架构和预训练模型torchvision.models 模块包含了一些经典的计算机视觉模型,如 ResNet、VGG、AlexNet 等,同时还提供了这些模型在 ImageNet 数据集上预训练的参数。这些预训练模型可以用于迁移学习或者基准测试。

  3. 图像工具函数torchvision.utils 模块中提供了一些图像操作的工具函数,比如保存图像、绘制边界框、可视化图像等功能。

使用 torchvision 可以大大简化计算机视觉任务的开发过程,提高开发效率,特别是在处理图像数据、构建模型、模型评估等方面提供了很多便利。

torchvision.datasets模块

torchvision.datasets 模块是 PyTorch 中用于加载和处理常见图像数据集的模块。这个模块提供了许多流行的图像数据集,使得用户可以轻松地获取这些数据集并用于模型训练和评估。

一些常见的数据集包括:

  1. MNIST: 包含手写数字图片的数据集,常用于图像分类任务。

  2. CIFAR10 和 CIFAR100: 分别包含 10 个类别和 100 个类别的彩色图片数据集,也用于图像分类任务。

  3. ImageNet: 包含数百万张图片,涵盖了数千个类别,常用于大规模图像分类和目标检测任务。

  4. COCO (Common Objects in COntext): 包含了大量的标注的图像,用于目标检测、实例分割等任务。

torchvision.datasets 不仅提供了这些数据集的接口,还提供了数据加载器(data loader),从而可以方便地将数据集加载到模型中进行训练和测试。

使用 torchvision.datasets,我们可以通过几行简单的代码来加载这些数据集,并且可以对数据进行预处理、数据增强等操作,为模型训练提供方便。

例如,以下是使用 torchvision.datasets 加载 CIFAR-10 数据集的示例代码:

python 复制代码
import torchvision
import torchvision.transforms as transforms

# 定义数据转换操作
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

这样就可以轻松地获取 CIFAR-10 数据集,并且进行相应的数据预处理,为模型训练做准备。

torchvision.transforms模块

transforms 模块是 PyTorch 中 torchvision 库的一部分,它提供了各种图像预处理和数据增强的函数,用于在训练神经网络时对图像数据进行处理。

以下是 transforms 模块中常用的一些函数:

  1. Pad(padding, fill=0, padding_mode='constant'):对图像进行填充。

  2. ToTensor():将 PIL 图像或 ndarray 转换为 tensor,并且将数值范围缩放到 [0, 1] 或 [-1, 1]。

  3. Normalize(mean, std):对 tensor 进行标准化,减去均值然后除以标准差。这个操作通常用于对输入数据进行归一化处理。

  4. Resize(size):调整图像大小为指定的尺寸。

  5. RandomHorizontalFlip():随机水平翻转图像,用于数据增强。

  6. RandomVerticalFlip():随机垂直翻转图像,用于数据增强。

  7. RandomRotation(degrees):随机旋转图像一定角度,用于数据增强。

  8. RandomCrop(size):随机裁剪图像到指定的尺寸,用于数据增强。

  9. ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):随机改变图像的亮度、对比度、饱和度和色相,用于数据增强。

  10. ToPILImage():将张量(tensor)转换为 PIL 图像格式

这些函数可以通过 transforms.Compose() 组合在一起,构成一个图像预处理流水线,然后应用于加载的图像数据上,以便在训练神经网络时进行数据处理和增强。

python 复制代码
import torchvision.transforms as transforms
from PIL import Image

# 加载图像
image_path = "data/dogcat/cat.12484.jpg"
image = Image.open(image_path)
image.show()

# 定义图像变换操作
transform = transforms.Compose([
    transforms.Pad(padding=10, fill=0, padding_mode='constant'),  # 填充操作
    transforms.Resize(256),  # 调整图像大小为 256x256
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomVerticalFlip(),  # 随机垂直翻转
    transforms.RandomRotation(degrees=45),  # 随机旋转图像最多45度
    transforms.RandomCrop(224),  # 随机裁剪图像到224x224
    #transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # 颜色增强
    transforms.ToTensor(),  # 转换为张量
    #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # 标准化
])

# 应用图像变换操作
transformed_image = transform(image)

# 显示处理后的图像
transforms.ToPILImage()(transformed_image).show()


torchvision.models模块

torchvision.models 模块提供了在 PyTorch 中使用的一系列经典的预训练模型,例如 ResNet、VGG、AlexNet、GoogLeNet 等。这些预训练模型可以用于图像分类、目标检测、语义分割等任务,并且方便用户进行迁移学习或微调。

使用 torchvision.models 模块,我们可以轻松地访问这些经典模型,并且可以加载预训练的权重参数,从而在自己的数据集上进行模型训练或推理。

以下是一些常用的预训练模型:

  1. ResNet: 包括 ResNet-18、ResNet-34、ResNet-50 等不同深度的 ResNet 模型,用于图像分类任务。

  2. VGG: 包括 VGG-11、VGG-16、VGG-19 等不同深度的 VGG 模型,也用于图像分类任务。

  3. AlexNet: AlexNet 是一个较早期的深度卷积神经网络模型,也常用于图像分类任务。

  4. GoogLeNet: GoogLeNet 是由 Google 提出的深度卷积神经网络,适用于图像分类和目标检测任务。

  5. DenseNet: 密集连接网络(DenseNet)是另一个流行的卷积神经网络结构,适用于图像分类和其他计算机视觉任务。

通过 torchvision.models,我们可以很方便地加载这些模型,并且可以直接用于自己的任务,或者进行进一步的微调以适应特定的数据集和任务需求。

以下是一个示例,展示了如何使用 torchvision.models 加载预训练的 ResNet-18 模型:

python 复制代码
import torchvision.models as models

# 加载预训练的 ResNet-18 模型
resnet18 = models.resnet18(pretrained=True)

# 对模型进行微调或者用于推理

可视化工具 Visdom

Visdom 官方文档(https://github.com/fossasia/visdom

Visdom 是一个用于创建实时交互式可视化的工具,最初是由 Facebook 的人工智能研究团队开发的,其开源于2017年3月,用于支持深度学习模型的可视化和监控。它提供了一个基于 Web 的用户界面,允许用户在浏览器中实时查看和操作可视化结果。Visdom 主要针对 PyTorch 和 Torch 等深度学习框架,但也可以与其他框架集成使用。

Visdom 的主要特点包括:

  1. 实时交互式可视化:Visdom 支持实时更新可视化结果,并且允许用户通过简单的交互方式进行操作,如缩放、平移、标注等,从而更好地理解数据和模型的行为。

  2. 多种类型的可视化:Visdom 提供了多种类型的可视化工具,包括折线图、条形图、散点图、热力图、直方图、图像等,满足了不同类型数据的可视化需求。

  3. 多用户支持:Visdom 支持多用户共享可视化结果,多个用户可以同时查看和操作可视化数据,这在团队协作以及教学研究方面非常有用。

  4. 语言无关性:Visdom 可以与多种编程语言进行集成,尤其是在 Python 和 Lua 等语言中应用较为广泛。

  5. 灵活的部署方式:Visdom 可以作为一个独立的服务器运行,也可以嵌入到现有的 Python 代码中,使得可视化过程更加灵活和定制化。

总的来说,Visdom 是一个功能强大、易于使用的可视化工具,特别适用于深度学习模型的训练过程监控、结果展示以及模型行为分析。通过实时交互式的可视化,用户可以更好地理解和优化他们的深度学习模型。

两个重要概念

在 Visdom 中,有两个重要的概念:窗口(window)和环境(environment)。

  1. 窗口(Window)

    在 Visdom 中,窗口是指用户界面中的一个可视化区域,用于展示特定类型的数据可视化结果,比如折线图、散点图、图像等。每个窗口都有一个唯一的标识符,可以通过这个标识符来更新或关闭窗口中的内容。用户可以在同一环境下创建多个窗口,用于同时展示不同类型的数据可视化结果,比如训练损失曲线、模型预测结果等。

  2. 环境(Environment)

    环境是 Visdom 中用于组织窗口的概念,可以理解为一个命名空间,用于区分不同类型或不同任务的可视化结果。不同环境的可视化结果相互隔离,互不影响,在使用时如果不指定env,默认使用main。不同用户、不同程序一般使用不同的env。

这两个概念的引入使得 Visdom 在展示和组织数据可视化结果时更加灵活和清晰,同时也方便用户对不同类型的数据进行管理和交互操作。

Visdom的使用

要使用 Visdom 进行可视化,您需要按照以下步骤进行设置和操作:

  1. 安装 Visdom

    首先,您需要在您的环境中安装 Visdom。可以使用以下命令使用 pip 安装 Visdom 库:

    pip install visdom
    
  2. 启动 Visdom 服务器

    在安装完成后,您需要启动 Visdom 服务器。可以在终端中运行以下命令启动服务器:

    python -m visdom.server
    

    这将在本地启动一个 Visdom 服务器,并显示服务器的 URL 地址,默认为 http://localhost:8097

  3. 连接到 Visdom 服务器

    在您的 Python 脚本中,您需要导入 Visdom 库并连接到正在运行的 Visdom 服务器。可以使用以下代码片段连接到服务器:

    python 复制代码
    import visdom
    
    # 创建 Visdom 客户端对象
    vis = visdom.Visdom()

    此时,您的客户端将通过默认的本地连接地址连接到 Visdom 服务器。

  4. 创建窗口并显示数据

    您可以使用 Visdom 客户端对象创建窗口,并将数据显示在窗口中。以下是一个简单的示例,展示如何在折线图窗口中显示一些数据:

    python 复制代码
    import visdom
    
    # 创建 Visdom 客户端对象
    vis = visdom.Visdom()
    
    # 创建折线图窗口并显示数据
    vis.line(Y=[0], X=[0], win='my_plot', opts=dict(title='My Plot'))
    vis.line(Y=[4, 2, 3], X=[1, 2, 3], win='my_plot', update='append')

    这将创建一个名为 "my_plot" 的折线图窗口,并在窗口中显示数据点 (1, 4),(2, 2),(3, 3)。之后,您可以通过不断更新数据来更新窗口中的图表。

这只是一个简单的使用示例,Visdom 还提供了许多其他类型的窗口和选项,用于展示和操作各种类型的数据。

python 复制代码
import torch as t
import visdom

# 新建一个连接客户端
# 指定env = u'test1',默认端口为8097,host是'localhost'
vis = visdom.Visdom(env=u'test1',use_incoming_socket=False)

x = t.arange(1, 30, 0.01)
y = t.sin(x)
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'})
  • vis = visdom.Visdom(env=u'test1'),用于构建一个客户端,客户端除指定env之外,还可以指定host、port等参数。

  • vis作为一个客户端对象,可以使用常见的画图函数,包括:

    • line:类似Matlab中的plot操作,用于记录某些标量的变化,如损失、准确率等
    • image:可视化图片,可以是输入的图片,也可以是GAN生成的图片,还可以是卷积核的信息
    • text:用于记录日志等文字信息,支持html格式
    • histgram:可视化分布,主要是查看数据、参数的分布
    • scatter:绘制散点图
    • bar:绘制柱状图
    • pie:绘制饼状图

Visdom同时支持PyTorch的tensor和Numpy的ndarray两种数据结构,但不支持Python的int、float等类型,因此每次传入时都需先将数据转成ndarray或tensor。上述操作的参数一般不同,但有两个参数是绝大多数操作都具备的:

  • win:用于指定pane的名字,如果不指定,visdom将自动分配一个新的pane。如果两次操作指定的win名字一样,新的操作将覆盖当前pane的内容,因此建议每次操作都重新指定win。
  • opts:选项,接收一个字典,常见的option包括titlexlabelylabelwidth等,主要用于设置pane的显示格式。

往往我们在训练网络的过程中需不断更新数值,如损失值等,这时就需要指定参数update='append'来避免覆盖之前的数值。

python 复制代码
import torch as t
import visdom

# 新建一个连接客户端
# 指定env = u'test1',默认端口为8097,host是'localhost'
vis = visdom.Visdom(env=u'test1',use_incoming_socket=False)

# append 追加数据
for ii in range(0, 10):
    # y = x
    x = t.Tensor([ii])
    y = x
    vis.line(X=x, Y=y, win='polynomial',name='Trace', update='append' if ii>0 else None)
    
# updateTrace 新增一条线
x = t.arange(0, 9, 0.1)
y = (x ** 2) / 9
vis.line(X=x, Y=y, win='polynomial', name='this is a new Trace',update='new')

vis.image

  • image接收一个二维或三维向量, H × W H\times W H×W或 3 × H × W 3 \times H\times W 3×H×W,前者是黑白图像,后者是彩色图像。
  • images接收一个四维向量 N × C × H × W N\times C\times H\times W N×C×H×W, C C C可以是1或3,分别代表黑白和彩色图像。可实现类似torchvision中make_grid的功能,将多张图片拼接在一起。images也可以接收一个二维或三维的向量,此时它所实现的功能与image一致。
python 复制代码
import torch as t
import visdom

# 新建一个连接客户端
# 指定env = u'test1',默认端口为8097,host是'localhost'
vis = visdom.Visdom(env=u'test',use_incoming_socket=False)

# 可视化一个随机的黑白图片
vis.image(t.randn(64, 64).numpy())

# 随机可视化一张彩色图片
vis.image(t.randn(3, 64, 64).numpy(), win='random2')

# 可视化36张随机的彩色图片,每一行6张
vis.images(t.randn(36, 3, 64, 64).numpy(), nrow=6, win='random3', opts={'title':'random_imgs'})

vis.text

在 Visdom 的 vis.text 函数中,可以使用 HTML 标签来自定义文本的样式和布局。以下是一个示例,展示如何在 vis.text 中使用不同的 HTML 标签和属性:

python 复制代码
import visdom

# 连接到 Visdom 服务器
viz = visdom.Visdom()

# 创建一个文本窗口,并使用 HTML 标签来设置样式和布局
html_content = """
<h1 style="color: red;">这是一个标题</h1>
<p style="font-size: 20px;">这是一个段落</p>
<ul>
    <li>列表项1</li>
    <li>列表项2</li>
    <li>列表项3</li>
</ul>
"""

viz.text(html_content)

在这个示例中,我们使用 HTML 标签和属性来设置文本的样式和布局。通过使用 <h1> 标签,我们将文本设置为红色的标题。使用 <p> 标签,我们将文本设置为字体大小为 20px 的段落。使用 <ul><li> 标签,我们创建了一个无序列表。

当调用 viz.text 并传入带有 HTML 标签的文本内容时,Visdom 会解析该内容并相应地显示在文本窗口中。

请注意,有些 HTML 标签和属性可能在 Visdom 中不被完全支持,或者显示效果可能会因浏览器兼容性而有所区别。

相关推荐
四口鲸鱼爱吃盐15 分钟前
Pytorch | 从零构建MobileNet对CIFAR10进行分类
人工智能·pytorch·分类
苏言の狗16 分钟前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
bastgia1 小时前
Tokenformer: 下一代Transformer架构
人工智能·机器学习·llm
菜狗woc1 小时前
opencv-python的简单练习
人工智能·python·opencv
15年网络推广青哥1 小时前
国际抖音TikTok矩阵运营的关键要素有哪些?
大数据·人工智能·矩阵
weixin_387545642 小时前
探索 AnythingLLM:借助开源 AI 打造私有化智能知识库
人工智能
engchina2 小时前
如何在 Python 中忽略烦人的警告?
开发语言·人工智能·python
paixiaoxin3 小时前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
OpenCSG3 小时前
CSGHub开源版本v1.2.0更新
人工智能
weixin_515202493 小时前
第R3周:RNN-心脏病预测
人工智能·rnn·深度学习