pytorch | pytorch改变tensor维度的方法

pytorch 的 Tensor 类有很多方法可以用来改变 tensor 的维度。这里介绍几种常用的方法:

  • view(shape):返回一个新的 tensor,它具有给定的形状。如果元素总数不变,则可以用它来改变 tensor 的维度。例如:
bash 复制代码
import torch

t = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])
print(t.shape)  # torch.Size([2, 3])

t_view = t.view(3, 2)
print(t_view.shape)  # torch.Size([3, 2])
  • unsqueeze(dim):返回一个新的 tensor,它的指定位置插入了一个新的维度。例如:
bash 复制代码
import torch

t = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])
print(t.shape)  # torch.Size([2, 3])

t_unsqueeze = t.unsqueeze(0)
print(t_unsqueeze.shape)  # torch.Size([1, 2, 3])

t_unsqueeze = t.unsqueeze(1)
print(t_unsqueeze.shape)  # torch.Size([2, 1, 3])

t_unsqueeze = t.unsqueeze(2)
print(t_unsqueeze.shape)  # torch.Size([2, 3, 1])
  • squeeze(dim):返回一个新的 tensor,它的指定位置的维度的大小为 1 的维度被删除。例如:
bash 复制代码
import torch

t = torch.tensor([
    [[1], [2], [3]],
    [[4], [5], [6]]
])
print(t.shape)  # torch.Size([2, 3, 1])

t_squeeze = t.squeeze(2)
print(t_squeeze.shape)  # torch.Size([2, 3])

t_squeeze = t.squeeze()
print(t_squeeze.shape)  # torch.Size([2, 3])
  • transpose(dim0, dim1):返回一个新的 tensor,它的排列被交换。例如:
bash 复制代码
import torch

t = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])
print(t.shape)  # torch.Size([2, 3])

t_transpose = t.transpose(0, 1)
print(t_transpose.shape)  # torch.Size([3, 2])

t_transpose = t.transpose(1, 0)
print(t_transpose.shape)  # torch.Size([3, 2])

还有一些其他的方法,例如 permute() 和 contiguous(),可以用来改变 tensor 的维度。有关这些方法的更多信息,可以参考 pytorch 官方文档:https://pytorch.org/docs/stable/tensors.html。

相关推荐
金融小师妹9 分钟前
3月美联储货币政策决策的动态博弈——基于就业市场数据与通胀预测的AI模型分析
大数据·人工智能·深度学习·机器学习
多恩Stone10 分钟前
【C++ debug】在 VS Code 中无 Attach 调试 Python 调用的 C++ 扩展
开发语言·c++·python
IDZSY043019 分钟前
【机乎】国内版Moltbook低调上线,AI智能体社交悄然生长
人工智能
葡萄城技术团队23 分钟前
从 Shortcut 的爆火,看 AI 时代电子表格的技术底座与架构演进
人工智能·架构
罗政31 分钟前
AI批量识图实战:车牌号提取(包括新能源)
人工智能
XW010599932 分钟前
4-11判断素数
前端·python·算法·素数
冰西瓜60036 分钟前
深度学习的数学原理(十三)—— CNN实战
人工智能·深度学习·cnn
AI街潜水的八角41 分钟前
工业缺陷检测实战——RSDDs北交轨道缺陷分割
人工智能
深蓝电商API41 分钟前
爬虫增量更新:基于时间戳与哈希去重
爬虫·python
人工智能AI技术44 分钟前
什么是多模态
人工智能