PyTorch 张量维度处理
在深度学习中,张量维度处理是核心能力 ,尤其是多模态融合、序列建模或 CNN 特征提取中。PyTorch 中许多函数都有 dim 参数,它们的行为都是基于维度的索引进行操作。理解这些函数可以帮你系统掌握维度操作。
1. torch.cat(tensors, dim)
作用:沿指定维度拼接张量。
- dim=0 → batch 拼接(样本数增加)
- dim=1 → token/feature 拼接(序列长度增加)
- dim=-1 → embedding/channel 拼接(每个 token 增加维度)
python
A = torch.randn(32, 256, 128) # [batch, token, embedding]
B = torch.randn(32, 128, 128)
C = torch.cat([A, B], dim=1) # 拼接 token 维度
print(C.shape) # [32, 384, 128]
2. torch.stack(tensors, dim)
作用:沿新维度堆叠张量,增加一个维度。
- 和
cat区别:stack会增加新维度,cat沿已有维度拼接。 - 场景:需要把多个相同 shape 的张量合并成一批。
python
A = torch.randn(32, 128)
B = torch.randn(32, 128)
C = torch.stack([A, B], dim=1) # 新增维度
print(C.shape) # [32, 2, 128]
3. torch.sum(input, dim, keepdim=False)
作用:沿指定维度求和。
dim:沿哪个维度求和keepdim=True保留维度(长度为1),便于后续广播
python
x = torch.randn(32, 256, 128)
y = torch.sum(x, dim=1) # [32, 128],token 维度求和
y_keep = torch.sum(x, dim=1, keepdim=True) # [32,1,128]
4. torch.mean(input, dim, keepdim=False)
作用:沿指定维度求平均
python
x = torch.randn(32, 256, 128)
y = torch.mean(x, dim=1) # [32,128],token 平均
5. torch.max(input, dim) / torch.min(input, dim)
作用:沿指定维度求最大值/最小值
- 返回两个值:
values, indices values沿 dim 求最大,indices是最大值索引
python
x = torch.randn(32, 256, 128)
max_val, max_idx = torch.max(x, dim=1) # token 维度
print(max_val.shape) # [32,128]
print(max_idx.shape) # [32,128]
6. torch.argmax(input, dim) / torch.argmin(input, dim)
作用:沿指定维度返回最大/最小值的索引
python
x = torch.randn(32, 256, 128)
indices = torch.argmax(x, dim=1)
print(indices.shape) # [32,128]
7. torch.sort(input, dim, descending=False)
作用:沿 dim 排序
python
x = torch.randn(32, 256)
sorted_x, indices = torch.sort(x, dim=1, descending=True)
print(sorted_x.shape) # [32,256]
8. torch.unsqueeze(input, dim) / torch.squeeze(input, dim)
unsqueeze(dim)→ 增加长度为1的维度squeeze(dim)→ 删除长度为1的维度
python
x = torch.randn(32, 256)
x_seq = x.unsqueeze(-1) # [32,256,1]
x_squeezed = x_seq.squeeze(-1) # [32,256]
9. torch.transpose(input, dim0, dim1) / permute(*dims)
-
transpose(dim0, dim1)→ 交换两个维度 -
permute(*dims)→ 重新排列所有维度 -
场景:
- CNN 数据
[batch, H, W, C]→[batch, C, H, W](PyTorch 默认 NCHW) - RNN 输入
[seq_len, batch, input]→[batch, seq_len, input](batch_first=True)
- CNN 数据
python
x = torch.randn(32, 28, 28, 3)
x = x.permute(0, 3, 1, 2) # [32,3,28,28]
10. torch.flatten(input, start_dim=0, end_dim=-1)
-
展开指定维度区间
-
常用场景:
[batch, token_num, embedding]→[batch, token_num*embedding]→ Linear 输入
python
x = torch.randn(32, 256, 128)
x_flat = x.flatten(start_dim=1)
print(x_flat.shape) # [32, 32768]
11. torch.repeat / expand (涉及 dim)
repeat(*sizes)→ 复制维度数据expand(*sizes)→ 扩展维度但不占用额外内存
python
x = torch.randn(32, 1, 128)
x_repeat = x.repeat(1, 256, 1) # [32,256,128] token 复制
x_expand = x.expand(-1, 256, -1) # [32,256,128] 不占内存
12. torch.split(tensor, split_size_or_sections, dim=?)
- 沿指定维度拆分张量
- 返回 list of tensors
python
x = torch.randn(32, 384, 128)
M1, M2, M3 = torch.split(x, [128,128,128], dim=1) # 沿 token 拼回原来的三个模态
13. torch.chunk(tensor, chunks, dim=?)
- 沿 dim 平均拆分成指定块数
python
x = torch.randn(32, 384, 128)
chunks = torch.chunk(x, 3, dim=1) # 三等分 token
14. 其他含 dim 的常用函数
| 函数 | 作用 | 场景 |
|---|---|---|
torch.mean |
沿 dim 平均 | token/feature 平均 |
torch.std / torch.var |
沿 dim 计算标准差/方差 | token/feature 标准化 |
torch.norm |
沿 dim 计算范数 | 特征归一化 |
torch.cumsum |
累加 | 序列累积和 |
torch.cumprod |
累积乘 | 序列累积乘 |
torch.topk |
返回 top k 最大值 | 注意 topk 的 dim 指定维度 |
torch.index_select |
沿 dim 按索引选择 | token/feature 选择 |
15. 维度处理核心规律
-
最后一维通常是 feature/embedding/channel → Linear / Conv 作用于此
-
dim 参数决定操作沿哪个维度进行(sum、mean、cat、split、topk 等)
-
unsqueeze/squeeze/flatten/permute → 改变张量结构,不改变数据
-
二维 vs 三维输入:
[batch, features]→ MLP[batch, token, embedding]→ Transformer/RNN[batch, C, H, W]→ CNN
-
concat 拼接维度:
- dim=0 → batch
- dim=1 → token/feature
- dim=-1 → embedding/channel
-
flatten(start_dim=1) → 把序列和 embedding 展平为 MLP 输入
-
split/chunk → 拆分模态或序列