【pytorch】keepdim参数解析

keepdim 是 PyTorch 中的一个参数,常用于各种归约操作(如求和、求均值、求最大值等)。当我们对张量进行归约时,通常会减少该维度的大小,但有时我们希望保持归约后的维度不变,这时就会用到 keepdim=True

举个例子

假设我们有一个 2x3 的张量 x

python 复制代码
import torch

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(x)

输出:

复制代码
tensor([[1, 2, 3],
        [4, 5, 6]])
1. 不使用 keepdim

我们对张量的某个维度进行求均值操作,例如对维度 1(列)求均值:

python 复制代码
mean_without_keepdim = x.mean(dim=1)
print(mean_without_keepdim)

输出:

复制代码
tensor([2., 5.])

在这种情况下,原本的 2x3 的张量被压缩成了 1D 的张量 [2., 5.],原来的维度 1(列)被"消除"了。

2. 使用 keepdim=True
python 复制代码
mean_with_keepdim = x.mean(dim=1, keepdim=True)
print(mean_with_keepdim)

输出:

复制代码
tensor([[2.],
        [5.]])

在这种情况下,虽然我们在维度 1 上进行了均值操作,但 keepdim=True 保持了维度结构,所以结果仍然是 2x1 的张量,而不是被压缩成 1D 的张量。即原来的维度 1 被保留,只是大小从 3 变成了 1。

总结

  • keepdim=False(默认值):归约操作后,所归约的维度会被移除,张量的维度会减少。
  • keepdim=True:归约操作后,所归约的维度会被保留,张量的维度不变,但该维度的大小变为 1。

这是在处理张量形状时非常有用的功能,尤其是在需要保持张量形状一致性的场景下(比如在某些层归一化操作或在神经网络中)。

相关推荐
宅小年3 分钟前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
AI探索者11 分钟前
LangGraph StateGraph 实战:状态机聊天机器人构建指南
python
AI探索者13 分钟前
LangGraph 入门:构建带记忆功能的天气查询 Agent
python
九狼17 分钟前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS25 分钟前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区2 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈2 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
FishCoderh2 小时前
Python自动化办公实战:批量重命名文件,告别手动操作
python
躺平大鹅2 小时前
Python函数入门详解(定义+调用+参数)
python
Ray Liang2 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx