【深度学习】pytorch——常用工具模块

笔记为自我总结整理的学习笔记,若有错误欢迎指出哟~

深度学习专栏链接:
http://t.csdnimg.cn/dscW7

pytorch------常用工具模块

数据处理 torch.utils.data模块

在解决深度学习问题的过程中,往往需要花费大量的精力去处理数据,包括图像、文本、语音或其它二进制数据等。数据的处理对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练,更会提高模型效果。考虑到这点,PyTorch提供了几个高效便捷的工具,以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。

  • Dataset 类:用于表示数据集,可以通过继承这个类来创建自定义的数据集。
  • DataLoader 类:用于批量加载数据,可以指定批量大小、是否打乱数据等参数。

Dataset

在PyTorch中,数据加载可通过自定义的数据集对象。数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset,并实现两个Python魔法方法:

  • __getitem__:返回一条数据,或一个样本。obj[index]等价于obj.__getitem__(index)
  • __len__:返回样本的数量。len(obj)等价于obj.__len__()

DataLoader

要创建一个 DataLoader,我们需要指定以下参数:

  • 数据集实例:这通常是你自定义的数据集类的实例,例如 CustomDataset。
  • 批量大小(batch_size):用于指定每个批量包含的样本数量。
  • 是否打乱数据(shuffle):指定是否在每个 epoch 开始时对数据进行打乱,通常在训练过程中会打乱数据,而在验证或测试过程中不会。
  • 多线程加载(num_workers):指定用于数据加载的线程数,可以加快数据加载速度。

sampler

在 PyTorch 中,torch.utils.data.sampler 模块包含了多种用来对数据进行采样的类,例如 SequentialSamplerRandomSamplerSubsetRandomSampler 等。这些采样器可以用于创建自定义的数据采样策略,以满足不同的训练需求。

下面是一些常用采样器的用法举例:

  1. SequentialSampler:顺序采样器,在每个 epoch 中按顺序遍历整个数据集。
python 复制代码
from torch.utils.data import DataLoader, SequentialSampler

# 创建顺序采样器
sampler = SequentialSampler(dataset)

# 使用采样器创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, sampler=sampler)
  1. RandomSampler:随机采样器,每个 epoch 随机对数据集进行采样。
python 复制代码
from torch.utils.data import DataLoader, RandomSampler

# 创建随机采样器
sampler = RandomSampler(dataset, replacement=True, num_samples=100)

# 使用采样器创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, sampler=sampler)
  1. SubsetRandomSampler:从给定索引中随机采样子集。
python 复制代码
from torch.utils.data import SubsetRandomSampler

# 创建一个索引列表
indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# 使用随机子集采样器创建数据加载器
sampler = SubsetRandomSampler(indices)
data_loader = DataLoader(dataset, batch_size=32, sampler=sampler)
  1. WeightedRandomSampler:加权随机采样,允许根据每个样本的权重来进行采样,从而更灵活地处理不平衡的数据集。这在处理类别不平衡、稀有事件或其他特定情况下非常有用。
python 复制代码
from torch.utils.data import DataLoader, WeightedRandomSampler

# 假设有一个数据集和对应的样本权重
dataset = YourDataset()
weights = [0.1, 0.5, 0.8, 0.3, 0.6]  # 每个样本的权重

# 创建加权随机采样器
sampler = WeightedRandomSampler(weights, num_samples=10, replacement=True)

# 使用采样器创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, sampler=sampler)

在这个示例中,WeightedRandomSampler 接受一个权重列表作为输入,并可以指定采样的样本数目和是否使用放回抽样(replacement=True 表示可以重复采样同一个样本)。

这些示例展示了如何使用不同的采样器来创建数据加载器,并指定不同的采样策略。可以根据具体的训练需求选择合适的采样器,并结合数据加载器来灵活地管理数据的采样和训练过程。

torch.utils.data的使用

假设有一个包含图像数据和对应标签的数据集,将创建一个自定义的数据集类来加载这些数据,并使用 DataLoader 来批量加载数据供模型训练使用。

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader

# 假设你有图像数据和对应标签的数据集
class CustomDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = targets
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]
        if self.transform:
            x = self.transform(x)
        return x, y

# 创建数据集实例
# 假设 data 和 targets 是你的图像数据和对应标签
custom_dataset = CustomDataset(data, targets, transform=your_transforms)

# 使用 DataLoader 批量加载数据
batch_size = 32
shuffle = True
num_workers = 4  # 可以加快数据加载速度的线程数

data_loader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

# 遍历数据加载器,获取批量数据
for inputs, labels in data_loader:
    # 在这里执行模型训练或推理
    pass
  • 创建了一个自定义的数据集类 CustomDataset,该类继承自 torch.utils.data.Dataset,并实现了 lengetitem 方法。
  • 创建了数据集实例 custom_dataset,并使用 DataLoader 实例 data_loader 批量加载数据。
  • 通过遍历数据加载器,可以获取批量的输入数据和对应的标签,用于模型的训练或推理过程。

计算机视觉工具包 torchvision

PyTorch 的计算机视觉工具包 torchvision 提供了一系列用于图像处理和计算机视觉任务的工具和数据集。它包含了常用的数据集(如 ImageNet、CIFAR10、COCO 等)、图像变换操作、模型架构以及预训练的模型等功能,方便用户快速构建和训练计算机视觉模型。

以下是 torchvision 中一些常用的功能和模块:

  1. 数据集和数据加载器torchvision.datasets 模块提供了常用的图像数据集,例如 CIFAR, COCO, MNIST 等,并且可以通过 torchvision.transforms 模块中的图像变换操作对数据进行预处理。同时,torchvision.transforms 还提供了各种图像变换操作,如裁剪、缩放、翻转等,用于数据增强和预处理。

  2. 模型架构和预训练模型torchvision.models 模块包含了一些经典的计算机视觉模型,如 ResNet、VGG、AlexNet 等,同时还提供了这些模型在 ImageNet 数据集上预训练的参数。这些预训练模型可以用于迁移学习或者基准测试。

  3. 图像工具函数torchvision.utils 模块中提供了一些图像操作的工具函数,比如保存图像、绘制边界框、可视化图像等功能。

使用 torchvision 可以大大简化计算机视觉任务的开发过程,提高开发效率,特别是在处理图像数据、构建模型、模型评估等方面提供了很多便利。

torchvision.datasets模块

torchvision.datasets 模块是 PyTorch 中用于加载和处理常见图像数据集的模块。这个模块提供了许多流行的图像数据集,使得用户可以轻松地获取这些数据集并用于模型训练和评估。

一些常见的数据集包括:

  1. MNIST: 包含手写数字图片的数据集,常用于图像分类任务。

  2. CIFAR10 和 CIFAR100: 分别包含 10 个类别和 100 个类别的彩色图片数据集,也用于图像分类任务。

  3. ImageNet: 包含数百万张图片,涵盖了数千个类别,常用于大规模图像分类和目标检测任务。

  4. COCO (Common Objects in COntext): 包含了大量的标注的图像,用于目标检测、实例分割等任务。

torchvision.datasets 不仅提供了这些数据集的接口,还提供了数据加载器(data loader),从而可以方便地将数据集加载到模型中进行训练和测试。

使用 torchvision.datasets,我们可以通过几行简单的代码来加载这些数据集,并且可以对数据进行预处理、数据增强等操作,为模型训练提供方便。

例如,以下是使用 torchvision.datasets 加载 CIFAR-10 数据集的示例代码:

python 复制代码
import torchvision
import torchvision.transforms as transforms

# 定义数据转换操作
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

这样就可以轻松地获取 CIFAR-10 数据集,并且进行相应的数据预处理,为模型训练做准备。

torchvision.transforms模块

transforms 模块是 PyTorch 中 torchvision 库的一部分,它提供了各种图像预处理和数据增强的函数,用于在训练神经网络时对图像数据进行处理。

以下是 transforms 模块中常用的一些函数:

  1. Pad(padding, fill=0, padding_mode='constant'):对图像进行填充。

  2. ToTensor():将 PIL 图像或 ndarray 转换为 tensor,并且将数值范围缩放到 [0, 1] 或 [-1, 1]。

  3. Normalize(mean, std):对 tensor 进行标准化,减去均值然后除以标准差。这个操作通常用于对输入数据进行归一化处理。

  4. Resize(size):调整图像大小为指定的尺寸。

  5. RandomHorizontalFlip():随机水平翻转图像,用于数据增强。

  6. RandomVerticalFlip():随机垂直翻转图像,用于数据增强。

  7. RandomRotation(degrees):随机旋转图像一定角度,用于数据增强。

  8. RandomCrop(size):随机裁剪图像到指定的尺寸,用于数据增强。

  9. ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):随机改变图像的亮度、对比度、饱和度和色相,用于数据增强。

  10. ToPILImage():将张量(tensor)转换为 PIL 图像格式

这些函数可以通过 transforms.Compose() 组合在一起,构成一个图像预处理流水线,然后应用于加载的图像数据上,以便在训练神经网络时进行数据处理和增强。

python 复制代码
import torchvision.transforms as transforms
from PIL import Image

# 加载图像
image_path = "data/dogcat/cat.12484.jpg"
image = Image.open(image_path)
image.show()

# 定义图像变换操作
transform = transforms.Compose([
    transforms.Pad(padding=10, fill=0, padding_mode='constant'),  # 填充操作
    transforms.Resize(256),  # 调整图像大小为 256x256
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomVerticalFlip(),  # 随机垂直翻转
    transforms.RandomRotation(degrees=45),  # 随机旋转图像最多45度
    transforms.RandomCrop(224),  # 随机裁剪图像到224x224
    #transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # 颜色增强
    transforms.ToTensor(),  # 转换为张量
    #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # 标准化
])

# 应用图像变换操作
transformed_image = transform(image)

# 显示处理后的图像
transforms.ToPILImage()(transformed_image).show()


torchvision.models模块

torchvision.models 模块提供了在 PyTorch 中使用的一系列经典的预训练模型,例如 ResNet、VGG、AlexNet、GoogLeNet 等。这些预训练模型可以用于图像分类、目标检测、语义分割等任务,并且方便用户进行迁移学习或微调。

使用 torchvision.models 模块,我们可以轻松地访问这些经典模型,并且可以加载预训练的权重参数,从而在自己的数据集上进行模型训练或推理。

以下是一些常用的预训练模型:

  1. ResNet: 包括 ResNet-18、ResNet-34、ResNet-50 等不同深度的 ResNet 模型,用于图像分类任务。

  2. VGG: 包括 VGG-11、VGG-16、VGG-19 等不同深度的 VGG 模型,也用于图像分类任务。

  3. AlexNet: AlexNet 是一个较早期的深度卷积神经网络模型,也常用于图像分类任务。

  4. GoogLeNet: GoogLeNet 是由 Google 提出的深度卷积神经网络,适用于图像分类和目标检测任务。

  5. DenseNet: 密集连接网络(DenseNet)是另一个流行的卷积神经网络结构,适用于图像分类和其他计算机视觉任务。

通过 torchvision.models,我们可以很方便地加载这些模型,并且可以直接用于自己的任务,或者进行进一步的微调以适应特定的数据集和任务需求。

以下是一个示例,展示了如何使用 torchvision.models 加载预训练的 ResNet-18 模型:

python 复制代码
import torchvision.models as models

# 加载预训练的 ResNet-18 模型
resnet18 = models.resnet18(pretrained=True)

# 对模型进行微调或者用于推理

可视化工具 Visdom

Visdom 官方文档(https://github.com/fossasia/visdom

Visdom 是一个用于创建实时交互式可视化的工具,最初是由 Facebook 的人工智能研究团队开发的,其开源于2017年3月,用于支持深度学习模型的可视化和监控。它提供了一个基于 Web 的用户界面,允许用户在浏览器中实时查看和操作可视化结果。Visdom 主要针对 PyTorch 和 Torch 等深度学习框架,但也可以与其他框架集成使用。

Visdom 的主要特点包括:

  1. 实时交互式可视化:Visdom 支持实时更新可视化结果,并且允许用户通过简单的交互方式进行操作,如缩放、平移、标注等,从而更好地理解数据和模型的行为。

  2. 多种类型的可视化:Visdom 提供了多种类型的可视化工具,包括折线图、条形图、散点图、热力图、直方图、图像等,满足了不同类型数据的可视化需求。

  3. 多用户支持:Visdom 支持多用户共享可视化结果,多个用户可以同时查看和操作可视化数据,这在团队协作以及教学研究方面非常有用。

  4. 语言无关性:Visdom 可以与多种编程语言进行集成,尤其是在 Python 和 Lua 等语言中应用较为广泛。

  5. 灵活的部署方式:Visdom 可以作为一个独立的服务器运行,也可以嵌入到现有的 Python 代码中,使得可视化过程更加灵活和定制化。

总的来说,Visdom 是一个功能强大、易于使用的可视化工具,特别适用于深度学习模型的训练过程监控、结果展示以及模型行为分析。通过实时交互式的可视化,用户可以更好地理解和优化他们的深度学习模型。

两个重要概念

在 Visdom 中,有两个重要的概念:窗口(window)和环境(environment)。

  1. 窗口(Window)

    在 Visdom 中,窗口是指用户界面中的一个可视化区域,用于展示特定类型的数据可视化结果,比如折线图、散点图、图像等。每个窗口都有一个唯一的标识符,可以通过这个标识符来更新或关闭窗口中的内容。用户可以在同一环境下创建多个窗口,用于同时展示不同类型的数据可视化结果,比如训练损失曲线、模型预测结果等。

  2. 环境(Environment)

    环境是 Visdom 中用于组织窗口的概念,可以理解为一个命名空间,用于区分不同类型或不同任务的可视化结果。不同环境的可视化结果相互隔离,互不影响,在使用时如果不指定env,默认使用main。不同用户、不同程序一般使用不同的env。

这两个概念的引入使得 Visdom 在展示和组织数据可视化结果时更加灵活和清晰,同时也方便用户对不同类型的数据进行管理和交互操作。

Visdom的使用

要使用 Visdom 进行可视化,您需要按照以下步骤进行设置和操作:

  1. 安装 Visdom

    首先,您需要在您的环境中安装 Visdom。可以使用以下命令使用 pip 安装 Visdom 库:

    复制代码
    pip install visdom
  2. 启动 Visdom 服务器

    在安装完成后,您需要启动 Visdom 服务器。可以在终端中运行以下命令启动服务器:

    复制代码
    python -m visdom.server

    这将在本地启动一个 Visdom 服务器,并显示服务器的 URL 地址,默认为 http://localhost:8097

  3. 连接到 Visdom 服务器

    在您的 Python 脚本中,您需要导入 Visdom 库并连接到正在运行的 Visdom 服务器。可以使用以下代码片段连接到服务器:

    python 复制代码
    import visdom
    
    # 创建 Visdom 客户端对象
    vis = visdom.Visdom()

    此时,您的客户端将通过默认的本地连接地址连接到 Visdom 服务器。

  4. 创建窗口并显示数据

    您可以使用 Visdom 客户端对象创建窗口,并将数据显示在窗口中。以下是一个简单的示例,展示如何在折线图窗口中显示一些数据:

    python 复制代码
    import visdom
    
    # 创建 Visdom 客户端对象
    vis = visdom.Visdom()
    
    # 创建折线图窗口并显示数据
    vis.line(Y=[0], X=[0], win='my_plot', opts=dict(title='My Plot'))
    vis.line(Y=[4, 2, 3], X=[1, 2, 3], win='my_plot', update='append')

    这将创建一个名为 "my_plot" 的折线图窗口,并在窗口中显示数据点 (1, 4),(2, 2),(3, 3)。之后,您可以通过不断更新数据来更新窗口中的图表。

这只是一个简单的使用示例,Visdom 还提供了许多其他类型的窗口和选项,用于展示和操作各种类型的数据。

python 复制代码
import torch as t
import visdom

# 新建一个连接客户端
# 指定env = u'test1',默认端口为8097,host是'localhost'
vis = visdom.Visdom(env=u'test1',use_incoming_socket=False)

x = t.arange(1, 30, 0.01)
y = t.sin(x)
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'})
  • vis = visdom.Visdom(env=u'test1'),用于构建一个客户端,客户端除指定env之外,还可以指定host、port等参数。

  • vis作为一个客户端对象,可以使用常见的画图函数,包括:

    • line:类似Matlab中的plot操作,用于记录某些标量的变化,如损失、准确率等
    • image:可视化图片,可以是输入的图片,也可以是GAN生成的图片,还可以是卷积核的信息
    • text:用于记录日志等文字信息,支持html格式
    • histgram:可视化分布,主要是查看数据、参数的分布
    • scatter:绘制散点图
    • bar:绘制柱状图
    • pie:绘制饼状图

Visdom同时支持PyTorch的tensor和Numpy的ndarray两种数据结构,但不支持Python的int、float等类型,因此每次传入时都需先将数据转成ndarray或tensor。上述操作的参数一般不同,但有两个参数是绝大多数操作都具备的:

  • win:用于指定pane的名字,如果不指定,visdom将自动分配一个新的pane。如果两次操作指定的win名字一样,新的操作将覆盖当前pane的内容,因此建议每次操作都重新指定win。
  • opts:选项,接收一个字典,常见的option包括titlexlabelylabelwidth等,主要用于设置pane的显示格式。

往往我们在训练网络的过程中需不断更新数值,如损失值等,这时就需要指定参数update='append'来避免覆盖之前的数值。

python 复制代码
import torch as t
import visdom

# 新建一个连接客户端
# 指定env = u'test1',默认端口为8097,host是'localhost'
vis = visdom.Visdom(env=u'test1',use_incoming_socket=False)

# append 追加数据
for ii in range(0, 10):
    # y = x
    x = t.Tensor([ii])
    y = x
    vis.line(X=x, Y=y, win='polynomial',name='Trace', update='append' if ii>0 else None)
    
# updateTrace 新增一条线
x = t.arange(0, 9, 0.1)
y = (x ** 2) / 9
vis.line(X=x, Y=y, win='polynomial', name='this is a new Trace',update='new')

vis.image

  • image接收一个二维或三维向量, H × W H\times W H×W或 3 × H × W 3 \times H\times W 3×H×W,前者是黑白图像,后者是彩色图像。
  • images接收一个四维向量 N × C × H × W N\times C\times H\times W N×C×H×W, C C C可以是1或3,分别代表黑白和彩色图像。可实现类似torchvision中make_grid的功能,将多张图片拼接在一起。images也可以接收一个二维或三维的向量,此时它所实现的功能与image一致。
python 复制代码
import torch as t
import visdom

# 新建一个连接客户端
# 指定env = u'test1',默认端口为8097,host是'localhost'
vis = visdom.Visdom(env=u'test',use_incoming_socket=False)

# 可视化一个随机的黑白图片
vis.image(t.randn(64, 64).numpy())

# 随机可视化一张彩色图片
vis.image(t.randn(3, 64, 64).numpy(), win='random2')

# 可视化36张随机的彩色图片,每一行6张
vis.images(t.randn(36, 3, 64, 64).numpy(), nrow=6, win='random3', opts={'title':'random_imgs'})

vis.text

在 Visdom 的 vis.text 函数中,可以使用 HTML 标签来自定义文本的样式和布局。以下是一个示例,展示如何在 vis.text 中使用不同的 HTML 标签和属性:

python 复制代码
import visdom

# 连接到 Visdom 服务器
viz = visdom.Visdom()

# 创建一个文本窗口,并使用 HTML 标签来设置样式和布局
html_content = """
<h1 style="color: red;">这是一个标题</h1>
<p style="font-size: 20px;">这是一个段落</p>
<ul>
    <li>列表项1</li>
    <li>列表项2</li>
    <li>列表项3</li>
</ul>
"""

viz.text(html_content)

在这个示例中,我们使用 HTML 标签和属性来设置文本的样式和布局。通过使用 <h1> 标签,我们将文本设置为红色的标题。使用 <p> 标签,我们将文本设置为字体大小为 20px 的段落。使用 <ul><li> 标签,我们创建了一个无序列表。

当调用 viz.text 并传入带有 HTML 标签的文本内容时,Visdom 会解析该内容并相应地显示在文本窗口中。

请注意,有些 HTML 标签和属性可能在 Visdom 中不被完全支持,或者显示效果可能会因浏览器兼容性而有所区别。

相关推荐
QQ676580084 分钟前
智慧工厂之扬尘识别 铲车装载识别 工程重型机械识别 磁铁识别 深度学习YOLO格式图像识别第10435期
人工智能·深度学习·yolo·扬尘识别·铲车装载·工程重型机械·磁铁识别
Raink老师8 分钟前
【AI面试临阵磨枪】KV Cache 是什么?为什么能加速推理?如何实现?
人工智能·ai 面试
newsxun40 分钟前
第十六届北京国际电影节东郎分会场启幕
人工智能
大嘴皮猴儿41 分钟前
从零开始学商品图翻译:小白也能快速掌握的多语言文字处理与上架技巧
大数据·ide·人工智能·macos·新媒体运营·xcode·自动翻译
思绪无限42 分钟前
YOLOv5至YOLOv12升级:行人跌倒检测系统的设计与实现(完整代码+界面+数据集项目)
深度学习·yolo·目标检测·yolov12·yolo全家桶·行人跌倒检测系统
大黄说说43 分钟前
AI大模型对内容创作的颠覆:机遇、版权争议与行业新规则
人工智能
captain_AIouo1 小时前
OZON航海引领者Captain AI指引运营新航向
大数据·人工智能·经验分享·aigc
AI医影跨模态组学1 小时前
PLOS Medicine 中山大学肿瘤防治中心蔡木炎等团队:基于多视角深度学习的组织病理学分析用于II期结直肠癌的预后与治疗分层
人工智能·深度学习·论文·医学·医学影像
起个名字总是说已存在1 小时前
github开源AI技能:Awesome DESIGN.md让页面设计无限可能
人工智能·开源·github
Aray12341 小时前
大模型推理全栈技术解析:从Transformer到RoPE/YaRN的上下文优化
人工智能·深度学习·transformer