PyTorch 张量维度处理详解

PyTorch 张量维度处理

在深度学习中,张量维度处理是核心能力 ,尤其是多模态融合、序列建模或 CNN 特征提取中。PyTorch 中许多函数都有 dim 参数,它们的行为都是基于维度的索引进行操作。理解这些函数可以帮你系统掌握维度操作。


1. torch.cat(tensors, dim)

作用:沿指定维度拼接张量。

  • dim=0 → batch 拼接(样本数增加)
  • dim=1 → token/feature 拼接(序列长度增加)
  • dim=-1 → embedding/channel 拼接(每个 token 增加维度)
python 复制代码
A = torch.randn(32, 256, 128)  # [batch, token, embedding]
B = torch.randn(32, 128, 128)
C = torch.cat([A, B], dim=1)  # 拼接 token 维度
print(C.shape)  # [32, 384, 128]

2. torch.stack(tensors, dim)

作用:沿新维度堆叠张量,增加一个维度。

  • cat 区别:stack 会增加新维度,cat 沿已有维度拼接。
  • 场景:需要把多个相同 shape 的张量合并成一批。
python 复制代码
A = torch.randn(32, 128)
B = torch.randn(32, 128)
C = torch.stack([A, B], dim=1)  # 新增维度
print(C.shape)  # [32, 2, 128]

3. torch.sum(input, dim, keepdim=False)

作用:沿指定维度求和。

  • dim:沿哪个维度求和
  • keepdim=True 保留维度(长度为1),便于后续广播
python 复制代码
x = torch.randn(32, 256, 128)
y = torch.sum(x, dim=1)        # [32, 128],token 维度求和
y_keep = torch.sum(x, dim=1, keepdim=True)  # [32,1,128]

4. torch.mean(input, dim, keepdim=False)

作用:沿指定维度求平均

python 复制代码
x = torch.randn(32, 256, 128)
y = torch.mean(x, dim=1)  # [32,128],token 平均

5. torch.max(input, dim) / torch.min(input, dim)

作用:沿指定维度求最大值/最小值

  • 返回两个值:values, indices
  • values 沿 dim 求最大,indices 是最大值索引
python 复制代码
x = torch.randn(32, 256, 128)
max_val, max_idx = torch.max(x, dim=1)  # token 维度
print(max_val.shape)  # [32,128]
print(max_idx.shape)  # [32,128]

6. torch.argmax(input, dim) / torch.argmin(input, dim)

作用:沿指定维度返回最大/最小值的索引

python 复制代码
x = torch.randn(32, 256, 128)
indices = torch.argmax(x, dim=1)
print(indices.shape)  # [32,128]

7. torch.sort(input, dim, descending=False)

作用:沿 dim 排序

python 复制代码
x = torch.randn(32, 256)
sorted_x, indices = torch.sort(x, dim=1, descending=True)
print(sorted_x.shape)  # [32,256]

8. torch.unsqueeze(input, dim) / torch.squeeze(input, dim)

  • unsqueeze(dim) → 增加长度为1的维度
  • squeeze(dim) → 删除长度为1的维度
python 复制代码
x = torch.randn(32, 256)
x_seq = x.unsqueeze(-1)  # [32,256,1]
x_squeezed = x_seq.squeeze(-1)  # [32,256]

9. torch.transpose(input, dim0, dim1) / permute(*dims)

  • transpose(dim0, dim1) → 交换两个维度

  • permute(*dims) → 重新排列所有维度

  • 场景:

    • CNN 数据 [batch, H, W, C][batch, C, H, W](PyTorch 默认 NCHW)
    • RNN 输入 [seq_len, batch, input][batch, seq_len, input](batch_first=True)
python 复制代码
x = torch.randn(32, 28, 28, 3)
x = x.permute(0, 3, 1, 2)  # [32,3,28,28]

10. torch.flatten(input, start_dim=0, end_dim=-1)

  • 展开指定维度区间

  • 常用场景:

    • [batch, token_num, embedding][batch, token_num*embedding] → Linear 输入
python 复制代码
x = torch.randn(32, 256, 128)
x_flat = x.flatten(start_dim=1)
print(x_flat.shape)  # [32, 32768]

11. torch.repeat / expand (涉及 dim)

  • repeat(*sizes) → 复制维度数据
  • expand(*sizes) → 扩展维度但不占用额外内存
python 复制代码
x = torch.randn(32, 1, 128)
x_repeat = x.repeat(1, 256, 1)  # [32,256,128] token 复制
x_expand = x.expand(-1, 256, -1)  # [32,256,128] 不占内存

12. torch.split(tensor, split_size_or_sections, dim=?)

  • 沿指定维度拆分张量
  • 返回 list of tensors
python 复制代码
x = torch.randn(32, 384, 128)
M1, M2, M3 = torch.split(x, [128,128,128], dim=1)  # 沿 token 拼回原来的三个模态

13. torch.chunk(tensor, chunks, dim=?)

  • 沿 dim 平均拆分成指定块数
python 复制代码
x = torch.randn(32, 384, 128)
chunks = torch.chunk(x, 3, dim=1)  # 三等分 token

14. 其他含 dim 的常用函数

函数 作用 场景
torch.mean 沿 dim 平均 token/feature 平均
torch.std / torch.var 沿 dim 计算标准差/方差 token/feature 标准化
torch.norm 沿 dim 计算范数 特征归一化
torch.cumsum 累加 序列累积和
torch.cumprod 累积乘 序列累积乘
torch.topk 返回 top k 最大值 注意 topk 的 dim 指定维度
torch.index_select 沿 dim 按索引选择 token/feature 选择

15. 维度处理核心规律

  1. 最后一维通常是 feature/embedding/channel → Linear / Conv 作用于此

  2. dim 参数决定操作沿哪个维度进行(sum、mean、cat、split、topk 等)

  3. unsqueeze/squeeze/flatten/permute → 改变张量结构,不改变数据

  4. 二维 vs 三维输入

    • [batch, features] → MLP
    • [batch, token, embedding] → Transformer/RNN
    • [batch, C, H, W] → CNN
  5. concat 拼接维度

    • dim=0 → batch
    • dim=1 → token/feature
    • dim=-1 → embedding/channel
  6. flatten(start_dim=1) → 把序列和 embedding 展平为 MLP 输入

  7. split/chunk → 拆分模态或序列

相关推荐
郝学胜-神的一滴2 小时前
Python对象的自省机制:深入探索对象的内心世界
开发语言·python·程序人生·算法
CHrisFC2 小时前
电力线路器材行业LIMS系统应用全解析
网络·人工智能·安全
tjjucheng2 小时前
小程序定制开发哪家有数据支持
python
cxr8282 小时前
稀缺的炼金术:用第一性原理与系统思维在绝境中构建认知优势
人工智能·思维模型·认知·认知框架
qdprobot2 小时前
具身智能小智AI小车图形化编程Mixly MQTT MCP AIOT控制齐护机器人
人工智能·机器人
说私域2 小时前
全民电商时代下的链动2+1模式与S2B2C商城小程序:社交裂变与供应链协同的营销革命
开发语言·人工智能·小程序·php·流量运营
M宝可梦2 小时前
I-JEPA CVPR2023 LeCun所说的world model和视频生成模型是一回事儿吗
人工智能·大语言模型·世界模型·lecun·jepa
云卓SKYDROID2 小时前
无人机防撞模块技术解析
人工智能·无人机·高科技·云卓科技·技术解析、
marteker2 小时前
迪士尼将营销业务整合为一个专注于协同和灵活的部门
人工智能