pytorch小记(二十二):全面解读 PyTorch 的 `torch.cumprod`——累积乘积详解与实战示例

pytorch小记(二十二):全面解读 PyTorch 的 `torch.cumprod`------累积乘积详解与实战示例

    • 一、函数签名与参数说明
    • 二、基础用法
      • [1. 一维张量累积乘积](#1. 一维张量累积乘积)
      • [2. 二维张量按行/按列累积](#2. 二维张量按行/按列累积)
    • [三、`dtype` 参数:避免整数溢出与提升精度](#三、dtype 参数:避免整数溢出与提升精度)
    • 四、典型应用场景
      • [1. 几何序列生成](#1. 几何序列生成)
      • [2. 概率分布的累积乘积](#2. 概率分布的累积乘积)
      • [3. 模型门控或权重衰减](#3. 模型门控或权重衰减)
    • [五、进阶示例:预分配 `out` 张量](#五、进阶示例:预分配 out 张量)
    • 六、小结

在深度学习与科学计算中,往往需要沿某个维度追踪"前面所有元素的乘积",比如几何序列计算、概率分布构建、模型门控/权重衰减等场景。PyTorch 提供的 torch.cumprod 函数可以一行代码搞定这一需求。本文将从函数签名、参数含义、基础用法,到进阶示例、典型应用场景,为你带来最全面的讲解,并附上丰富示例助你快速上手。


一、函数签名与参数说明

python 复制代码
torch.cumprod(
    input: Tensor,
    dim: int,
    *,
    dtype: Optional[torch.dtype] = None,
    out: Optional[Tensor] = None
) → Tensor
  • input:任意维度的输入张量。
  • dim :指定沿哪个维度做累积乘积(0 表示第一个维度,以此类推)。
  • dtype(可选):输出张量的数据类型。如果原张量为整数且会溢出,可通过将其提升到更宽数据类型来避免溢出。
  • out(可选):预先分配好的张量,用于存储输出,避免额外内存分配。

二、基础用法

1. 一维张量累积乘积

python 复制代码
import torch

x = torch.tensor([1, 2, 3, 4])
y = torch.cumprod(x, dim=0)
print(y)  # tensor([ 1,  2,  6, 24])
  • y[0] = 1
  • y[1] = 1 * 2 = 2
  • y[2] = 1 * 2 * 3 = 6
  • y[3] = 1 * 2 * 3 * 4 = 24

2. 二维张量按行/按列累积

python 复制代码
x2 = torch.tensor([[1, 2, 3],
                   [4, 5, 6]])
# 沿行(dim=1)累积
row_prod = torch.cumprod(x2, dim=1)
print(row_prod)
# tensor([[  1,   2,   6],
#         [  4,  20, 120]])

# 沿列(dim=0)累积
col_prod = torch.cumprod(x2, dim=0)
print(col_prod)
# tensor([[1, 2,  3],
#         [4, 10, 18]])

三、dtype 参数:避免整数溢出与提升精度

input 为大整数且乘积超出类型范围时,会导致溢出。此时可指定更宽的数据类型:

python 复制代码
x_int = torch.tensor([1000, 1000, 1000], dtype=torch.int32)
# 默认 int32 会溢出
print(torch.cumprod(x_int, dim=0))
# tensor([1000,  -727,  -728], dtype=torch.int32)

# 改为 int64 避免溢出
print(torch.cumprod(x_int, dim=0, dtype=torch.int64))
# tensor([      1000,    1000000, 1000000000])

四、典型应用场景

1. 几何序列生成

几何序列 a , a r , a r 2 , ... a, ar, ar^2, ... a,ar,ar2,... 可用累积乘积实现:

python 复制代码
a, r, n = 2.0, 0.5, 5
ratios = torch.full((n,), r)               # [r, r, r, r, r]
geom = a * torch.cumprod(ratios, dim=0)
print(geom)
# tensor([1.0000, 0.5000, 0.2500, 0.1250, 0.0625])

2. 概率分布的累积乘积

在构建离散分布的乘积模型时,用累乘来得到联合概率:

python 复制代码
probs = torch.tensor([0.2, 0.3, 0.5])
# 标准化(确保和为1)
probs = probs / probs.sum()
# 获取依次乘积(注意:乘积非累加,因此并非 CDF)
joint = torch.cumprod(probs, dim=0)
print(joint)
# tensor([0.2000, 0.0600, 0.0300])

3. 模型门控或权重衰减

在 RNN、Transformer 等模型中,若需要对前 n 层或时间步做指数衰减,可用累积乘积计算衰减系数:

python 复制代码
decay_rates = torch.linspace(0.9, 0.5, steps=4)  # 每层不同衰减率
coeffs = torch.cumprod(decay_rates, dim=0)      # 累积得到层间总衰减
print(coeffs)
# tensor([0.9000, 0.7200, 0.5040, 0.2520])

五、进阶示例:预分配 out 张量

为了在高性能场景下避免额外内存分配,可以先分配好输出张量,再将结果写入:

python 复制代码
x = torch.arange(1, 1001, dtype=torch.float32)
out = torch.empty_like(x)
torch.cumprod(x, dim=0, out=out)
print(out[:5])  # tensor([1., 2., 6., 24., 120.])

六、小结

  • 功能torch.cumprod 沿指定维度计算输入张量的累计乘积,返回新张量。

  • 关键参数

    • dim:累积轴;
    • dtype:避免整数溢出/提升精度;
    • out:预分配输出提高性能。
  • 常见应用

    1. 几何序列生成;
    2. 概率分布乘积;
    3. 模型门控/权重衰减;
    4. 其它需要"前缀乘积"场景。
相关推荐
心灵彼岸-诗和远方6 分钟前
芯片生态链深度解析(三):芯片设计篇——数字文明的造物主战争
人工智能·制造
DpHard13 分钟前
Vscode 配置python调试环境
ide·vscode·python
小蜗笔记16 分钟前
显卡、Cuda和pytorch兼容问题
人工智能·pytorch·python
高建伟-joe27 分钟前
内容安全:使用开源框架Caffe实现上传图片进行敏感内容识别
人工智能·python·深度学习·flask·开源·html5·caffe
Cloud Traveler43 分钟前
迁移学习:解锁AI高效学习与泛化能力的密钥
人工智能·学习·迁移学习
IT_xiao小巫44 分钟前
AI 实践探索:辅助生成测试用例
人工智能·测试用例
一切皆有可能!!1 小时前
ChromaDB 向量库优化技巧实战
人工智能·语言模型
星川皆无恙1 小时前
大模型学习:Deepseek+dify零成本部署本地运行实用教程(超级详细!建议收藏)
大数据·人工智能·学习·语言模型·架构
观测云1 小时前
观测云产品更新 | 安全监测、事件中心、仪表板AI智能分析等
人工智能·安全
AIGC方案1 小时前
2025 AI如何重构网络安全产品
人工智能·web安全·重构