Pytorch--tensor.view()

在 PyTorch 中,tensor.view() 是一个常用的方法,用于改变张量(Tensor)的形状(shape),但不会改变其数据本身。它类似于 NumPy 的 reshape(),但有一些关键区别。


1. 基本用法

python 复制代码
import torch

x = torch.arange(1, 10)  # shape: [9]
print(x)
# tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])

# 改变形状为 (3, 3)
y = x.view(3, 3)  
print(y)
# tensor([[1, 2, 3],
#         [4, 5, 6],
#         [7, 8, 9]])

关键点

  • 不改变数据,只是重新排列维度。

  • 新形状的元素数量必须与原张量一致 ,否则会报错:

    python 复制代码
    x.view(2, 5)  # ❌ 错误!因为 2×5=10,但 x 只有 9 个元素

2. 自动推断维度(-1 的作用)

如果不想手动计算某个维度的大小,可以用 -1,PyTorch 会自动计算:

python 复制代码
x = torch.arange(1, 10)  # shape: [9]

# 自动计算行数,确保列数是 3
y = x.view(-1, 3)  # shape: [3, 3]
print(y)
# tensor([[1, 2, 3],
#         [4, 5, 6],
#         [7, 8, 9]])

# 自动计算列数,确保行数是 3
z = x.view(3, -1)  # shape: [3, 3]
print(z)
# 输出同上

3. view() vs reshape()

方法 是否共享内存 是否适用于非连续存储 适用场景
view() ✅ 共享内存(修改会影响原张量) ❌ 仅适用于连续存储的张量 高效改变形状(推荐优先使用)
reshape() ✅ 可能共享内存(如果可能) ✅ 适用于非连续存储 更通用,但可能额外复制数据

示例对比

python 复制代码
x = torch.arange(1, 10)  # 连续存储

# view() 可以正常工作
y = x.view(3, 3)  

# 如果张量不连续(如转置后),view() 会报错
x_transposed = x.t()  # 转置后存储不连续
# z = x_transposed.view(9)  # ❌ RuntimeError: view size is not compatible with input tensor's size and stride

# reshape() 可以处理非连续存储
z = x_transposed.reshape(9)  # ✅

4. 常见用途

(1) 展平张量(Flatten)

python 复制代码
x = torch.randn(4, 5)  # shape: [4, 5]
flattened = x.view(-1)  # shape: [20]

(2) 调整 CNN 特征图维度

python 复制代码
# 假设 CNN 输出是 [batch_size, channels, height, width]
features = torch.randn(32, 64, 7, 7)  # shape: [32, 64, 7, 7]

# 展平成 [batch_size, channels * height * width] 用于全连接层
flattened = features.view(32, -1)  # shape: [32, 64*7*7] = [32, 3136]

(3) 交换维度(类似 permute

python 复制代码
x = torch.randn(2, 3, 4)  # shape: [2, 3, 4]
y = x.view(2, 4, 3)  # shape: [2, 4, 3](相当于交换最后两维)

5. 注意事项

  1. view() 只适用于连续存储的张量 ,否则会报错,此时应该用 reshape() 或先 .contiguous()

    python 复制代码
    x_non_contiguous = x.t()  # 转置后不连续
    x_contiguous = x_non_contiguous.contiguous()  # 变成连续存储
    y = x_contiguous.view(...)  # 现在可以用 view()
  2. view() 返回的新张量与原张量共享内存 ,修改其中一个会影响另一个:

    python 复制代码
    x = torch.arange(1, 10)
    y = x.view(3, 3)
    y[0, 0] = 100  # 修改 y 会影响 x
    print(x)  # tensor([100, 2, 3, 4, 5, 6, 7, 8, 9])

总结

操作 推荐方法
改变形状(连续张量) view()
改变形状(非连续张量) reshape().contiguous().view()
展平张量 x.view(-1)torch.flatten(x)
调整 CNN 特征图维度 features.view(batch_size, -1)

view() 是 PyTorch 中高效调整张量形状的首选方法,但要注意内存共享和连续性限制! 🚀
展平(Flatten)或改变形状(如 viewreshape)的核心原则是保持张量的总元素个数(numel())不变,只是重新排列这些元素的维度。


1. 元素总数不变原则

无论原始张量是几维的(1D、2D、3D 或更高维),转换后的新形状必须满足:
原形状的元素总数 = 新形状的元素总数 \text{原形状的元素总数} = \text{新形状的元素总数} 原形状的元素总数=新形状的元素总数

即:

元素总数需满足:
dim 1 × dim 2 × ⋯ × dim n = new_dim 1 × new_dim 2 × ⋯ × new_dim m \text{dim}_1 \times \text{dim}_2 \times \dots \times \text{dim}_n = \text{new\_dim}_1 \times \text{new\_dim}_2 \times \dots \times \text{new\_dim}_m dim1×dim2×⋯×dimn=new_dim1×new_dim2×⋯×new_dimm

示例:
python 复制代码
import torch

x = torch.arange(24)  # 1D 张量,24 个元素
print(x.numel())      # 输出:24

# 转换为 2D 张量:4 行 × 6 列(4×6=24)
y = x.view(4, 6)      # 形状 [4, 6]

# 转换为 3D 张量:2×3×4(2×3×4=24)
z = x.view(2, 3, 4)   # 形状 [2, 3, 4]

2. 自动推断维度(-1 的作用)

在指定新形状时,可以用 -1 代表"自动计算该维度大小" ,PyTorch 会根据总元素数和其他已知维度推导出 -1 的值。

规则:
推断的维度 = 总元素数 已知维度的乘积 \text{推断的维度} = \frac{\text{总元素数}}{\text{已知维度的乘积}} 推断的维度=已知维度的乘积总元素数

示例:
python 复制代码
x = torch.arange(24)  # 24 个元素

# 自动计算行数,确保列数为 6
y = x.view(-1, 6)     # 形状 [4, 6](因为 24/6=4)

# 自动计算列数,确保行数为 3
z = x.view(3, -1)     # 形状 [3, 8](因为 24/3=8)
错误示例:

如果维度乘积不匹配总元素数,会报错:

python 复制代码
x.view(5, -1)  # ❌ 报错!24 无法被 5 整除

3. 展平(Flatten)的本质

展平是将任意维度的张量转换为一维或二维的形式:

  • 一维展平x.view(-1)x.flatten()
    将所有元素排成一行,形状变为 [num_elements]
  • 二维展平(保留批处理维度)x.view(batch_size, -1)nn.Flatten()
    保持 batch_size 不变,其余维度合并为第二维,形状变为 [batch_size, features]
示例:
python 复制代码
x = torch.randn(2, 3, 4)  # 形状 [2, 3, 4],总元素数=24

# 一维展平
flatten_1d = x.view(-1)    # 形状 [24]

# 二维展平(保留第0维 batch_size=2)
flatten_2d = x.view(2, -1) # 形状 [2, 12](因为 3×4=12)

4. 为什么需要手动指定部分维度?

  • 全连接层的输入要求
    通常需要二维张量 [batch_size, features],因此需明确保留 batch_size,其余维度展平。
  • 避免歧义
    例如,若张量形状为 [32, 64, 7, 7],想展平成 [32, 3136],需明确第二维是 3136(即 64×7×7),而 -1 让 PyTorch 自动计算。
代码对比:
python 复制代码
# 明确指定第二维
flatten_explicit = x.view(32, 64*7*7)  # 形状 [32, 3136]

# 用 -1 自动计算
flatten_auto = x.view(32, -1)          # 形状 [32, 3136](推荐)

5. 关键总结

  1. 元素总数不变:形状变换的本质是重新排列数据,不增删元素。
  2. -1 的作用:自动计算该维度大小,确保总元素数匹配。
  3. 展平的应用场景
    • 全连接层前必须将多维特征转换为一维向量(如 CNN 的 [batch, C, H, W][batch, C*H*W])。
    • 数据预处理时调整输入形状(如图像展平为向量)。
相关推荐
新智元几秒前
95 后打造世界首个行动型浏览器——Fellou,从「浏览」到「行动」一键直达!
人工智能·openai
新智元5 分钟前
硅谷 AI 初创要让 60 亿人失业,网友痛批人类叛徒!Jeff Dean 已投
人工智能·openai
haochengxia13 分钟前
vLLM V1 KV Cache Manager 源码学习
人工智能
黑心萝卜三条杠18 分钟前
解锁高性能,YOLOv8 部署至 Jetson Orin Nano 开发板的全攻略
人工智能
小oo呆18 分钟前
【自然语言处理与大模型】Linux环境下Ollama下载太慢了该怎么处理?
linux·服务器·人工智能
ljd21032312423 分钟前
opencv函数展示
人工智能·opencv·计算机视觉
0x21128 分钟前
[论文阅读]Making Retrieval-Augmented Language Models Robust to Irrelevant Context
论文阅读·人工智能·语言模型
新加坡内哥谈技术28 分钟前
大语言模型推理能力的强化学习现状理解GRPO与近期推理模型研究的新见解
人工智能·语言模型·自然语言处理·chatgpt
玛哈特-小易37 分钟前
玛哈特整平机:工业制造中的关键设备
人工智能·制造·精密校平机·精密矫平机·钢板矫平机
智享食事1 小时前
数字化时代下的工业物联网智能体开发平台策略
人工智能