大模型面试题12:Torch的基本操作

一、PyTorch 核心操作(全面深度版)

(一)张量基础:创建、类型、设备与属性
1. 张量创建(覆盖所有常用方式)
方法 用途 示例
torch.tensor() 从列表/数组创建(拷贝数据) torch.tensor([1,2,3], dtype=torch.float32)
torch.empty() 创建未初始化张量 torch.empty(2,3) → 随机值(未初始化)
torch.zeros() 全0张量 torch.zeros((2,3), device='cuda')(指定GPU)
torch.ones() 全1张量 torch.ones_like(tensor)(匹配已有张量形状/类型)
torch.rand() 0-1均匀分布随机张量 torch.rand(2,3)
torch.randn() 标准正态分布随机张量 torch.randn_like(tensor)
torch.arange() 整数序列张量 torch.arange(0, 10, 2) → [0,2,4,6,8]
torch.linspace() 等间隔序列张量 torch.linspace(0, 1, 5) → [0,0.25,0.5,0.75,1]
torch.eye() 单位矩阵 torch.eye(3) → 3×3单位矩阵
2. 数据类型与设备(关键属性)
python 复制代码
# 数据类型转换
x = torch.tensor([1,2])
x = x.to(dtype=torch.float64)  # 或 x.float()/x.int()/x.long()

# 设备切换(CPU/GPU)
x = x.to(device='cuda')  # 等价于 x.cuda()(需GPU可用)
x = x.cpu()              # 切回CPU

# 核心属性
print(x.dtype)   # 数据类型(如torch.float32)
print(x.device)  # 存储设备(如cuda:0/cpu)
print(x.shape)   # 形状(等价于x.size())
print(x.ndim)    # 维度数
print(x.requires_grad)  # 是否追踪梯度
3. 索引与切片(多维张量通用)
python 复制代码
x = torch.randn(3,4,5)  # 3×4×5张量

# 基础索引(维度依次索引)
x[0]          # 取第0个维度,shape=(4,5)
x[0, 1]       # 取第0维第1行,shape=(5,)
x[0, 1, 2]    # 标量

# 切片(start:end:step)
x[:, 1:3, ::2]  # 所有0维,1-2行(左闭右开),列步长2 → shape=(3,2,3)

# 布尔索引
mask = x > 0  # 布尔张量,shape同x
x[mask]       # 取出所有>0的元素(1D张量)

# 高级索引(花式索引)
x[[0,2], :, [1,3]]  # 0/2维,所有行,1/3列 → shape=(2,4,2)
(二)维度操作(补充细节与边界案例)
1. 转置/维度重排(内存细节)
python 复制代码
x = torch.randn(2,3,4)

# transpose:仅交换两个维度,返回视图(共享内存)
x_t = x.transpose(0,2)  # shape=(4,3,2)
print(x_t.is_contiguous())  # False(非连续内存)

# permute:重排所有维度,同样返回视图
x_p = x.permute(2,1,0)  # shape=(4,3,2)

# 连续化(非连续张量无法view,需contiguous())
x_p_contig = x_p.contiguous()
2. 升维/降维(注意dim范围)
python 复制代码
x = torch.randn(3)  # shape=(3,)

# unsqueeze:dim范围[-ndim-1, ndim](负数为倒数维度)
x_1 = x.unsqueeze(0)  # shape=(1,3)
x_2 = x.unsqueeze(-1) # shape=(3,1)

# squeeze:仅移除size=1的维度,无op则移除所有
y = torch.randn(1,3,1,4)
y_s1 = y.squeeze(0)   # shape=(3,1,4)(仅移除0维)
y_s2 = y.squeeze()    # shape=(3,4)(移除所有size=1维度)
3. 拼接(cat)vs 堆叠(stack)(核心对比)
操作 核心差异 失败案例(原因)
torch.cat 沿已有维度拼接,维度数不变 cat([torch.rand(2,3), torch.rand(2,4)], dim=1) → 成功; cat([torch.rand(2,3), torch.rand(3,3)], dim=0) → 成功; cat([torch.rand(2,3), torch.rand(2,3)], dim=2) → 失败(无2维)
torch.stack 沿新维度堆叠,维度数+1 stack([torch.rand(2,3), torch.rand(2,3)], dim=2) → shape=(2,3,2)(成功); stack([torch.rand(2,3), torch.rand(3,3)], dim=0) → 失败(形状不一致)
4. 形状修改(view vs reshape vs resize_)
方法 内存特性 适用场景 示例
view() 仅连续张量可用,返回视图 确定张量连续,需共享内存 x.contiguous().view(6)
reshape() 非连续时自动拷贝,通用 不确定连续性,优先用 x.transpose(0,1).reshape(6)
resize_() 原地修改,可扩容(补0) 需原地改形状,允许维度不匹配 x.resize_(4) → 若原长度>4则截断,<4则补0
(三)乘法(极致详细:规则+边界+应用)
1. 逐元素乘法(* / torch.mul
  • 规则 :对应位置元素相乘,支持广播机制(核心!);

  • 广播规则:从最后一维开始匹配,维度大小相同/其中一个为1/缺失则广播;

  • 示例

    python 复制代码
    # 基础(形状一致)
    a = torch.tensor([[1,2],[3,4]])
    b = torch.tensor([[5,6],[7,8]])
    print(a * b)  # [[5,12],[21,32]]
    
    # 广播((2,2) × (1,2))
    c = torch.tensor([[5,6]])
    print(a * c)  # [[5,12],[15,24]](c广播为(2,2))
    
    # 广播((2,2) × (2,1))
    d = torch.tensor([[5],[7]])
    print(a * d)  # [[5,10],[21,28]](d广播为(2,2))
2. 矩阵乘法(@ / torch.matmul / torch.mm
方法 维度支持 广播支持 示例(shape变化)
torch.mm() 仅2D张量 mm((2,3), (3,4)) → (2,4)
torch.matmul 2D/3D/更高维 matmul((5,2,3), (3,4)) → (5,2,4)(广播)
@ 同matmul 等价于matmul,语法糖
  • matmul多维广播核心规则
    对于张量 a (..., n, m)b (..., m, p),结果为 (..., n, p)
    若维度数不一致,自动在缺失维度补1后广播:

    python 复制代码
    # 3D × 2D(广播)
    a = torch.randn(5,2,3)  # (batch, n, m)
    b = torch.randn(3,4)     # (m, p) → 广播为(5,3,4)
    print(torch.matmul(a, b).shape)  # (5,2,4)
    
    # 1D × 2D(特殊:1D视为行向量)
    c = torch.randn(3)       # (m,) → 视为(1,3)
    d = torch.randn(3,4)     # (m,p)
    print(torch.matmul(c, d).shape)  # (4,)(等价于(1,3)@(3,4) → (1,4) → 降维)
3. 批量矩阵乘法(torch.bmm
  • 规则 :仅3D张量,批量内逐矩阵相乘,无广播

  • 要求a (b, n, m) + b (b, m, p)(b, n, p)(b为batch数,必须一致);

  • 示例

    python 复制代码
    a = torch.randn(5,2,3)
    b = torch.randn(5,3,4)
    print(torch.bmm(a, b).shape)  # (5,2,4)
    
    # 错误案例(batch数不一致)
    b_bad = torch.randn(6,3,4)
    torch.bmm(a, b_bad)  # 报错:batch size mismatch
4. 点积(torch.dot
  • 规则:仅1D张量,逐元素乘积求和(输出标量);

  • 注意 :高维张量会先展平为1D,不推荐用于高维!

    python 复制代码
    # 正确用法(1D)
    a = torch.tensor([1,2,3])
    b = torch.tensor([4,5,6])
    print(torch.dot(a, b))  # 1×4 + 2×5 + 3×6 = 32
    
    # 不推荐(高维展平)
    c = torch.tensor([[1,2],[3,4]])
    d = torch.tensor([[5,6],[7,8]])
    print(torch.dot(c, d))  # 1×5 + 2×6 + 3×7 + 4×8 = 70(展平后计算)
5. 外积(torch.outer
  • 规则 :仅1D张量,a(i) × b(j) 生成2D矩阵 (len(a), len(b))

  • 示例

    python 复制代码
    a = torch.tensor([1,2])
    b = torch.tensor([3,4,5])
    print(torch.outer(a, b))
    # [[1×3, 1×4, 1×5],
    #  [2×3, 2×4, 2×5]] → [[3,4,5],[6,8,10]]
6. 克罗内克积(torch.kron
  • 规则A ⊗ B,将A的每个元素与B相乘并扩展,维度为 (A.shape[0]*B.shape[0], A.shape[1]*B.shape[1])

  • 应用:矩阵扩展、信号处理;

  • 示例

    python 复制代码
    a = torch.tensor([[1,2],[3,4]])
    b = torch.tensor([[0,1],[1,0]])
    print(torch.kron(a, b))
    # [[1×0,1×1, 2×0,2×1],
    #  [1×1,1×0, 2×1,2×0],
    #  [3×0,3×1, 4×0,4×1],
    #  [3×1,3×0, 4×1,4×0]] → [[0,1,0,2],[1,0,2,0],[0,3,0,4],[3,0,4,0]]
7. 乘法速查表(避坑)
需求场景 推荐方法 禁止方法
2D矩阵相乘 matmul/@ */mul
批量3D矩阵相乘(无广播) bmm mm
批量3D矩阵相乘(有广播) matmul bmm
1D张量求和乘积 dot mm
1D张量两两相乘生成2D outer matmul(需reshape)
(四)其他核心操作(补充)
1. 广播机制(独立规则)
  • 核心:维度从后往前匹配,满足以下条件则广播:
    1. 维度大小相同;
    2. 其中一个维度大小为1;
    3. 其中一个张量缺失该维度。
  • 示例:(3,1,4) + (2,4) → 广播为 (3,2,4) + (3,2,4)
2. 内存管理
python 复制代码
# 连续张量判断与转换
x = x.transpose(0,1)
print(x.is_contiguous())  # False
x = x.contiguous()        # 转换为连续张量(拷贝数据)

# 共享内存(视图)vs 拷贝
y = x.view(6)    # 视图,共享内存
z = x.clone()    # 深拷贝,独立内存
3. 梯度相关
python 复制代码
x = torch.tensor([1.0,2.0], requires_grad=True)  # 开启梯度追踪
y = x * 2
y.backward(torch.tensor([1.0,1.0]))  # 反向传播
print(x.grad)  # 梯度值:[2.0, 2.0]

以上内容覆盖了Pandas从基础到进阶的全场景操作,包含详细规则、示例、边界案例和避坑指南,可直接落地到实战场景中。

相关推荐
aitoolhub2 小时前
在线PPT设计工具深度测评:功能覆盖与用户体验对比
深度学习·powerpoint·ux·在线设计
java1234_小锋3 小时前
Transformer 大语言模型(LLM)基石 - Transformer简介
深度学习·语言模型·llm·transformer·大语言模型
子午3 小时前
【垃圾识别系统】Python+TensorFlow+Django+人工智能+深度学习+卷积神经网络算法
人工智能·python·深度学习
咚咚王3 小时前
人工智能之数据分析 Pandas:第八章 数据可视化
人工智能·数据分析·pandas
水木姚姚4 小时前
PyTorch在Microsft windows 11下的使用
人工智能·pytorch·windows
shayudiandian4 小时前
用FastAPI部署深度学习模型
人工智能·深度学习·fastapi
JoannaJuanCV4 小时前
深度学习框架keras使用—(1)CNN经典模型:VGGNet
深度学习·cnn·keras
_oP_i4 小时前
常见、主流、可靠的机器学习与深度学习训练集网站
人工智能·深度学习·机器学习