PyTorch 基础知识
基础知识
查看GPU
bash
nvidia-smi
查看PyTorch是否正常调用GPU
python
import torch
torch.cuda.is_available()
启动jupyter notebook
bash
jupyter notebook
dir()函数 可以查看包(如 torch
)中有哪些模块、子包或属性,类似打开一个工具箱,查看里面的各种子工具箱。继续对其中的某个分隔区调用 dir()
,可以查看更深层的内容(如 dir(torch.cuda)
会显示 torch.cuda
内部的模块或函数)。如果 dir()
的输出里出现双下划线 __
包裹的名称,一般说明这是特殊属性或方法,而不是一个子模块。
help()函数 可以查看函数或类的官方文档/说明书,了解它的用途和用法。例如,help(torch.cuda.is_available)
可以查看该函数是否存在,以及它返回什么结果(返回 True 或 False 来指示是否可使用 CUDA)。注意在查看函数文档时,函数名后不带括号 (如 help(torch.cuda.is_available)
而不是 help(torch.cuda.is_available())
)。
python
help(torch.cuda.is_available)
# Help on function is_available in module torch.cuda:
# is_available() -> bool
# Return a bool indicating if CUDA is currently available.
os.path.join
可以将自动使用正确的路径分隔符,将多个路径组件拼接在一起。
python
import os
path1 = r"C:\Users\lenovo\Desktop\cpp\main\java\org\example\webchat"
path2 = r"Server"
print(os.path.join(path1, path2))
# C:\Users\lenovo\Desktop\cpp\main\java\org\example\webchat\Server
os.path.split
将路径拆分为 (目录路径, 文件名)
的元组。
python
import os
file_path = "/home/user/documents/file.txt"
directory, filename = os.path.split(file_path)
print(directory, filename)
# /home/user/documents file.txt
os.path.splitext
将路径拆分为 (文件名, 文件拓展名)
的元组。
python
import os
file_path = "/home/user/documents/file.txt"
filename, extendname = os.path.splitext(file_path)
print(filename, extendname)
# /home/user/documents/file .txt
split()
是Python内置的字符串方法,可以按照指定的分隔符拆分字符串,并返回一个列表。
python
string.split(separator, maxsplit)
# separator 可选 用于拆分字符串的分隔符(默认为空格)
# maxsplit 可选 最多拆分几次(默认拆分所有)
python
text = "apple,banana,orange"
result = text.split(",")
print(result)
# ['apple', 'banana', 'orange']
从PIL库中导入Image模块,可以便捷地使用图像操作功能。
python
from PIL import Image
# 加载图像
img = Image.open('example.jpg')
# 显示图像
img.show()
# 保存图像为不同格式
img.save('new_image.png')
# 将图像转换成其他模式,灰度图L,彩色图RGB
img_gray = img.convert('L')
# 调整图像的尺寸
img_resized = img.resize((200, 200))
Python格式化输出
python
a = "hello"
print("{}, Li Hua".format(a))
Ctrl+P 可以提示函数里面需要什么参数
Dataset类
Dataset是PyTorch用于数据加载的基类,主要用于构建自定义数据集 。Dataset类可以提供索引访问数据的能力,类似 list[index]
,也可以处理不同类型的数据(图片、文本、视频等),可以在__getitem__
方法中进行数据增强或转换等预处理。
Dataset需要继承 torch.utils.data.Dataset
,并重写以下方法:
python
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
""" 初始化数据 """
self.data = data
self.labels = labels
def __getitem__(self, index):
""" 根据索引 index 返回数据 """
x = self.data[index]
y = self.labels[index]
return x, y # 返回数据和标签
def __len__(self):
""" 返回数据集大小 """
return len(self.data)
# 创建数据集
data = torch.randn(100, 3) # 100 个样本,每个样本 3 维
labels = torch.randint(0, 2, (100,)) # 100 个二分类标签
# 实例化数据集
dataset = MyDataset(data, labels)
# 访问数据
print(len(dataset)) # 100
print(dataset[0]) # 第 0 个样本的数据和标签
以蚂蚁蜜蜂二分类数据集为例,先观察文件结构:
在hymenoptera_data文件夹下,分为train训练集和val验证集,在训练集中,分为ants和bees两个文件夹,每个文件夹下是对应的图片。
python
from torch.utils.data import Dataset # 导入Dataset类,用于创建自定义数据集
from PIL import Image # 导入PIL库,用于读取图片
import os # 导入os库,用于处理文件路径
# 创建一个继承自Dataset类的自定义数据集类
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
"""
初始化方法,创建数据集实例时调用
root_dir: 数据集根目录,包含了所有类别的文件夹
label_dir: 类别文件夹名(例如"ants"或"bees")
"""
self.root_dir = root_dir # 保存数据集根目录路径
self.label_dir = label_dir # 保存类别文件夹名
self.path = os.path.join(self.root_dir, self.label_dir) # 构建当前类别的完整路径
# 获取该类别下所有图片文件名,生成图片地址列表
# ['0013035.jpg', '1030023514_aad5c608f9.jpg', ... ]
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
"""
根据给定的索引返回一张图片及其标签
idx: 索引值
return: 返回指定索引处的图片和标签
"""
img_name = self.img_path[idx] # 获取图片文件名
# 构建图片的完整路径
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path) # 使用PIL读取图像文件
label = self.label_dir # 标签就是文件夹名(例如"ants"或"bees")
return img, label # 返回图片和标签
def __len__(self):
"""
返回数据集中图片的数量
return: 数据集中的图片数量
"""
return len(self.img_path) # 数据集大小,等于图片列表的长度
# 创建数据集实例
# 数据集根目录
root_dir = r"D:\deeplearning_ai_books\tudui\Dataset\hymenoptera_data\train"
ants_label_dir = "ants" # 蚂蚁类别的文件夹名
bees_label_dir = "bees" # 蜜蜂类别的文件夹名
# 创建蚂蚁类别和蜜蜂类别的两个数据集实例
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
# 获取蚂蚁数据集中的第一张图片及其标签
img, label = ants_dataset[0]
img.show() # 可视化第一张图片
# 将蚂蚁数据集和蜜蜂数据集合并为一个训练数据集
train_dataset = ants_dataset + bees_dataset
Torchvision DataSets
torchvision 主要模块
模块 | 作用 |
---|---|
torchvision.datasets |
提供标准数据集(如 CIFAR-10, MNIST, COCO) |
torchvision.io |
处理输入输出(较少用) |
torchvision.models |
预训练模型(如 ResNet, VGG, Faster R-CNN) |
torchvision.ops |
特殊操作(较少用) |
torchvision.transforms |
提供图像变换(Resize, ToTensor, Normalize 等) |
torchvision.utils |
提供工具(如 TensorBoard 可视化) |
torchvision 中提供了多种标准数据集,可以方便地进行下载、加载和预处理,并结合 transforms 进行数据增强和格式转换。
以 CIFAR-10 数据集为例,共有10类,每类6000张图片。图片默认是 PIL Image,标签是整数 0-9,可通过 CIFAR10.classes
获取对应类别名称。
python
# 构造函数
torchvision.datasets.CIFAR10(
root: str, # 数据集存储路径(字符串)
train: bool = True, # True 加载训练集,False 加载测试集
transform: Optional[Callable] = None, # 图像转换函数
target_transform: Optional[Callable] = None, # 标签转换函数
download: bool = False # True自动下载数据集,False加载本地数据
)
python
import torchvision
# 下载并加载 CIFAR-10 数据集
train_set = torchvision.datasets.CIFAR10(root="./Dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./Dataset", train=False, download=True)
print(train_set[0])
# (<PIL.Image.Image image mode=RGB size=32x32 at 0x2537AFD81F0>, 6)
# 查看测试集的第一张图片和类别
img, target = train_set[0]
# 查看所有类别
print(train_set.classes)
# ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 查看具体类别
print(train_set.classes[target])
# frog
# 显示图片
img.show()
由于CIFAR-10原始数据是 PIL Image格式,无法直接用于PyTorch训练。所以可以在数据集下载的过程中就对其进行 transforms 处理,如 ToTensor
操作。
python
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_tranforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="./Dataset", train=True,
transform=dataset_tranforms, download=True)
test_set = torchvision.datasets.CIFAR10(root="./Dataset", train=False,
transform=dataset_tranforms, download=True)
writer = SummaryWriter("logs")
for i in range(10):
img, _ = train_set[i]
writer.add_image("CIFAR-10", img, i)
writer.close()
DataLoader类
DataLoader 是 PyTorch 提供的数据加载器 ,用于从数据集中批量采样,并提供数据预处理、打乱顺序、多进程加载等功能。
Dataset类相当于一叠扑克牌,存储了所有数据。而 DataLoader 类相当于发牌,每次从 Dataset 中取出一定数量(batch_size)的数据。
python
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset, # 传入数据集
batch_size=64, # 每个 batch 取 64 张图片
shuffle=True, # 是否打乱数据
num_workers=0, # 线程数(Windows 需设为 0)
drop_last=False # 是否丢弃最后一个数量不满 batch_size 的 batch
)
batch_size
决定每次加载多少数据,batch_size=len(dataset)
一次性加载所有数据。
drop_last
的作用是如果数据集的大小不是 batch_size
的整数倍,是否丢弃最后不足 batch_size
的部分。
shuffle
的作用是决定是否在每个 epoch
重新打乱数据。一般训练集需要打乱数据,而测试集不需要打乱数据。如果**shuffle=True
** 那么两轮 epoch 取的数据是不同的顺序。
取单个数据:
python
train_set = torchvision.datasets.CIFAR10(root="./Dataset", train=True,
transform=dataset_tranforms, download=True)
img, target = train_set[0]
print(img.shape)
# torch.Size([3, 32, 32])
print(target)
# 6
批量取数据:
python
train_set = torchvision.datasets.CIFAR10(root="./Dataset", train=True,
transform=dataset_tranforms, download=True)
dataloader = DataLoader(dataset=train_set, batch_size=64, shuffle=True, drop_last=False)
for data in dataloader:
imgs, targets = data
print(imgs.shape)
print(targets.shape)
# torch.Size([64, 3, 32, 32])
# torch.Size([64])
# ......
# torch.Size([16, 3, 32, 32])
# torch.Size([16])
结合 DataLoader 与 SummaryWriter 进行可视化:
python
from torch.utils.data import DataLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
# 预处理
dataset_tranforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 加载数据集
train_set = torchvision.datasets.CIFAR10(root="./Dataset", train=True,
transform=dataset_tranforms, download=True)
# 创建DataLoader
dataloader = DataLoader(dataset=train_set, batch_size=64, shuffle=True, drop_last=False)
step = 0
for data in dataloader:
imgs, _ = data
writer.add_images("DataLoader", imgs, global_step=step)
step = step + 1
writer.close()
Transforms
torchvision.transforms
提供了一系列图像预处理工具,用于对图像进行变换,如尺寸调整、归一化、格式转换等,以便在训练模型时使用。
python
from torchvision import transforms
两种读取图片的方式
python
import cv2
from PIL import Image
img_path = r"Dataset/hymenoptera_data/train/ants/0013035.jpg"
# **OpenCV**
# cv2 处理速度快,但**通道顺序为** **BGR**
img_cv = cv2.imread(img_path)
print(type(img_cv))
# <class 'numpy.ndarray'>
# **PIL Image**
# 常见的 Python 图像库,易用但不适合计算
img_PIL = Image.open(img_path)
print(type(img_PIL))
# <class 'PIL.JpegImagePlugin.JpegImageFile'>
但是 PIL Image 和 numpy.ndarray 是普通图像格式,无法直接用于神经网络训练。Tensor 更适用于 PyTorch,支持 GPU 加速计算。同时 Tensor 包含梯度信息 ,可用于反向传播训练 ,形状为 (C, H, W)
(通道数、高度、宽度)。
ToTensor()
函数可以将将 PIL Image 或 numpy.ndarray 转换为 Tensor(同时归一化到 [0,1])。
python
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image
writer = SummaryWriter("logs")
img_path = r"C:\Users\lenovo\Desktop\camera\QIXIA\niu.jpg"
img_PIL = Image.open(img_path)
# transforms是一个类, ToTensor是类的方法,因此**调用时需要先实例化**
trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img_PIL)
writer.add_image("Image_Tensor", img_tensor, global_step=0, dataformats="CHW")
print(type(img_tensor))
# <class 'torch.Tensor'>
print(img_tensor.shape)
# torch.Size([3, 3712, 5568])
writer.close()
ToPILImage()
函数可以将 Tensor 或 numpy.ndarray 转换为 PIL Image,适用于显示转换后的图片,以及方便使用 PIL 处理图片
python
img_ToPIL = transforms.ToPILImage()
img_TOPIL = img_ToPIL(img_tensor)
print(type(img_TOPIL))
# <class 'PIL.Image.Image'>
Normalize()
函数对 Tensor 进行归一化,转换为自定义均值、标准差的分布。
o u t p u t = i n p u t − m e a n s t d output = \frac{input - mean}{std} output=stdinput−mean
如果原始图像的像素值在 [ 0 , 1 ] [0,1] [0,1] 范围内,并且设置均值和标准差为0.5,那么归一化后的像素值范围为 [ − 1 , 1 ] [-1,1] [−1,1] 。
python
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
print(img_tensor[0][0][0])
# tensor(0.3137)
img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0])
# tensor(-0.3725)
Resize()
函数可以调整图片尺寸。
python
trans_resize = transforms.Resize((512, 512))
img_resize = trans_resize(img_tensor)
print(img_tensor.shape, img_resize.shape)
# torch.Size([3, 3712, 5568]) torch.Size([3, 512, 512])
等比例缩放,即只缩放短边到固定值,其他自动调整以保持宽高比
python
trans_resize_2 = transforms.Resize(512)
img_resize_2 = trans_resize_2(img_tensor)
print(img_tensor.shape, img_resize_2.shape)
# torch.Size([3, 3712, 5568]) torch.Size([3, 512, 768])
RandomCrop()
函数可以随机裁剪图片。
python
# 随机裁剪 500×1000 区域 5 次
trans_crop = transforms.RandomCrop((500, 1000))
for i in range(5):
img_crop = trans_crop(img_tensor)
CenterCrop()
函数可以裁剪位于图像正中心固定大小。
python
trans_center = transforms.CenterCrop(1000)
img_center = trans_center(img_tensor)
Compose()
函数可以组合并按顺序处理多个预处理变换。
python
# 用列表的形式列出所有的预处理变换组合
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
transforms.Resize(1024),
transforms.RandomCrop((500, 1000))
])
transformed = transform(img_PIL)
print(transformed.shape)
# torch.Size([3, 500, 1000])
TensorBoard
TensorBoard可以可视化训练过程 Loss 的变化,以及模型在不同阶段的输出
启动TensorBoard
bash
tensorboard --logdir=logs # 默认端口6006
tensorboard --logdir=logs --port=6007 # 指定端口
SummaryWriter
类用于创建日志文件并记录训练过程,会直接向自定义 log_dir
文件夹写入事件文件,供TensorBoard 解析。
python
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir="logs") # 事件文件存储到自定义的logs目录
writer.close()
如果数据有误,就删除logdir文件夹,重新运行代码,终端终止TensorBoard后重新启动
add_scalar()
方法可以记录标量数据,如训练损失等
python
add_scalar(tag, scalar_value, global_step)
# tag 数据标签(类似图表标题)
# scalar_value 需要保存的数值(y 轴)
# global_step 当前训练步数(x 轴)
python
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
for i in range(100):
writer.add_scalar("y=2x", 2*i, i) # (标题, y轴数值, x轴步数)
writer.close()
add_image()
方法可以记录图像数据
python
add_image(tag, img_tensor, global_step, dataformats="CHW")
# tag 图像的名称(类似图表标题)
# img_tensor 图片数据(支持 **torch.Tensor、numpy.array**)
# global_step 当前训练步数(用于对比不同时期的结果)
# dataformats HWC(高度、宽度、通道)
python
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np
writer = SummaryWriter("logs")
img_path = r"Dataset/hymenoptera_data/train/ants/0013035.jpg"
img_PIL = Image.open(img_path)
img_array = np.array(img_PIL) # 转换为 numpy.array 数组
print(img_array.shape)
# 要保证dataformats和图像数组的shape对应
writer.add_image("image", img_array, 0, dataformats="HWC")
writer.close()
通过TensorBoard,可以直观的查看不同训练步长下图形的变化。