PyTorch API 详细中文文档
按模块分类,涵盖核心函数与用法示例
目录
- 张量操作 (Tensor Operations)
- 数学运算 (Math Operations)
- 自动求导 (Autograd)
- 神经网络模块 (torch.nn)
- 优化器 (torch.optim)
- 数据加载与处理 (torch.utils.data)
- 设备管理 (Device Management)
- 模型保存与加载
- 分布式训练 (Distributed Training)
- 实用工具函数
1. 张量操作 (Tensor Operations)
1.1 张量创建
函数 | 描述 | 示例 |
---|---|---|
torch.tensor(data, dtype, device) |
从数据创建张量 | torch.tensor([1,2,3], dtype=torch.float32) |
torch.zeros(shape) |
创建全零张量 | torch.zeros(2,3) |
torch.ones(shape) |
创建全一张量 | torch.ones(5) |
torch.rand(shape) |
均匀分布随机张量 | torch.rand(3,3) |
torch.randn(shape) |
标准正态分布张量 | torch.randn(4,4) |
torch.arange(start, end, step) |
创建等差序列 | torch.arange(0, 10, 2) → [0,2,4,6,8] |
torch.linspace(start, end, steps) |
线性间隔序列 | torch.linspace(0, 1, 5) → [0, 0.25, 0.5, 0.75, 1] |
1.2 张量属性
属性/方法 | 描述 | 示例 |
---|---|---|
.shape |
张量维度 | x = torch.rand(2,3); x.shape → torch.Size([2,3]) |
.dtype |
数据类型 | x.dtype → torch.float32 |
.device |
所在设备 | x.device → device(type='cpu') |
.requires_grad |
是否追踪梯度 | x.requires_grad = True |
1.3 张量变形
函数 | 描述 | 示例 |
---|---|---|
.view(shape) |
调整形状(不复制数据) | x = torch.arange(6); x.view(2,3) |
.reshape(shape) |
类似 view ,但自动处理内存连续性 |
x.reshape(3,2) |
.permute(dims) |
调整维度顺序 | x = torch.rand(2,3,4); x.permute(1,2,0) |
.squeeze(dim) |
去除大小为1的维度 | x = torch.rand(1,3); x.squeeze(0) → shape [3] |
.unsqueeze(dim) |
添加大小为1的维度 | x = torch.rand(3); x.unsqueeze(0) → shape [1,3] |
2. 数学运算 (Math Operations)
2.1 逐元素运算
函数 | 描述 | 示例 |
---|---|---|
torch.add(x, y) |
加法 | torch.add(x, y) 或 x + y |
torch.mul(x, y) |
乘法 | torch.mul(x, y) 或 x * y |
torch.exp(x) |
指数运算 | torch.exp(torch.tensor([1.0])) → [2.7183] |
torch.log(x) |
自然对数 | torch.log(torch.exp(tensor([2.0]))) → [2.0] |
torch.clamp(x, min, max) |
限制值范围 | torch.clamp(x, min=0, max=1) |
2.2 矩阵运算
函数 | 描述 | 示例 |
---|---|---|
torch.matmul(x, y) |
矩阵乘法 | x = torch.rand(2,3); y = torch.rand(3,4); torch.matmul(x, y) |
torch.inverse(x) |
矩阵求逆 | x = torch.rand(3,3); inv_x = torch.inverse(x) |
torch.eig(x) |
特征值分解 | eigenvalues, eigenvectors = torch.eig(x) |
2.3 统计运算
函数 | 描述 | 示例 |
---|---|---|
torch.sum(x, dim) |
沿维度求和 | x = torch.rand(2,3); torch.sum(x, dim=1) |
torch.mean(x, dim) |
沿维度求均值 | torch.mean(x, dim=0) |
torch.max(x, dim) |
沿维度求最大值 | values, indices = torch.max(x, dim=1) |
torch.argmax(x, dim) |
最大值索引 | indices = torch.argmax(x, dim=1) |
3. 自动求导 (Autograd)
3.1 梯度计算
函数/属性 | 描述 | 示例 |
---|---|---|
x.backward() |
反向传播计算梯度 | x = torch.tensor(2.0, requires_grad=True); y = x**2; y.backward() |
x.grad |
查看梯度值 | x.grad → 4.0 (若 y = x² ) |
torch.no_grad() |
禁用梯度追踪 | with torch.no_grad(): y = x * 2 |
detach() |
分离张量(不追踪梯度) | y = x.detach() |
3.2 梯度控制
函数 | 描述 |
---|---|
x.retain_grad() |
保留非叶子节点的梯度 |
torch.autograd.grad(outputs, inputs) |
手动计算梯度 |
示例:
python
x = torch.tensor(3.0, requires_grad=True)
y = x**3 + 2*x
dy_dx = torch.autograd.grad(y, x) # 返回 (torch.tensor(29.0),)
4. 神经网络模块 (torch.nn)
4.1 层定义
类 | 描述 | 示例 |
---|---|---|
nn.Linear(in_features, out_features) |
全连接层 | layer = nn.Linear(784, 256) |
nn.Conv2d(in_channels, out_channels, kernel_size) |
卷积层 | conv = nn.Conv2d(3, 16, kernel_size=3) |
nn.LSTM(input_size, hidden_size) |
LSTM 层 | lstm = nn.LSTM(100, 50) |
nn.Dropout(p=0.5) |
Dropout 层 | dropout = nn.Dropout(0.2) |
4.2 激活函数
函数 | 描述 | 示例 |
---|---|---|
nn.ReLU() |
ReLU 激活 | F.relu(x) 或 nn.ReLU()(x) |
nn.Sigmoid() |
Sigmoid 函数 | torch.sigmoid(x) |
nn.Softmax(dim) |
Softmax 归一化 | F.softmax(x, dim=1) |
4.3 损失函数
类 | 描述 | 示例 |
---|---|---|
nn.MSELoss() |
均方误差 | loss_fn = nn.MSELoss() |
nn.CrossEntropyLoss() |
交叉熵损失 | loss = loss_fn(outputs, labels) |
nn.BCELoss() |
二分类交叉熵 | loss_fn = nn.BCELoss() |
5. 优化器 (torch.optim)
5.1 优化器定义
类 | 描述 | 示例 |
---|---|---|
optim.SGD(params, lr) |
随机梯度下降 | optimizer = optim.SGD(model.parameters(), lr=0.01) |
optim.Adam(params, lr) |
Adam 优化器 | optimizer = optim.Adam(model.parameters(), lr=0.001) |
optim.RMSprop(params, lr) |
RMSprop 优化器 | optimizer = optim.RMSprop(params, lr=0.01) |
5.2 优化器方法
方法 | 描述 | 示例 |
---|---|---|
optimizer.zero_grad() |
清空梯度 | optimizer.zero_grad() |
optimizer.step() |
更新参数 | loss.backward(); optimizer.step() |
optimizer.state_dict() |
获取优化器状态 | state = optimizer.state_dict() |
6. 数据加载与处理 (torch.utils.data)
6.1 数据集类
类/函数 | 描述 | 示例 |
---|---|---|
Dataset |
自定义数据集基类 | 继承并实现 __len__ 和 __getitem__ |
DataLoader(dataset, batch_size, shuffle) |
数据加载器 | loader = DataLoader(dataset, batch_size=64, shuffle=True) |
自定义数据集示例:
python
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
6.2 数据预处理 (TorchVision)
python
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(256), # 调整图像大小
transforms.ToTensor(), # 转为张量
transforms.Normalize(mean=[0.5], std=[0.5]) # 标准化
])
7. 设备管理 (Device Management)
7.1 设备切换
函数/方法 | 描述 | 示例 |
---|---|---|
.to(device) |
移动张量/模型到设备 | x = x.to('cuda:0') |
torch.cuda.is_available() |
检查 GPU 是否可用 | if torch.cuda.is_available(): ... |
torch.cuda.empty_cache() |
清空 GPU 缓存 | torch.cuda.empty_cache() |
8. 模型保存与加载
函数 | 描述 | 示例 |
---|---|---|
torch.save(obj, path) |
保存对象(模型/参数) | torch.save(model.state_dict(), 'model.pth') |
torch.load(path) |
加载对象 | model.load_state_dict(torch.load('model.pth')) |
model.state_dict() |
获取模型参数字典 | params = model.state_dict() |
9. 分布式训练 (Distributed Training)
函数/类 | 描述 | 示例 |
---|---|---|
nn.DataParallel(model) |
单机多卡并行 | model = nn.DataParallel(model) |
torch.distributed.init_process_group() |
初始化分布式训练 | 需配合多进程使用 |
10. 实用工具函数
函数 | 描述 | 示例 |
---|---|---|
torch.cat(tensors, dim) |
沿维度拼接张量 | torch.cat([x, y], dim=0) |
torch.stack(tensors, dim) |
堆叠张量(新建维度) | torch.stack([x, y], dim=1) |
torch.split(tensor, split_size, dim) |
分割张量 | chunks = torch.split(x, 2, dim=0) |
常见问题与技巧
-
GPU 内存不足
- 使用
batch_size
较小的值 - 启用混合精度训练 (
torch.cuda.amp
) - 使用
torch.utils.checkpoint
节省内存
- 使用
-
梯度爆炸/消失
- 使用梯度裁剪:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
- 调整权重初始化方法
- 使用梯度裁剪:
-
模型推理模式
pythonmodel.eval() # 关闭 Dropout 和 BatchNorm 的随机性 with torch.no_grad(): outputs = model(inputs)
文档说明
- 本文档基于 PyTorch 2.5 编写,部分 API 可能不兼容旧版本。
- 更详细的参数说明请参考 PyTorch 官方文档。