Pytorch中统计学相关的函数

torch.mean()

计算输入张量中所有元素的平均值。可以指定维度进行沿该维度的均值计算,若不指定维度则返回所有元素的均值。

案例代码:

python 复制代码
import torch

x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
mean_all = torch.mean(x)  # 输出: 2.5
mean_dim0 = torch.mean(x, dim=0)  # 输出: tensor([2., 3.])
mean_dim1 = torch.mean(x, dim=1)  # 输出: tensor([1.5, 3.5])

torch.sum()

计算输入张量中所有元素的和。可以指定维度进行沿该维度的求和计算,若不指定维度则返回所有元素的和。

案例代码:

python 复制代码
x = torch.tensor([[1, 2], [3, 4]])
sum_all = torch.sum(x)  # 输出: 10
sum_dim0 = torch.sum(x, dim=0)  # 输出: tensor([4, 6])
sum_dim1 = torch.sum(x, dim=1)  # 输出: tensor([3, 7])

torch.prod()

计算输入张量中所有元素的乘积。可以指定维度进行沿该维度的乘积计算,若不指定维度则返回所有元素的乘积。

案例代码:

python 复制代码
x = torch.tensor([[1, 2], [3, 4]])
prod_all = torch.prod(x)  # 输出: 24
prod_dim0 = torch.prod(x, dim=0)  # 输出: tensor([3, 8])
prod_dim1 = torch.prod(x, dim=1)  # 输出: tensor([2, 12])

torch.max()

返回输入张量中的最大值。可以指定维度进行沿该维度的最大值计算,返回最大值和对应的索引。

案例代码:

python 复制代码
x = torch.tensor([[1, 5], [3, 2]])
max_all = torch.max(x)  # 输出: 5
max_dim0 = torch.max(x, dim=0)  # 返回: (tensor([3, 5]), tensor([1, 0]))
max_dim1 = torch.max(x, dim=1)  # 返回: (tensor([5, 3]), tensor([1, 0]))

torch.min()

返回输入张量中的最小值。可以指定维度进行沿该维度的最小值计算,返回最小值和对应的索引。

案例代码:

python 复制代码
x = torch.tensor([[1, 5], [3, 2]])
min_all = torch.min(x)  # 输出: 1
min_dim0 = torch.min(x, dim=0)  # 返回: (tensor([1, 2]), tensor([0, 1]))
min_dim1 = torch.min(x, dim=1)  # 返回: (tensor([1, 2]), tensor([0, 1]))

torch.argmax()

返回输入张量中最大值的索引。可以指定维度进行沿该维度的最大值索引计算,若不指定维度则返回展平后张量的最大值的索引。

案例代码:

python 复制代码
x = torch.tensor([[1, 5], [3, 2]])
argmax_all = torch.argmax(x)  # 输出: 1
argmax_dim0 = torch.argmax(x, dim=0)  # 输出: tensor([1, 0])
argmax_dim1 = torch.argmax(x, dim=1)  # 输出: tensor([1, 0])

torch.argmin()

返回输入张量中最小值的索引。可以指定维度进行沿该维度的最小值索引计算,若不指定维度则返回展平后张量的最小值的索引。

案例代码:

python 复制代码
x = torch.tensor([[1, 5], [3, 2]])
argmin_all = torch.argmin(x)  # 输出: 0
argmin_dim0 = torch.argmin(x, dim=0)  # 输出: tensor([0, 1])
argmin_dim1 = torch.argmin(x, dim=1)  # 输出: tensor([0, 1])

以下是PyTorch中常用统计函数的详细解释及案例代码:

torch.std()

计算张量的标准差。可选参数dim指定计算维度,unbiased决定是否使用无偏估计(默认True)。

python 复制代码
import torch
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
std_all = torch.std(x)  # 全局标准差
std_dim = torch.std(torch.randn(3,4), dim=1)  # 沿第1维计算

torch.var()

计算张量的方差。参数与std()类似,keepdim可保持输出维度。

python 复制代码
x = torch.tensor([[1,2],[3,4]], dtype=torch.float)
var_all = torch.var(x)  # 1.25
var_dim = torch.var(x, dim=0, keepdim=True)  # 沿列计算 [[2., 2.]]

torch.median()

返回中位数。dim指定维度时返回元组(值,索引)。

python 复制代码
x = torch.tensor([[3,1,4],[2,5,6]])
val_all = torch.median(x)  # 3.5
vals, inds = torch.median(x, dim=1)  # 每行中位数及索引

torch.mode()

返回众数及其索引。对多众数情况返回第一个遇到的。

python 复制代码
x = torch.tensor([1,2,2,3,3,3])
value, index = torch.mode(x)  # (3, 5)

torch.histc()

计算直方图。bins指定区间数,min/max设定范围。

python 复制代码
x = torch.randn(1000)
hist = torch.histc(x, bins=10, min=-3, max=3)  # 10个区间的计数

torch.bincount()

对非负整数张量计数。minlength设置最小输出长度。

python 复制代码
x = torch.tensor([0,1,1,2,2,2])
counts = torch.bincount(x)  # [1,2,3]
weights = torch.tensor([0.1,0.2,0.3,0.4,0.5,0.6])
weighted_counts = torch.bincount(x, weights=weights)  # [0.1, 0.5, 1.5]

注意:所有函数默认在浮点张量上操作,输入整数张量时需显式转换类型。对于空输入或无效参数会抛出异常。

相关推荐
在人间耕耘1 小时前
HarmonyOS Vision Kit 视觉AI实战:把官方 Demo 改造成一套能长期复用的组件库
人工智能·深度学习·harmonyos
homelook1 小时前
Transformer与电池管理系统(BMS)的结合是当前 智能电池管理 的前沿研究方向
人工智能·深度学习·transformer
多恩Stone2 小时前
【C++入门扫盲1】C++ 与 Python:类型、编译器/解释器与 CPU 的关系
开发语言·c++·人工智能·python·算法·3d·aigc
QQ4022054962 小时前
Python+django+vue3预制菜半成品配菜平台
开发语言·python·django
百锦再2 小时前
Django实现接口token检测的实现方案
数据库·python·django·sqlite·flask·fastapi·pip
QQ5110082852 小时前
python+springboot+django/flask的校园资料分享系统
spring boot·python·django·flask·node.js·php
QQ_19632884752 小时前
Python-flask框架西山区家政服务评价系统网站设计与开发-Pycharm django
python·pycharm·flask
ccLianLian2 小时前
强化学习·导论
深度学习
遥遥江上月2 小时前
Node.js + Stagehand + Python 部署
开发语言·python·node.js
B站计算机毕业设计超人2 小时前
计算机毕业设计Django+Vue.js音乐推荐系统 音乐可视化 大数据毕业设计 (源码+文档+PPT+讲解)
大数据·vue.js·hadoop·python·spark·django·课程设计