1. transpose -- 交换两个维度
import torch
x = torch.randn(2, 3, 5) # shape: (2, 3, 5)
y = x.transpose(0, 1) # shape: (3, 2, 5) 交换第0维和第1维
2. permute -- 任意重排所有维度
x = torch.randn(2, 3, 5)
y = x.permute(2, 0, 1) # shape: (5, 2, 3) 原第2维→新第0维,原第0维→新第1维,原第1维→新第2维
z = x.permute(1, 2, 0) # shape: (3, 5, 2)
常用场景
图像通道调整 :
(B, H, W, C)→(B, C, H, W)用permute(0, 3, 1, 2)转置矩阵 :二维张量
(M, N)→(N, M)用transpose(0, 1)或t()批量矩阵转置 :
(B, M, N)→(B, N, M)用transpose(1, 2)
非连续内存 :transpose 和 permute 会让张量变得不连续,若后续需要 .view() 操作,必须先调用 .contiguous()。
x = torch.randn(2, 3)
y = x.transpose(0, 1) # shape (3,2), 不连续
# y.view(6) # 可能报错!
y = y.contiguous().view(6) # 正确做法