【笔记】unsqueeze

unsqueeze是 PyTorch 中的一个方法,用于在指定位置插入一个维度为 1 的新维度。这个操作对于调整张量的形状非常有用,尤其是在需要匹配特定维度要求(例如模型输入或 `torchvision.utils.make_grid` 函数的要求)时。

理解 unsqueeze

假设你有一个形状为 [2, 3] 的二维张量:

python 复制代码
tensor = torch.randn(2, 3)
print(tensor.shape)  # 输出: torch.Size([2, 3])

如果你想要把这个张量变成三维的,比如形状变为 [1, 2, 3],就可以使用 unsqueeze方法。你可以指定在哪一个维度上增加新的维度(从0开始计数)。

  • 在第0维增加新维度:tensor.unsqueeze(0)

  • 在第1维增加新维度:tensor.unsqueeze(1)

例如:

python 复制代码
# 在第0维增加新维度
new_tensor_0 = tensor.unsqueeze(0)
print(new_tensor_0.shape)  # 输出: torch.Size([1, 2, 3])

# 在第1维增加新维度
new_tensor_1 = tensor.unsqueeze(1)
print(new_tensor_1.shape)  # 输出: torch.Size([2, 1, 3])

应用场景

在我的代码上下文中,unsqueeze主要用于确保传入 `make_grid` 的张量具有正确的维度make_grid 需要输入是一个四维张量 (B, C, H, W),其中:

  • B表示批量大小(Batch Size)

  • C表示通道数(Channels)

  • H表示高度(Height)

  • W表示宽度(Width)

例如,如果有一个形状为 [7, 224, 224]的 mask 张量(即它只有三个维度),而你需要将其转换为四个维度的形式以满足 make_grid 的要求,你可以使用 unsqueeze(1)来在第二个维度(通道维度)上增加一个新的维度:

python 复制代码
masks = masks.unsqueeze(1)  # 将 [7, 224, 224] 转换为 [7, 1, 224, 224]

这样,mask 的形状就变成了 [7, 1, 224, 224],符合 make_grid的输入要求。

相关推荐
古译汉书7 分钟前
嵌入式铁头山羊STM32-各章节详细笔记-查阅传送门
数据结构·笔记·stm32·单片机·嵌入式硬件·个人开发
2301_800050993 小时前
DNS 服务器
linux·运维·笔记
汇能感知3 小时前
光谱相机的未来趋势
经验分享·笔记·科技
风已经起了6 小时前
FPGA学习笔记——图像处理之对比度调节(直方图均衡化)
图像处理·笔记·学习·fpga开发·fpga
go_bai6 小时前
Linux--常见工具
linux·开发语言·经验分享·笔记·vim·学习方法
sjh21008 小时前
【学习笔记】20年前的微芯an1078foc技术,smo滑模位置估计,反电动势波形还不错,为何位置估计反而超前了呢?
笔记·学习
航Hang*9 小时前
Kurt-Blender零基础教程:第3章:材质篇——第1节:材质基础~原理化BSDF,添加有纹理材质与用蒙版做纹理叠加
笔记·blender·材质·建模
泽虞10 小时前
《C++程序设计》笔记p4
linux·开发语言·c++·笔记·算法
峰顶听歌的鲸鱼11 小时前
29.Linux防火墙管理
linux·运维·网络·笔记·学习方法
jun~11 小时前
SQLMap绕过 Web 应用程序保护靶机(打靶记录)
linux·笔记·学习·安全·web安全