Pytorch中的直方图

1. 直方图的基础概念

直方图(Histogram) 是一种用于展示数据分布的统计图表:

  • 它将数据划分为若干个连续的区间(bin/组距)
  • 统计每个区间内包含的数据点数量(频数)或占比(频率);
  • 以"区间"为横轴、"频数/频率"为纵轴绘制柱状图,直观反映数据的分布特征(如集中趋势、离散程度、峰值位置等)。

2. PyTorch 中计算直方图:torch.histc()

PyTorch 提供 torch.histc() 函数计算张量的直方图,核心参数:

参数 含义
input 输入张量(需展平为一维)
bins 区间数量(默认10)
min 区间最小值(默认取输入张量最小值)
max 区间最大值(默认取输入张量最大值)
步骤1:PyTorch 计算直方图数据
python 复制代码
import torch
import matplotlib.pyplot as plt  # 用于可视化

# 1. 生成测试数据:模拟一组服从正态分布的随机数
# 均值=5,标准差=2,生成1000个数据点
torch.manual_seed(42)  # 固定随机种子,结果可复现
data = torch.normal(mean=5, std=2, size=(1000,))

# 2. 使用torch.histc()计算直方图
bins = 10  # 划分10个区间
hist = torch.histc(data, bins=bins)
# 计算区间边界(min到max均分bins个区间)
min_val = data.min().item()
max_val = data.max().item()
bin_edges = torch.linspace(min_val, max_val, bins + 1)  # 10个区间对应11个边界点

# 打印结果
print("数据范围:", min_val, "~", max_val)
print("区间边界:", bin_edges.numpy())
print("每个区间的频数:", hist.numpy())
步骤2:可视化直方图(直观展示)

在上述代码基础上,添加可视化代码:

python 复制代码
# 3. 绘制直方图
plt.figure(figsize=(10, 6))
# 方法1:用PyTorch计算的频数绘制
plt.bar(
    x=[(bin_edges[i] + bin_edges[i+1])/2 for i in range(bins)],  # 每个柱子的中心位置
    height=hist.numpy(),  # 柱子高度(频数)
    width=(max_val - min_val)/bins * 0.8,  # 柱子宽度
    color="#1f77b4",
    alpha=0.7,
    edgecolor="black"
)

# 方法2(更简单):直接用matplotlib绘制原始数据的直方图(验证结果)
# plt.hist(data.numpy(), bins=10, color="#ff7f0e", alpha=0.5, edgecolor="black")

# 添加标签和标题
plt.xlabel("数据区间", fontsize=12)
plt.ylabel("频数(数据点数量)", fontsize=12)
plt.title("PyTorch 数据直方图(正态分布)", fontsize=14)
plt.grid(axis="y", alpha=0.3)
plt.show()

3. 运行结果说明

(1)控制台输出(示例)
复制代码
数据范围: -2.1432 ~ 10.5997
区间边界: [-2.1432  -0.8776   0.388    1.6536   2.9192   4.1848   5.4504   6.716    7.9816   9.2472  10.5997]
每个区间的频数: [  8.  22.  50.  95. 156. 210. 198. 145.  85.  31.]
(2)可视化图表效果
  • 横轴是数据区间(如 -2.14 ~ -0.88、-0.88 ~ 0.39 等);
  • 纵轴是每个区间内的数据点数量;
  • 由于数据是正态分布(均值5),直方图呈现"中间高、两边低"的钟形,峰值出现在 4.18~5.45 区间(对应均值位置)。

4. 扩展案例:离散数据的直方图

如果数据是离散值(如分类标签),可使用 torch.bincount() 更适合:

python 复制代码
# 离散数据示例:0-4的分类标签,共100个样本
discrete_data = torch.randint(low=0, high=5, size=(100,))
# 计算每个离散值的频数(等价于bins=5、min=0、max=4的直方图)
counts = torch.bincount(discrete_data)

# 可视化
plt.figure(figsize=(8, 5))
plt.bar(x=range(5), height=counts.numpy(), color="#2ca02c", alpha=0.7, edgecolor="black")
plt.xlabel("离散值", fontsize=12)
plt.ylabel("频数", fontsize=12)
plt.title("离散数据直方图", fontsize=14)
plt.xticks(range(5))  # 横轴刻度为0-4
plt.grid(axis="y", alpha=0.3)
plt.show()

总结

  1. 直方图核心:将数据划分为连续区间,统计每个区间的频数,直观展示数据分布。
  2. PyTorch 实现
    • 连续数据:用 torch.histc() 计算直方图频数,需指定 bins/min/max
    • 离散数据:用 torch.bincount() 更高效(无需划分区间)。
  3. 可视化 :结合 matplotlib 将 PyTorch 计算的频数绘制成柱状图,能直观看到数据的分布特征(如正态分布的钟形、离散数据的分布比例)。
相关推荐
盼小辉丶3 天前
PyTorch实战(30)——使用TorchScript和ONNX导出通用PyTorch模型
人工智能·pytorch·深度学习·模型部署
封奚泽优3 天前
使用mmdetection项目进行训练记录
pytorch·python·cuda·mmdetection·mmcv
tony3653 天前
pytorch分布式训练解释
人工智能·pytorch·分布式
weixin_贾3 天前
深度学习基础理论与 PyTorch 实战 —— 从传统机器学习到前沿模型全攻略
pytorch·深度学习·机器学习
大连好光景3 天前
PyTorch深度学习----优化器
pytorch·深度学习·学习
多恩Stone4 天前
【3D-AICG 系列-11】Trellis 2 的 Shape VAE 训练流程梳理
人工智能·pytorch·算法·3d·aigc
隔壁大炮5 天前
08. PyTorch_张量基本创建方式
人工智能·pytorch·python
隔壁大炮5 天前
07. PyTorch框架简介
人工智能·pytorch·python
大鹏的NLP博客5 天前
Rust + PyTorch 实现 BGE 向量检索系统
人工智能·pytorch·rust
勾股导航6 天前
蚁群优化算法
人工智能·pytorch·python