详细介绍Pytorch中torchvision的相关使用

torchvision 是 PyTorch 的一个官方库,主要用于处理计算机视觉任务。提供了许多常用的数据集、模型架构、图像转换等功能,使得计算机视觉任务的开发变得更加高效和便捷。以下是对 torchvision 主要功能的详细介绍:

1. 数据集(Datasets)

torchvision 提供了许多常用的计算机视觉数据集,如 CIFAR-10、MNIST、ImageNet 等。这些数据集可以直接通过 torchvision.datasets 模块加载。

示例:加载 CIFAR-10 数据集
python 复制代码
from torchvision import datasets
from torch.utils.data import DataLoader

# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True)

# 使用 DataLoader 加载数据
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

2. 图像转换(Transforms)

torchvision.transforms 模块提供了许多常用的图像转换操作,如裁剪、缩放、旋转、翻转等。这些转换操作可以单独使用,也可以组合使用。

示例:组合图像转换操作
python 复制代码
from torchvision import transforms

# 定义转换操作
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 应用转换操作
train_dataset.transform = transform
test_dataset.transform = transform

3. 预训练模型(Models)

torchvision.models 模块提供了许多常用的预训练模型,如 ResNet、VGG、AlexNet、DenseNet 等。这些模型可以直接用于迁移学习或作为基准模型。

示例:加载预训练的 ResNet-50 模型
python 复制代码
from torchvision import models
import torch.nn as nn

# 加载预训练的 ResNet-50 模型
model = models.resnet50(pretrained=True)

# 修改最后一层以适应新的分类任务
num_classes = 10
model.fc = nn.Linear(model.fc.in_features, num_classes)

4. 数据加载器(DataLoader)

torch.utils.data.DataLoader 是一个实用的数据加载器,可以与 torchvision 提供的数据集一起使用,方便地进行批量加载和数据迭代。

示例:使用 DataLoader 加载数据
python 复制代码
from torch.utils.data import DataLoader

# 使用 DataLoader 加载数据
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 训练模型
for images, labels in train_loader:
    # 训练代码
    pass

5. 自定义数据集(Custom Datasets)

如果需要使用自定义数据集,可以继承 torch.utils.data.Dataset 类,并实现 __len____getitem__ 方法。

示例:自定义数据集
python 复制代码
from torch.utils.data import Dataset
from PIL import Image
import os

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = os.listdir(root_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.images[idx])
        image = Image.open(img_path)
        if self.transform:
            image = self.transform(image)
        return image

# 使用自定义数据集
custom_dataset = CustomDataset(root_dir='path/to/dataset', transform=transform)
custom_loader = DataLoader(custom_dataset, batch_size=64, shuffle=True)

6. 可视化(Visualization)

torchvision 还提供了一些用于可视化的工具,如 torchvision.utils.make_grid 可以将多个图像拼接成一个网格图像。

示例:可视化图像
python 复制代码
import matplotlib.pyplot as plt
from torchvision import utils

# 获取一批图像
images, labels = next(iter(train_loader))

# 将图像拼接成网格
grid = utils.make_grid(images)

# 显示图像
plt.imshow(grid.permute(1, 2, 0))
plt.show()
相关推荐
AAD555888991 小时前
数字仪表LCD显示识别与读数:数字0-9、小数点及单位kwh检测识别实战
python
开源技术3 小时前
Python Pillow 优化,打开和保存速度最快提高14倍
开发语言·python·pillow
Niuguangshuo3 小时前
深入解析Stable Diffusion基石——潜在扩散模型(LDMs)
人工智能·计算机视觉·stable diffusion
迈火3 小时前
SD - Latent - Interposer:解锁Stable Diffusion潜在空间的创意工具
人工智能·gpt·计算机视觉·stable diffusion·aigc·语音识别·midjourney
wfeqhfxz25887823 小时前
YOLO13-C3k2-GhostDynamicConv烟雾检测算法实现与优化
人工智能·算法·计算机视觉
芝士爱知识a4 小时前
2026年AI面试软件推荐
人工智能·面试·职场和发展·大模型·ai教育·考公·智蛙面试
Li emily4 小时前
解决港股实时行情数据 API 接入难题
人工智能·python·fastapi
Aaron15884 小时前
基于RFSOC的数字射频存储技术应用分析
c语言·人工智能·驱动开发·算法·fpga开发·硬件工程·信号处理
J_Xiong01174 小时前
【Agents篇】04:Agent 的推理能力——思维链与自我反思
人工智能·ai agent·推理
wfeqhfxz25887824 小时前
农田杂草检测与识别系统基于YOLO11实现六种杂草自动识别_1
python