【笔记】unsqueeze

unsqueeze是 PyTorch 中的一个方法,用于在指定位置插入一个维度为 1 的新维度。这个操作对于调整张量的形状非常有用,尤其是在需要匹配特定维度要求(例如模型输入或 `torchvision.utils.make_grid` 函数的要求)时。

理解 unsqueeze

假设你有一个形状为 2, 3 的二维张量:

python 复制代码
tensor = torch.randn(2, 3)
print(tensor.shape)  # 输出: torch.Size([2, 3])

如果你想要把这个张量变成三维的,比如形状变为 1, 2, 3,就可以使用 unsqueeze方法。你可以指定在哪一个维度上增加新的维度(从0开始计数)。

  • 在第0维增加新维度:tensor.unsqueeze(0)

  • 在第1维增加新维度:tensor.unsqueeze(1)

例如:

python 复制代码
# 在第0维增加新维度
new_tensor_0 = tensor.unsqueeze(0)
print(new_tensor_0.shape)  # 输出: torch.Size([1, 2, 3])

# 在第1维增加新维度
new_tensor_1 = tensor.unsqueeze(1)
print(new_tensor_1.shape)  # 输出: torch.Size([2, 1, 3])

应用场景

在我的代码上下文中,unsqueeze主要用于确保传入 `make_grid` 的张量具有正确的维度make_grid 需要输入是一个四维张量 (B, C, H, W),其中:

  • B表示批量大小(Batch Size)

  • C表示通道数(Channels)

  • H表示高度(Height)

  • W表示宽度(Width)

例如,如果有一个形状为 7, 224, 224的 mask 张量(即它只有三个维度),而你需要将其转换为四个维度的形式以满足 make_grid 的要求,你可以使用 unsqueeze(1)来在第二个维度(通道维度)上增加一个新的维度:

python 复制代码
masks = masks.unsqueeze(1)  # 将 [7, 224, 224] 转换为 [7, 1, 224, 224]

这样,mask 的形状就变成了 7, 1, 224, 224,符合 make_grid的输入要求。

相关推荐
三品吉他手会点灯7 小时前
STM32F103 学习笔记-24-I2C-读写EEPROM(第1节)-I2C物理层介绍
笔记·stm32·学习
万物更新_8 小时前
vue框架
前端·javascript·vue.js·笔记
上海观智网络9 小时前
上海小程序定制开发合同怎么签?需要注意什么?
经验分享·笔记·小程序
Ab_stupid9 小时前
CTF-Crypto培训笔记-现代密码
笔记·des·aes·rsa·crypto
IT技术学习9 小时前
打包系统为ISO
笔记
就叫飞六吧10 小时前
数学图形绘制在线网站
笔记
SHARK_pssm10 小时前
【数据结构——树与堆】
c语言·数据结构·经验分享·笔记
怪味&先森11 小时前
读书小结—《认知觉醒》
笔记
杨先生哦11 小时前
2026 热端攻防:AI 驱动 Web 前端安全全景透析
前端·笔记·安全·web安全
Cloud_Shy61812 小时前
解读《Effective Python 3rd Edition》:从练气到老魔(第七章 Item 48 - 50)
开发语言·人工智能·笔记·python·microsoft·学习方法