在 PyTorch 中,tensor.view()
是一个常用的方法,用于改变张量(Tensor)的形状(shape),但不会改变其数据本身。它类似于 NumPy 的 reshape()
,但有一些关键区别。
1. 基本用法
python
import torch
x = torch.arange(1, 10) # shape: [9]
print(x)
# tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
# 改变形状为 (3, 3)
y = x.view(3, 3)
print(y)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
关键点:
-
不改变数据,只是重新排列维度。
-
新形状的元素数量必须与原张量一致 ,否则会报错:
pythonx.view(2, 5) # ❌ 错误!因为 2×5=10,但 x 只有 9 个元素
2. 自动推断维度(-1 的作用)
如果不想手动计算某个维度的大小,可以用 -1
,PyTorch 会自动计算:
python
x = torch.arange(1, 10) # shape: [9]
# 自动计算行数,确保列数是 3
y = x.view(-1, 3) # shape: [3, 3]
print(y)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
# 自动计算列数,确保行数是 3
z = x.view(3, -1) # shape: [3, 3]
print(z)
# 输出同上
3. view()
vs reshape()
方法 | 是否共享内存 | 是否适用于非连续存储 | 适用场景 |
---|---|---|---|
view() |
✅ 共享内存(修改会影响原张量) | ❌ 仅适用于连续存储的张量 | 高效改变形状(推荐优先使用) |
reshape() |
✅ 可能共享内存(如果可能) | ✅ 适用于非连续存储 | 更通用,但可能额外复制数据 |
示例对比
python
x = torch.arange(1, 10) # 连续存储
# view() 可以正常工作
y = x.view(3, 3)
# 如果张量不连续(如转置后),view() 会报错
x_transposed = x.t() # 转置后存储不连续
# z = x_transposed.view(9) # ❌ RuntimeError: view size is not compatible with input tensor's size and stride
# reshape() 可以处理非连续存储
z = x_transposed.reshape(9) # ✅
4. 常见用途
(1) 展平张量(Flatten)
python
x = torch.randn(4, 5) # shape: [4, 5]
flattened = x.view(-1) # shape: [20]
(2) 调整 CNN 特征图维度
python
# 假设 CNN 输出是 [batch_size, channels, height, width]
features = torch.randn(32, 64, 7, 7) # shape: [32, 64, 7, 7]
# 展平成 [batch_size, channels * height * width] 用于全连接层
flattened = features.view(32, -1) # shape: [32, 64*7*7] = [32, 3136]
(3) 交换维度(类似 permute
)
python
x = torch.randn(2, 3, 4) # shape: [2, 3, 4]
y = x.view(2, 4, 3) # shape: [2, 4, 3](相当于交换最后两维)
5. 注意事项
-
view()
只适用于连续存储的张量 ,否则会报错,此时应该用reshape()
或先.contiguous()
:pythonx_non_contiguous = x.t() # 转置后不连续 x_contiguous = x_non_contiguous.contiguous() # 变成连续存储 y = x_contiguous.view(...) # 现在可以用 view()
-
view()
返回的新张量与原张量共享内存 ,修改其中一个会影响另一个:pythonx = torch.arange(1, 10) y = x.view(3, 3) y[0, 0] = 100 # 修改 y 会影响 x print(x) # tensor([100, 2, 3, 4, 5, 6, 7, 8, 9])
总结
操作 | 推荐方法 |
---|---|
改变形状(连续张量) | view() |
改变形状(非连续张量) | reshape() 或 .contiguous().view() |
展平张量 | x.view(-1) 或 torch.flatten(x) |
调整 CNN 特征图维度 | features.view(batch_size, -1) |
view()
是 PyTorch 中高效调整张量形状的首选方法,但要注意内存共享和连续性限制! 🚀
展平(Flatten)或改变形状(如 view
、reshape
)的核心原则是保持张量的总元素个数(numel()
)不变,只是重新排列这些元素的维度。
1. 元素总数不变原则
无论原始张量是几维的(1D、2D、3D 或更高维),转换后的新形状必须满足:
原形状的元素总数 = 新形状的元素总数 \text{原形状的元素总数} = \text{新形状的元素总数} 原形状的元素总数=新形状的元素总数
即:
元素总数需满足:
dim 1 × dim 2 × ⋯ × dim n = new_dim 1 × new_dim 2 × ⋯ × new_dim m \text{dim}_1 \times \text{dim}_2 \times \dots \times \text{dim}_n = \text{new\_dim}_1 \times \text{new\_dim}_2 \times \dots \times \text{new\_dim}_m dim1×dim2×⋯×dimn=new_dim1×new_dim2×⋯×new_dimm
示例:
python
import torch
x = torch.arange(24) # 1D 张量,24 个元素
print(x.numel()) # 输出:24
# 转换为 2D 张量:4 行 × 6 列(4×6=24)
y = x.view(4, 6) # 形状 [4, 6]
# 转换为 3D 张量:2×3×4(2×3×4=24)
z = x.view(2, 3, 4) # 形状 [2, 3, 4]
2. 自动推断维度(-1
的作用)
在指定新形状时,可以用 -1
代表"自动计算该维度大小" ,PyTorch 会根据总元素数和其他已知维度推导出 -1
的值。
规则:
推断的维度 = 总元素数 已知维度的乘积 \text{推断的维度} = \frac{\text{总元素数}}{\text{已知维度的乘积}} 推断的维度=已知维度的乘积总元素数
示例:
python
x = torch.arange(24) # 24 个元素
# 自动计算行数,确保列数为 6
y = x.view(-1, 6) # 形状 [4, 6](因为 24/6=4)
# 自动计算列数,确保行数为 3
z = x.view(3, -1) # 形状 [3, 8](因为 24/3=8)
错误示例:
如果维度乘积不匹配总元素数,会报错:
python
x.view(5, -1) # ❌ 报错!24 无法被 5 整除
3. 展平(Flatten)的本质
展平是将任意维度的张量转换为一维或二维的形式:
- 一维展平 :
x.view(-1)
或x.flatten()
将所有元素排成一行,形状变为[num_elements]
。 - 二维展平(保留批处理维度) :
x.view(batch_size, -1)
或nn.Flatten()
保持batch_size
不变,其余维度合并为第二维,形状变为[batch_size, features]
。
示例:
python
x = torch.randn(2, 3, 4) # 形状 [2, 3, 4],总元素数=24
# 一维展平
flatten_1d = x.view(-1) # 形状 [24]
# 二维展平(保留第0维 batch_size=2)
flatten_2d = x.view(2, -1) # 形状 [2, 12](因为 3×4=12)
4. 为什么需要手动指定部分维度?
- 全连接层的输入要求 :
通常需要二维张量[batch_size, features]
,因此需明确保留batch_size
,其余维度展平。 - 避免歧义 :
例如,若张量形状为[32, 64, 7, 7]
,想展平成[32, 3136]
,需明确第二维是3136
(即64×7×7
),而-1
让 PyTorch 自动计算。
代码对比:
python
# 明确指定第二维
flatten_explicit = x.view(32, 64*7*7) # 形状 [32, 3136]
# 用 -1 自动计算
flatten_auto = x.view(32, -1) # 形状 [32, 3136](推荐)
5. 关键总结
- 元素总数不变:形状变换的本质是重新排列数据,不增删元素。
-1
的作用:自动计算该维度大小,确保总元素数匹配。- 展平的应用场景 :
- 全连接层前必须将多维特征转换为一维向量(如 CNN 的
[batch, C, H, W]
→[batch, C*H*W]
)。 - 数据预处理时调整输入形状(如图像展平为向量)。
- 全连接层前必须将多维特征转换为一维向量(如 CNN 的