面试-Torch函数

0. 连续张量和非连续张量

1.核心含义: "连续(contiguous)" 描述的是张量底层数据在内存中的存储方式。
2.连续张量: 张量的元素在内存中按 "行优先" 顺序连续排列,没有间隔,能通过 固定步长遍历 所有元素;
3.非连续张量: 经过transpose()、permute()等操作后,张量的维度顺序变了,但底层数据的存储顺序没改,导致元素在内存中不再连续,遍历需要不规则步长。

用 "书" 举例:

  • 连续张量:书按[0,0]→[0,1]→[0,2]→[1,0]→[1,1]→[1,2]的顺序摆放在一排,没有空隙;
  • 非连续张量(如转置后):维度变成[列,行],但书的摆放顺序还是原来的[0,0]→[0,1]→[0,2]→[1,0]→[1,1]→[1,2],此时要按列取数(如[0,0]→[1,0]→[0,1]→[1,1]),需要跳着找书,内存不连续。

1. torch.view()

核心作用: 重塑张量形状,采用方式是 "共享内存"(修改新张量会影响原张量的位置),要求张量是 "连续的(contiguous)",否则会报错。
特点: 不改变原始 x ,通过共享内存的方式改变张量的形状,并且仅支持连续张量。因为view()需要按固定步长重塑维度。

进一步解读: PyTorch 中像view()这类操作,并不会复制张量的底层数据,而是 创建一个新的 "视图(view) " ------ 新张量和原张量共用同一块内存空间,只是对数据的 "解读方式"(维度、步长)不同。因此,修改新张量的某个元素,原张量对应位置的元素也会同步改变,反之亦然。

python 复制代码
import torch
# 原始张量
x = torch.randn(2,3)
print("原始x shape:", x.shape) # ([2,6])

# 重塑
x_view = x.view(2,3,3) 
print("重塑x shape:", x_view.shape) # ([2,3,3])

# 验证共享内存
x_view[0,0,0] = 100.0
print("原始x[0,0]:", x[0,0]) # tensor(100.)

2. torch.reshape()

核心作用: 重塑张量形状,无需张量连续,是更推荐的通用重塑方法。
特点: reshape兼容非连续张量,view仅支持连续张量;功能上几乎等价,新手优先用reshape。

python 复制代码
import torch

# 原始张量
x = torch.arange(12).reshape(3,4) # torch.Size([3,4])
print("x shape:", x.shape)

# 重塑为[4,3]
x_trans = x.transpose(0,1) # 连续张量->非连续张量
x_reshape = x_trans.reshape(4,3) # [3,4] -> [4,3]
print("x_reshape:", x_reshape.shape)

# 展平
x_flat = x_reshape.reshape(-1) # -1 表示自动计算维度
print("x_flat shape:", x_flat.shape)

3. torch.triu()

核心作用: 提取张量的上三角部分,其余元素置 0;常用来构造因果掩码(如 Transformer 的自注意力)。
特点: 提取张量的上三角部分。其中,diagonal(对角线偏移,默认 0,diagonal=1表示主对角线以上的部分)。

python 复制代码
import torch

# 原始矩阵:torch.ones(x,y) 创建
x = torch.ones(3,3)

# 提取上三角部分,(diagonal=1:主对角线以上保留)
x_triu = torch.triu(x, diagonal=1)
print("x_triu:", x_triu)
# 输出:
# tensor([[0., 1., 1.],
#         [0., 0., 1.],
#         [0., 0., 0.]])

# 构造因果掩码矩阵
seq_len = 3
mask = torch.triu(torch.full(seqlen, seqlen), float("-inf"), diagonal=1) 
print("mask:", mask) 
# 输出:
# tensor([[-inf, -inf, -inf],
#         [ -inf, -inf, -inf],
#         [ -inf, -inf, -inf]])  

4. torch.full()

核心作用: 创建指定形状、所有元素均为固定值 的张量;常用于构造掩码(如负无穷、0/1 掩码)。
特点: 必须得传入默认值。
参数: size(张量形状)、fill_value(填充值)、device(可选,指定设备)。比 torch.ones(x,y) 和 torch.zeros(x,y) 要更灵活。

python 复制代码
import torch

# 创建2x3的全5张量 torch.full((seq, seq), float("-inf"))
x_full = torch.full((2, 3), 5.0)
print("x_full:\n", x_full)
# 输出:
# tensor([[5., 5., 5.],
#         [5., 5., 5.]])

# 创建3x3的全负无穷张量(注意力掩码常用)
mask = torch.full((3, 3), float("-inf"), device="cpu")
print("mask:\n", mask)
# 输出:
# tensor([[-inf, -inf, -inf],
#         [-inf, -inf, -inf],
#         [-inf, -inf, -inf]])

5. torch.transpose()

核心作用: 交换张量的两个维度;常用于矩阵转置、调整注意力张量的维度顺序(如[bsz, seq_len, heads]→[bsz, heads, seq_len])。
参数: dim0、dim1(要交换的两个维度索引)。

python 复制代码
import torch
# 通过 torch.randn()、torch.full()、torch.ones()创建张量
x = torch.randn(1,512,16,1024) 
x = torch.full((1,512,16,1024),"float(-inf)")
x = torch.full((1,512,16,1024)) # 报错,torch.full(tensor, value)必须得同时传入默认值、张量两个元素
x = torch.ones(1,512,16,1024)
print("x:", x) # torch.Size([bsz, seq, heads, dim])

6. torch.cat()

核心作用:指定维度 上拼接多个张量;要求除拼接维度外,其他维度形状完全一致。
参数: tensors(待拼接的张量列表)、dim(拼接维度)。

python 复制代码
import torch
# 通过 torch.randn()、torch.full()、torch.ones()创建张量
x = torch.randn(1,512,16,1024) 
x = torch.full((1,512,16,1024),"float(-inf)")
x = torch.full((1,512,16,1024)) # 报错,torch.full(tensor, value)必须得同时传入默认值、张量两个元素
x = torch.ones(1,512,16,1024)
print("x:", x) # [bsz, seq, heads, dim]

x2 = torch.randn(1,512,8,1024)

# 在维度2上进行拼接
x_cat = torch.cat([x1, x2], dim=2)
print("x_cat shape:", x_cat.shape) # torch.Size([1,512,24,1024])

# 注意力 KV 缓存拼接
past_kv = torch.randn(1, 10, 1024) # [bsz, seq, dim],这里seq代表已经处理了 10 个kv健
cur_kv = torch.randn(1, 1, 1024) # 当前 kv 键值对
new_kv = torch.cat([past_kv,new_kv], dim=1)
print("new_kv cache:", new_kv) # torch.cat([a, b], dim=c):torch.Size([1, 11, 1024])

7. torch.arange()

核心作用: 创建连续整数序列的一维张量;常用于生成索引、位置编码等。
特点: torch.arange() 是根据步长来生成张量的,没有默认值,只能生成一维张量;torch.full() 能生成任意维度张量,且支持默认值;torch.randn() 随机生成指定维度的张量,不支持默认值。
参数: start(起始值,默认 0)、end(结束值,不包含)、step(步长,默认 1)。

python 复制代码
# 生成0到9的整数:[0,1,2,...,9]
x1 = torch.arange(10)
print("x1:", x1)  # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

# 生成1到9,步长2:[1,3,5,7,9]
x2 = torch.arange(1, 10, 2)
print("x2:", x2)  # tensor([1, 3, 5, 7, 9])

# 结合size()使用:生成与张量某维度长度匹配的索引
x = torch.randn(2, 5, 8)
# 生成0到x.size(1)-1的索引(x.size(1)=5)
idx = torch.arange(x.size(1))
print("idx:", idx)  # tensor([0, 1, 2, 3, 4])

8. tensor.size() / tensor.shape

核心作用 :获取张量的形状信息;size()是方法,shape是属性,功能几乎等价。
参数 :dim(可选,指定维度索引,返回该维度的长度;不指定则返回 torch.Size 对象)。如 x.size(0) 代表张量中第一个维度的大小。

python 复制代码
x = torch.randn(2, 3, 4)

# 获取整体形状
print("x.size():", x.size())  # torch.Size([2, 3, 4])
print("x.shape:", x.shape)    # torch.Size([2, 3, 4])

# 获取指定维度的长度
print("维度0长度:", x.size(0))  # 2(批次大小)
print("维度1长度:", x.size(1))  # 3(序列长度)
print("维度2长度:", x.size(2))  # 4(特征维度)

# 解包形状(常用操作)
bsz, seq_len, hidden_dim = x.size()
print(f"批次:{bsz}, 序列长度:{seq_len}, 特征维度:{hidden_dim}")  # 批次:2, 序列长度:3, 特征维度:4

9. torch.unsqueeze() / torch.squeeze()

核心作用: 插入和删除指定维度,插入和删除的维度的长度为1.
torch.unsqueeze(tensor, dim) :在指定维度插入一个维度(维度长度为 1),常用于扩展掩码维度;
torch.squeeze(tensor, dim) :删除长度为 1 的维度,简化张量形状。

python 复制代码
# unsqueeze:扩展维度(注意力掩码常用)
mask = torch.randn(2, 3)  # torch.Size([2,3])
# 插入维度1和2:shape [2,1,1,3](匹配注意力分数维度)
mask_unsq = mask.unsqueeze(1).unsqueeze(2)
print("mask_unsq shape:", mask_unsq.shape)  # torch.Size([2, 1, 1, 3])

# squeeze:删除长度为1的维度
x = torch.randn(2, 1, 3, 1)
x_sq = x.squeeze()  # 删除所有长度为1的维度
print("x_sq shape:", x_sq.shape)  # torch.Size([2, 3])

总结:

  • 形状调整:reshape(通用)、view(共享内存)是核心,优先用reshape;size()/shape用于获取形状信息。
  • 维度操作:transpose(交换维度)、unsqueeze/squeeze(增 / 删维度)、cat(拼接张量)是维度调整高频函数。
  • 特殊张量创建:arange(生成序列)、full(固定值张量)、triu(上三角矩阵)常用于掩码、索引构造。
  • 记忆要点:cat要求非拼接维度形状一致;triu(diagonal=1)是 Transformer 因果掩码的核心;unsqueeze是扩展掩码维度的常用操作。
相关推荐
aiguangyuan2 小时前
基于BERT的中文命名实体识别实战解析
人工智能·python·nlp
量子-Alex2 小时前
【大模型RLHF】Training language models to follow instructions with human feedback
人工智能·语言模型·自然语言处理
晚霞的不甘2 小时前
Flutter for OpenHarmony 实现计算几何:Graham Scan 凸包算法的可视化演示
人工智能·算法·flutter·架构·开源·音视频
陈天伟教授2 小时前
人工智能应用- 语言处理:04.统计机器翻译
人工智能·自然语言处理·机器翻译
Dfreedom.2 小时前
图像处理中的对比度增强与锐化
图像处理·人工智能·opencv·锐化·对比度增强
wenzhangli72 小时前
OoderAgent 企业版 2.0 发布的意义:一次生态战略的全面升级
人工智能·开源
AI_56783 小时前
SQL性能优化全景指南:从量子执行计划到自适应索引的终极实践
数据库·人工智能·学习·adb
cyyt3 小时前
深度学习周报(2.2~2.8)
人工智能·深度学习
阿杰学AI3 小时前
AI核心知识92——大语言模型之 Self-Attention Mechanism(简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·transformer·自注意力机制