PyTorch 基础知识

PyTorch 基础知识

基础知识

查看GPU

bash 复制代码
nvidia-smi

查看PyTorch是否正常调用GPU

python 复制代码
import torch
torch.cuda.is_available()

启动jupyter notebook

bash 复制代码
jupyter notebook

dir()函数 可以查看包(如 torch)中有哪些模块、子包或属性,类似打开一个工具箱,查看里面的各种子工具箱。继续对其中的某个分隔区调用 dir(),可以查看更深层的内容(如 dir(torch.cuda) 会显示 torch.cuda 内部的模块或函数)。如果 dir() 的输出里出现双下划线 __ 包裹的名称,一般说明这是特殊属性或方法,而不是一个子模块。

help()函数 可以查看函数或类的官方文档/说明书,了解它的用途和用法。例如,help(torch.cuda.is_available) 可以查看该函数是否存在,以及它返回什么结果(返回 True 或 False 来指示是否可使用 CUDA)。注意在查看函数文档时,函数名后不带括号 (如 help(torch.cuda.is_available) 而不是 help(torch.cuda.is_available()))。

python 复制代码
help(torch.cuda.is_available)

# Help on function is_available in module torch.cuda:

# is_available() -> bool
#    Return a bool indicating if CUDA is currently available.

os.path.join 可以将自动使用正确的路径分隔符,将多个路径组件拼接在一起。

python 复制代码
import os

path1 = r"C:\Users\lenovo\Desktop\cpp\main\java\org\example\webchat"
path2 = r"Server"
print(os.path.join(path1, path2))
# C:\Users\lenovo\Desktop\cpp\main\java\org\example\webchat\Server

os.path.split 将路径拆分为 (目录路径, 文件名) 的元组。

python 复制代码
import os

file_path = "/home/user/documents/file.txt"
directory, filename = os.path.split(file_path)
print(directory, filename)
# /home/user/documents file.txt

os.path.splitext 将路径拆分为 (文件名, 文件拓展名) 的元组。

python 复制代码
import os

file_path = "/home/user/documents/file.txt"
filename, extendname = os.path.splitext(file_path)
print(filename, extendname)
# /home/user/documents/file .txt

split() 是Python内置的字符串方法,可以按照指定的分隔符拆分字符串,并返回一个列表。

python 复制代码
string.split(separator, maxsplit)
# separator 可选 用于拆分字符串的分隔符(默认为空格)
# maxsplit  可选 最多拆分几次(默认拆分所有)
python 复制代码
text = "apple,banana,orange"
result = text.split(",")
print(result)
# ['apple', 'banana', 'orange']

从PIL库中导入Image模块,可以便捷地使用图像操作功能。

python 复制代码
from PIL import Image
# 加载图像
img = Image.open('example.jpg')
# 显示图像
img.show()
# 保存图像为不同格式
img.save('new_image.png')
# 将图像转换成其他模式,灰度图L,彩色图RGB
img_gray = img.convert('L')  
# 调整图像的尺寸
img_resized = img.resize((200, 200))

Python格式化输出

python 复制代码
a = "hello"
print("{}, Li Hua".format(a))

Ctrl+P 可以提示函数里面需要什么参数

Dataset类

Dataset是PyTorch用于数据加载的基类,主要用于构建自定义数据集 。Dataset类可以提供索引访问数据的能力,类似 list[index] ,也可以处理不同类型的数据(图片、文本、视频等),可以在__getitem__ 方法中进行数据增强或转换等预处理。

Dataset需要继承 torch.utils.data.Dataset ,并重写以下方法:

python 复制代码
import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels):
        """ 初始化数据 """
        self.data = data
        self.labels = labels

    def __getitem__(self, index):
        """ 根据索引 index 返回数据 """
        x = self.data[index]
        y = self.labels[index]
        return x, y  # 返回数据和标签

    def __len__(self):
        """ 返回数据集大小 """
        return len(self.data)

# 创建数据集
data = torch.randn(100, 3)  # 100 个样本,每个样本 3 维
labels = torch.randint(0, 2, (100,))  # 100 个二分类标签

# 实例化数据集
dataset = MyDataset(data, labels)

# 访问数据
print(len(dataset))  # 100
print(dataset[0])  # 第 0 个样本的数据和标签

以蚂蚁蜜蜂二分类数据集为例,先观察文件结构:

在hymenoptera_data文件夹下,分为train训练集和val验证集,在训练集中,分为ants和bees两个文件夹,每个文件夹下是对应的图片。

python 复制代码
from torch.utils.data import Dataset  # 导入Dataset类,用于创建自定义数据集
from PIL import Image  # 导入PIL库,用于读取图片
import os  # 导入os库,用于处理文件路径

# 创建一个继承自Dataset类的自定义数据集类
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        """
        初始化方法,创建数据集实例时调用
        root_dir: 数据集根目录,包含了所有类别的文件夹
        label_dir: 类别文件夹名(例如"ants"或"bees")
        """
        self.root_dir = root_dir  # 保存数据集根目录路径
        self.label_dir = label_dir  # 保存类别文件夹名
        self.path = os.path.join(self.root_dir, self.label_dir)  # 构建当前类别的完整路径
        # 获取该类别下所有图片文件名,生成图片地址列表
        # ['0013035.jpg', '1030023514_aad5c608f9.jpg', ... ]
        self.img_path = os.listdir(self.path)  

    def __getitem__(self, idx):
        """
        根据给定的索引返回一张图片及其标签
        idx: 索引值
        return: 返回指定索引处的图片和标签
        """
        img_name = self.img_path[idx]  # 获取图片文件名
        # 构建图片的完整路径
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  
        img = Image.open(img_item_path)  # 使用PIL读取图像文件
        label = self.label_dir  # 标签就是文件夹名(例如"ants"或"bees")
        return img, label  # 返回图片和标签

    def __len__(self):
        """
        返回数据集中图片的数量
        return: 数据集中的图片数量
        """
        return len(self.img_path)  # 数据集大小,等于图片列表的长度

# 创建数据集实例
# 数据集根目录
root_dir = r"D:\deeplearning_ai_books\tudui\Dataset\hymenoptera_data\train"  
ants_label_dir = "ants"  # 蚂蚁类别的文件夹名
bees_label_dir = "bees"  # 蜜蜂类别的文件夹名

# 创建蚂蚁类别和蜜蜂类别的两个数据集实例
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)

# 获取蚂蚁数据集中的第一张图片及其标签
img, label = ants_dataset[0]
img.show()  # 可视化第一张图片

# 将蚂蚁数据集和蜜蜂数据集合并为一个训练数据集
train_dataset = ants_dataset + bees_dataset

Torchvision DataSets

torchvision 主要模块

模块 作用
torchvision.datasets 提供标准数据集(如 CIFAR-10, MNIST, COCO)
torchvision.io 处理输入输出(较少用)
torchvision.models 预训练模型(如 ResNet, VGG, Faster R-CNN)
torchvision.ops 特殊操作(较少用)
torchvision.transforms 提供图像变换(Resize, ToTensor, Normalize 等)
torchvision.utils 提供工具(如 TensorBoard 可视化)

torchvision 中提供了多种标准数据集,可以方便地进行下载、加载和预处理,并结合 transforms 进行数据增强和格式转换。

以 CIFAR-10 数据集为例,共有10类,每类6000张图片。图片默认是 PIL Image,标签是整数 0-9,可通过 CIFAR10.classes 获取对应类别名称。

python 复制代码
# 构造函数
torchvision.datasets.CIFAR10(
    root: str,                                     # 数据集存储路径(字符串)
    train: bool = True,                            # True 加载训练集,False 加载测试集
    transform: Optional[Callable] = None,          # 图像转换函数
    target_transform: Optional[Callable] = None,   # 标签转换函数
    download: bool = False                         # True自动下载数据集,False加载本地数据
)
python 复制代码
import torchvision

# 下载并加载 CIFAR-10 数据集
train_set = torchvision.datasets.CIFAR10(root="./Dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./Dataset", train=False, download=True)

print(train_set[0])
# (<PIL.Image.Image image mode=RGB size=32x32 at 0x2537AFD81F0>, 6)

# 查看测试集的第一张图片和类别
img, target = train_set[0]

# 查看所有类别
print(train_set.classes)
# ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 查看具体类别
print(train_set.classes[target])
# frog

# 显示图片
img.show()

由于CIFAR-10原始数据是 PIL Image格式,无法直接用于PyTorch训练。所以可以在数据集下载的过程中就对其进行 transforms 处理,如 ToTensor 操作。

python 复制代码
import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_tranforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root="./Dataset", train=True, 
                   transform=dataset_tranforms, download=True)
test_set = torchvision.datasets.CIFAR10(root="./Dataset", train=False, 
                   transform=dataset_tranforms, download=True)

writer = SummaryWriter("logs")

for i in range(10):
    img, _ = train_set[i]
    writer.add_image("CIFAR-10", img, i)

writer.close()

DataLoader类

DataLoader 是 PyTorch 提供的数据加载器 ,用于从数据集中批量采样,并提供数据预处理、打乱顺序、多进程加载等功能。

Dataset类相当于一叠扑克牌,存储了所有数据。而 DataLoader 类相当于发牌,每次从 Dataset 中取出一定数量(batch_size)的数据。

python 复制代码
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,        # 传入数据集
    batch_size=64,  # 每个 batch 取 64 张图片
    shuffle=True,   # 是否打乱数据
    num_workers=0,  # 线程数(Windows 需设为 0)
    drop_last=False # 是否丢弃最后一个数量不满 batch_size 的 batch
)

batch_size 决定每次加载多少数据,batch_size=len(dataset) 一次性加载所有数据。

drop_last 的作用是如果数据集的大小不是 batch_size 的整数倍,是否丢弃最后不足 batch_size 的部分。

shuffle 的作用是决定是否在每个 epoch 重新打乱数据。一般训练集需要打乱数据,而测试集不需要打乱数据。如果**shuffle=True** 那么两轮 epoch 取的数据是不同的顺序

取单个数据

python 复制代码
train_set = torchvision.datasets.CIFAR10(root="./Dataset", train=True, 
                        transform=dataset_tranforms, download=True)

img, target = train_set[0]
print(img.shape)
# torch.Size([3, 32, 32])
print(target)
# 6

批量取数据

python 复制代码
train_set = torchvision.datasets.CIFAR10(root="./Dataset", train=True, 
                        transform=dataset_tranforms, download=True)

dataloader = DataLoader(dataset=train_set, batch_size=64, shuffle=True, drop_last=False)
for data in dataloader:
    imgs, targets = data
    print(imgs.shape)
    print(targets.shape)
    # torch.Size([64, 3, 32, 32])
    # torch.Size([64])
    # ......
    # torch.Size([16, 3, 32, 32])
    # torch.Size([16])

结合 DataLoader 与 SummaryWriter 进行可视化:

python 复制代码
from torch.utils.data import DataLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("logs")

# 预处理
dataset_tranforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

# 加载数据集
train_set = torchvision.datasets.CIFAR10(root="./Dataset", train=True, 
                        transform=dataset_tranforms, download=True)
# 创建DataLoader
dataloader = DataLoader(dataset=train_set, batch_size=64, shuffle=True, drop_last=False)

step = 0
for data in dataloader:
    imgs, _ = data
    writer.add_images("DataLoader", imgs, global_step=step)
    step = step + 1

writer.close()

Transforms

torchvision.transforms 提供了一系列图像预处理工具,用于对图像进行变换,如尺寸调整、归一化、格式转换等,以便在训练模型时使用。

python 复制代码
from torchvision import transforms

两种读取图片的方式

python 复制代码
import cv2
from PIL import Image

img_path = r"Dataset/hymenoptera_data/train/ants/0013035.jpg"

# **OpenCV**
# cv2 处理速度快,但**通道顺序为** **BGR**
img_cv = cv2.imread(img_path)
print(type(img_cv))
# <class 'numpy.ndarray'>

# **PIL Image**
# 常见的 Python 图像库,易用但不适合计算
img_PIL = Image.open(img_path)
print(type(img_PIL))
# <class 'PIL.JpegImagePlugin.JpegImageFile'>

但是 PIL Imagenumpy.ndarray 是普通图像格式,无法直接用于神经网络训练。Tensor 更适用于 PyTorch,支持 GPU 加速计算。同时 Tensor 包含梯度信息 ,可用于反向传播训练 ,形状为 (C, H, W)(通道数、高度、宽度)。

ToTensor() 函数可以将将 PIL Imagenumpy.ndarray 转换为 Tensor(同时归一化到 0,1)。

python 复制代码
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image

writer = SummaryWriter("logs")
img_path = r"C:\Users\lenovo\Desktop\camera\QIXIA\niu.jpg"
img_PIL = Image.open(img_path)

# transforms是一个类, ToTensor是类的方法,因此**调用时需要先实例化**
trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img_PIL)
writer.add_image("Image_Tensor", img_tensor, global_step=0, dataformats="CHW")

print(type(img_tensor))
# <class 'torch.Tensor'>

print(img_tensor.shape)
# torch.Size([3, 3712, 5568])

writer.close()

ToPILImage() 函数可以将 Tensornumpy.ndarray 转换为 PIL Image,适用于显示转换后的图片,以及方便使用 PIL 处理图片

python 复制代码
img_ToPIL = transforms.ToPILImage()
img_TOPIL = img_ToPIL(img_tensor)
print(type(img_TOPIL))
# <class 'PIL.Image.Image'>

Normalize() 函数对 Tensor 进行归一化,转换为自定义均值、标准差的分布。

o u t p u t = i n p u t − m e a n s t d output = \frac{input - mean}{std} output=stdinput−mean

如果原始图像的像素值在 0 , 1 0,1 0,1 范围内,并且设置均值和标准差为0.5,那么归一化后的像素值范围为 − 1 , 1 -1,1 −1,1

python 复制代码
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
print(img_tensor[0][0][0])
# tensor(0.3137)

img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0])
# tensor(-0.3725)

Resize() 函数可以调整图片尺寸。

python 复制代码
trans_resize = transforms.Resize((512, 512))
img_resize = trans_resize(img_tensor)
print(img_tensor.shape, img_resize.shape)
# torch.Size([3, 3712, 5568]) torch.Size([3, 512, 512])

等比例缩放,即只缩放短边到固定值,其他自动调整以保持宽高比

python 复制代码
trans_resize_2 = transforms.Resize(512)
img_resize_2 = trans_resize_2(img_tensor)
print(img_tensor.shape, img_resize_2.shape)
# torch.Size([3, 3712, 5568]) torch.Size([3, 512, 768])

RandomCrop() 函数可以随机裁剪图片。

python 复制代码
# 随机裁剪 500×1000 区域 5 次
trans_crop = transforms.RandomCrop((500, 1000))
for i in range(5):
    img_crop = trans_crop(img_tensor)

CenterCrop() 函数可以裁剪位于图像正中心固定大小。

python 复制代码
trans_center = transforms.CenterCrop(1000)
img_center = trans_center(img_tensor)

Compose() 函数可以组合并按顺序处理多个预处理变换。

python 复制代码
# 用列表的形式列出所有的预处理变换组合
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    transforms.Resize(1024),
    transforms.RandomCrop((500, 1000))
])
transformed = transform(img_PIL)
print(transformed.shape)
# torch.Size([3, 500, 1000])

TensorBoard

TensorBoard可以可视化训练过程 Loss 的变化,以及模型在不同阶段的输出

启动TensorBoard

bash 复制代码
tensorboard --logdir=logs  # 默认端口6006
tensorboard --logdir=logs --port=6007  # 指定端口

SummaryWriter 类用于创建日志文件并记录训练过程,会直接向自定义 log_dir 文件夹写入事件文件,供TensorBoard 解析。

python 复制代码
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="logs")  # 事件文件存储到自定义的logs目录
writer.close()

如果数据有误,就删除logdir文件夹,重新运行代码,终端终止TensorBoard后重新启动

add_scalar() 方法可以记录标量数据,如训练损失等

python 复制代码
add_scalar(tag, scalar_value, global_step)
# tag              数据标签(类似图表标题)
# scalar_value     需要保存的数值(y 轴)
# global_step      当前训练步数(x 轴)
python 复制代码
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
for i in range(100):
    writer.add_scalar("y=2x", 2*i, i)    # (标题, y轴数值, x轴步数)
writer.close()

add_image() 方法可以记录图像数据

python 复制代码
add_image(tag, img_tensor, global_step, dataformats="CHW")
# tag            图像的名称(类似图表标题)
# img_tensor     图片数据(支持 **torch.Tensor、numpy.array**)
# global_step    当前训练步数(用于对比不同时期的结果)
# dataformats    HWC(高度、宽度、通道)
python 复制代码
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np

writer = SummaryWriter("logs")

img_path = r"Dataset/hymenoptera_data/train/ants/0013035.jpg"
img_PIL = Image.open(img_path)
img_array = np.array(img_PIL)     # 转换为 numpy.array 数组
print(img_array.shape)
# 要保证dataformats和图像数组的shape对应
writer.add_image("image", img_array, 0, dataformats="HWC")

writer.close()

通过TensorBoard,可以直观的查看不同训练步长下图形的变化。

相关推荐
中科院提名者几秒前
BERT 模型的运行机制及DistilBERT 的蒸馏压缩过程
人工智能·深度学习·bert
李二。几秒前
鸿蒙原生ArkTS-太空探索新闻AI
人工智能·华为·harmonyos
乐于分享的阿乐几秒前
(二)VSCode搭建python环境(详细图文保姆级教程)
ide·vscode·python
z小猫不吃鱼1 分钟前
14 BERT 的 Masked Language Modeling 详解
人工智能
努力的章鱼bro1 分钟前
CUDA编程入门
c++·人工智能·cuda
Bode_20023 分钟前
移动多智能体现场柔性测量与自适应质检的难点与实现路径
人工智能·计算机视觉·制造
Honker_yhw3 分钟前
大数据管理与应用系列丛书《数据挖掘》(吕欣等著)读书笔记-集成学习与 AdaBoost
人工智能·数据挖掘·集成学习
weixin_408099674 分钟前
2026 AI生成图片快速去水印的5种实测方法(附在线工具 + Python/Java/PHP API代码)
java·人工智能·python·api接口·ai去水印·石榴智能·自动去水印
云智慧AIOps社区5 分钟前
直击BEYOND Expo 2026 | 云智慧Cloudwise亮相澳门,发布“三层战略”护航 AI 数实共生
运维·人工智能·运维自动化·ai基础设施可靠性
行业研究员7 分钟前
2026 AI Agent记忆解决方案:腾讯云数据库提供全场景支撑
数据库·人工智能·腾讯云·ai记忆