PyTorch API 详细中文文档,基于PyTorch2.5


PyTorch API 详细中文文档

按模块分类,涵盖核心函数与用法示例


目录

  1. 张量操作 (Tensor Operations)
  2. 数学运算 (Math Operations)
  3. 自动求导 (Autograd)
  4. 神经网络模块 (torch.nn)
  5. 优化器 (torch.optim)
  6. 数据加载与处理 (torch.utils.data)
  7. 设备管理 (Device Management)
  8. 模型保存与加载
  9. 分布式训练 (Distributed Training)
  10. 实用工具函数

1. 张量操作 (Tensor Operations)

1.1 张量创建
函数 描述 示例
torch.tensor(data, dtype, device) 从数据创建张量 torch.tensor([1,2,3], dtype=torch.float32)
torch.zeros(shape) 创建全零张量 torch.zeros(2,3)
torch.ones(shape) 创建全一张量 torch.ones(5)
torch.rand(shape) 均匀分布随机张量 torch.rand(3,3)
torch.randn(shape) 标准正态分布张量 torch.randn(4,4)
torch.arange(start, end, step) 创建等差序列 torch.arange(0, 10, 2)[0,2,4,6,8]
torch.linspace(start, end, steps) 线性间隔序列 torch.linspace(0, 1, 5)[0, 0.25, 0.5, 0.75, 1]
1.2 张量属性
属性/方法 描述 示例
.shape 张量维度 x = torch.rand(2,3); x.shape → torch.Size([2,3])
.dtype 数据类型 x.dtype → torch.float32
.device 所在设备 x.device → device(type='cpu')
.requires_grad 是否追踪梯度 x.requires_grad = True
1.3 张量变形
函数 描述 示例
.view(shape) 调整形状(不复制数据) x = torch.arange(6); x.view(2,3)
.reshape(shape) 类似 view,但自动处理内存连续性 x.reshape(3,2)
.permute(dims) 调整维度顺序 x = torch.rand(2,3,4); x.permute(1,2,0)
.squeeze(dim) 去除大小为1的维度 x = torch.rand(1,3); x.squeeze(0)shape [3]
.unsqueeze(dim) 添加大小为1的维度 x = torch.rand(3); x.unsqueeze(0)shape [1,3]

2. 数学运算 (Math Operations)

2.1 逐元素运算
函数 描述 示例
torch.add(x, y) 加法 torch.add(x, y)x + y
torch.mul(x, y) 乘法 torch.mul(x, y)x * y
torch.exp(x) 指数运算 torch.exp(torch.tensor([1.0]))[2.7183]
torch.log(x) 自然对数 torch.log(torch.exp(tensor([2.0])))[2.0]
torch.clamp(x, min, max) 限制值范围 torch.clamp(x, min=0, max=1)
2.2 矩阵运算
函数 描述 示例
torch.matmul(x, y) 矩阵乘法 x = torch.rand(2,3); y = torch.rand(3,4); torch.matmul(x, y)
torch.inverse(x) 矩阵求逆 x = torch.rand(3,3); inv_x = torch.inverse(x)
torch.eig(x) 特征值分解 eigenvalues, eigenvectors = torch.eig(x)
2.3 统计运算
函数 描述 示例
torch.sum(x, dim) 沿维度求和 x = torch.rand(2,3); torch.sum(x, dim=1)
torch.mean(x, dim) 沿维度求均值 torch.mean(x, dim=0)
torch.max(x, dim) 沿维度求最大值 values, indices = torch.max(x, dim=1)
torch.argmax(x, dim) 最大值索引 indices = torch.argmax(x, dim=1)

3. 自动求导 (Autograd)

3.1 梯度计算
函数/属性 描述 示例
x.backward() 反向传播计算梯度 x = torch.tensor(2.0, requires_grad=True); y = x**2; y.backward()
x.grad 查看梯度值 x.grad4.0(若 y = x²
torch.no_grad() 禁用梯度追踪 with torch.no_grad(): y = x * 2
detach() 分离张量(不追踪梯度) y = x.detach()
3.2 梯度控制
函数 描述
x.retain_grad() 保留非叶子节点的梯度
torch.autograd.grad(outputs, inputs) 手动计算梯度

示例

python 复制代码
x = torch.tensor(3.0, requires_grad=True)  
y = x**3 + 2*x  
dy_dx = torch.autograd.grad(y, x)  # 返回 (torch.tensor(29.0),)  

4. 神经网络模块 (torch.nn)

4.1 层定义
描述 示例
nn.Linear(in_features, out_features) 全连接层 layer = nn.Linear(784, 256)
nn.Conv2d(in_channels, out_channels, kernel_size) 卷积层 conv = nn.Conv2d(3, 16, kernel_size=3)
nn.LSTM(input_size, hidden_size) LSTM 层 lstm = nn.LSTM(100, 50)
nn.Dropout(p=0.5) Dropout 层 dropout = nn.Dropout(0.2)
4.2 激活函数
函数 描述 示例
nn.ReLU() ReLU 激活 F.relu(x)nn.ReLU()(x)
nn.Sigmoid() Sigmoid 函数 torch.sigmoid(x)
nn.Softmax(dim) Softmax 归一化 F.softmax(x, dim=1)
4.3 损失函数
描述 示例
nn.MSELoss() 均方误差 loss_fn = nn.MSELoss()
nn.CrossEntropyLoss() 交叉熵损失 loss = loss_fn(outputs, labels)
nn.BCELoss() 二分类交叉熵 loss_fn = nn.BCELoss()

5. 优化器 (torch.optim)

5.1 优化器定义
描述 示例
optim.SGD(params, lr) 随机梯度下降 optimizer = optim.SGD(model.parameters(), lr=0.01)
optim.Adam(params, lr) Adam 优化器 optimizer = optim.Adam(model.parameters(), lr=0.001)
optim.RMSprop(params, lr) RMSprop 优化器 optimizer = optim.RMSprop(params, lr=0.01)
5.2 优化器方法
方法 描述 示例
optimizer.zero_grad() 清空梯度 optimizer.zero_grad()
optimizer.step() 更新参数 loss.backward(); optimizer.step()
optimizer.state_dict() 获取优化器状态 state = optimizer.state_dict()

6. 数据加载与处理 (torch.utils.data)

6.1 数据集类
类/函数 描述 示例
Dataset 自定义数据集基类 继承并实现 __len____getitem__
DataLoader(dataset, batch_size, shuffle) 数据加载器 loader = DataLoader(dataset, batch_size=64, shuffle=True)

自定义数据集示例

python 复制代码
class MyDataset(Dataset):  
    def __init__(self, data, labels):  
        self.data = data  
        self.labels = labels  
    def __len__(self):  
        return len(self.data)  
    def __getitem__(self, idx):  
        return self.data[idx], self.labels[idx]  
6.2 数据预处理 (TorchVision)
python 复制代码
from torchvision import transforms  

transform = transforms.Compose([  
    transforms.Resize(256),          # 调整图像大小  
    transforms.ToTensor(),           # 转为张量  
    transforms.Normalize(mean=[0.5], std=[0.5])  # 标准化  
])  

7. 设备管理 (Device Management)

7.1 设备切换
函数/方法 描述 示例
.to(device) 移动张量/模型到设备 x = x.to('cuda:0')
torch.cuda.is_available() 检查 GPU 是否可用 if torch.cuda.is_available(): ...
torch.cuda.empty_cache() 清空 GPU 缓存 torch.cuda.empty_cache()

8. 模型保存与加载

函数 描述 示例
torch.save(obj, path) 保存对象(模型/参数) torch.save(model.state_dict(), 'model.pth')
torch.load(path) 加载对象 model.load_state_dict(torch.load('model.pth'))
model.state_dict() 获取模型参数字典 params = model.state_dict()

9. 分布式训练 (Distributed Training)

函数/类 描述 示例
nn.DataParallel(model) 单机多卡并行 model = nn.DataParallel(model)
torch.distributed.init_process_group() 初始化分布式训练 需配合多进程使用

10. 实用工具函数

函数 描述 示例
torch.cat(tensors, dim) 沿维度拼接张量 torch.cat([x, y], dim=0)
torch.stack(tensors, dim) 堆叠张量(新建维度) torch.stack([x, y], dim=1)
torch.split(tensor, split_size, dim) 分割张量 chunks = torch.split(x, 2, dim=0)

常见问题与技巧

  1. GPU 内存不足

    • 使用 batch_size 较小的值
    • 启用混合精度训练 (torch.cuda.amp)
    • 使用 torch.utils.checkpoint 节省内存
  2. 梯度爆炸/消失

    • 使用梯度裁剪:torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    • 调整权重初始化方法
  3. 模型推理模式

    python 复制代码
    model.eval()  # 关闭 Dropout 和 BatchNorm 的随机性  
    with torch.no_grad():  
        outputs = model(inputs)  

文档说明

  • 本文档基于 PyTorch 2.5 编写,部分 API 可能不兼容旧版本。
  • 更详细的参数说明请参考 PyTorch 官方文档
相关推荐
snpgroupcn4 分钟前
SNP亮相2026思爱普中国峰会,助力企业加速数据价值兑现
人工智能
IT乐手4 分钟前
Anthropic 为何限制中国大陆使用 Claude?
人工智能
To_OC7 分钟前
用 ESM 模块化搭建 DeepSeek LLM 调用,顺带用 Prompt 实现轻量 NLP 任务
人工智能·nlp·deepseek
jrjrgood10 分钟前
现货黄金和黄金期货的区别有哪些?如何投资?
大数据·人工智能·区块链
属于自己的天空13 分钟前
确认弹窗太多?一次配好 Claude Code 权限,安心让 AI 干活
人工智能
dearxue19 分钟前
这一次,我们一起把AI的复杂一口吃掉
人工智能·后端
行者-全栈开发25 分钟前
深度解析 WWDC 2026:苹果 AI 全栈技术架构与落地实现路径
人工智能·架构·wwdc
企业老板ai培训28 分钟前
2026中小企业AI应用落地白皮书:从AI短视频矩阵到数字人获客的破局增长趋势
人工智能·矩阵·音视频
SEO_juper41 分钟前
博客文章黄金结构:开头 1 句痛点 + 3 小标题 + 对比 + 总结 + 下载
人工智能·博客·外贸·geo·独立站·跨境电商独立站·文章结构
双翌视觉43 分钟前
工业AI视觉检测中的“小样本困境”
人工智能·计算机视觉·视觉检测