pytorch | pytorch改变tensor维度的方法

pytorch 的 Tensor 类有很多方法可以用来改变 tensor 的维度。这里介绍几种常用的方法:

  • view(shape):返回一个新的 tensor,它具有给定的形状。如果元素总数不变,则可以用它来改变 tensor 的维度。例如:
bash 复制代码
import torch

t = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])
print(t.shape)  # torch.Size([2, 3])

t_view = t.view(3, 2)
print(t_view.shape)  # torch.Size([3, 2])
  • unsqueeze(dim):返回一个新的 tensor,它的指定位置插入了一个新的维度。例如:
bash 复制代码
import torch

t = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])
print(t.shape)  # torch.Size([2, 3])

t_unsqueeze = t.unsqueeze(0)
print(t_unsqueeze.shape)  # torch.Size([1, 2, 3])

t_unsqueeze = t.unsqueeze(1)
print(t_unsqueeze.shape)  # torch.Size([2, 1, 3])

t_unsqueeze = t.unsqueeze(2)
print(t_unsqueeze.shape)  # torch.Size([2, 3, 1])
  • squeeze(dim):返回一个新的 tensor,它的指定位置的维度的大小为 1 的维度被删除。例如:
bash 复制代码
import torch

t = torch.tensor([
    [[1], [2], [3]],
    [[4], [5], [6]]
])
print(t.shape)  # torch.Size([2, 3, 1])

t_squeeze = t.squeeze(2)
print(t_squeeze.shape)  # torch.Size([2, 3])

t_squeeze = t.squeeze()
print(t_squeeze.shape)  # torch.Size([2, 3])
  • transpose(dim0, dim1):返回一个新的 tensor,它的排列被交换。例如:
bash 复制代码
import torch

t = torch.tensor([
    [1, 2, 3],
    [4, 5, 6]
])
print(t.shape)  # torch.Size([2, 3])

t_transpose = t.transpose(0, 1)
print(t_transpose.shape)  # torch.Size([3, 2])

t_transpose = t.transpose(1, 0)
print(t_transpose.shape)  # torch.Size([3, 2])

还有一些其他的方法,例如 permute() 和 contiguous(),可以用来改变 tensor 的维度。有关这些方法的更多信息,可以参考 pytorch 官方文档:https://pytorch.org/docs/stable/tensors.html。

相关推荐
kszlgy2 小时前
Day 52 神经网络调参指南
python
wrj的博客3 小时前
python环境安装
python·学习·环境配置
康康的AI博客3 小时前
腾讯王炸:CodeMoment - 全球首个产设研一体 AI IDE
ide·人工智能
中达瑞和-高光谱·多光谱3 小时前
中达瑞和LCTF:精准调控光谱,赋能显微成像新突破
人工智能
mahtengdbb13 小时前
【目标检测实战】基于YOLOv8-DynamicHGNetV2的猪面部检测系统搭建与优化
人工智能·yolo·目标检测
Pyeako3 小时前
深度学习--BP神经网络&梯度下降&损失函数
人工智能·python·深度学习·bp神经网络·损失函数·梯度下降·正则化惩罚
清 澜4 小时前
大模型面试400问第一部分第一章
人工智能·大模型·大模型面试
不大姐姐AI智能体4 小时前
搭了个小红书笔记自动生产线,一句话生成图文,一键发布,支持手机端、电脑端发布
人工智能·经验分享·笔记·矩阵·aigc
摘星编程4 小时前
OpenHarmony环境下React Native:Geolocation地理围栏
python