Pytorch中的直方图

1. 直方图的基础概念

直方图(Histogram) 是一种用于展示数据分布的统计图表:

  • 它将数据划分为若干个连续的区间(bin/组距)
  • 统计每个区间内包含的数据点数量(频数)或占比(频率);
  • 以"区间"为横轴、"频数/频率"为纵轴绘制柱状图,直观反映数据的分布特征(如集中趋势、离散程度、峰值位置等)。

2. PyTorch 中计算直方图:torch.histc()

PyTorch 提供 torch.histc() 函数计算张量的直方图,核心参数:

参数 含义
input 输入张量(需展平为一维)
bins 区间数量(默认10)
min 区间最小值(默认取输入张量最小值)
max 区间最大值(默认取输入张量最大值)
步骤1:PyTorch 计算直方图数据
python 复制代码
import torch
import matplotlib.pyplot as plt  # 用于可视化

# 1. 生成测试数据:模拟一组服从正态分布的随机数
# 均值=5,标准差=2,生成1000个数据点
torch.manual_seed(42)  # 固定随机种子,结果可复现
data = torch.normal(mean=5, std=2, size=(1000,))

# 2. 使用torch.histc()计算直方图
bins = 10  # 划分10个区间
hist = torch.histc(data, bins=bins)
# 计算区间边界(min到max均分bins个区间)
min_val = data.min().item()
max_val = data.max().item()
bin_edges = torch.linspace(min_val, max_val, bins + 1)  # 10个区间对应11个边界点

# 打印结果
print("数据范围:", min_val, "~", max_val)
print("区间边界:", bin_edges.numpy())
print("每个区间的频数:", hist.numpy())
步骤2:可视化直方图(直观展示)

在上述代码基础上,添加可视化代码:

python 复制代码
# 3. 绘制直方图
plt.figure(figsize=(10, 6))
# 方法1:用PyTorch计算的频数绘制
plt.bar(
    x=[(bin_edges[i] + bin_edges[i+1])/2 for i in range(bins)],  # 每个柱子的中心位置
    height=hist.numpy(),  # 柱子高度(频数)
    width=(max_val - min_val)/bins * 0.8,  # 柱子宽度
    color="#1f77b4",
    alpha=0.7,
    edgecolor="black"
)

# 方法2(更简单):直接用matplotlib绘制原始数据的直方图(验证结果)
# plt.hist(data.numpy(), bins=10, color="#ff7f0e", alpha=0.5, edgecolor="black")

# 添加标签和标题
plt.xlabel("数据区间", fontsize=12)
plt.ylabel("频数(数据点数量)", fontsize=12)
plt.title("PyTorch 数据直方图(正态分布)", fontsize=14)
plt.grid(axis="y", alpha=0.3)
plt.show()

3. 运行结果说明

(1)控制台输出(示例)
复制代码
数据范围: -2.1432 ~ 10.5997
区间边界: [-2.1432  -0.8776   0.388    1.6536   2.9192   4.1848   5.4504   6.716    7.9816   9.2472  10.5997]
每个区间的频数: [  8.  22.  50.  95. 156. 210. 198. 145.  85.  31.]
(2)可视化图表效果
  • 横轴是数据区间(如 -2.14 ~ -0.88、-0.88 ~ 0.39 等);
  • 纵轴是每个区间内的数据点数量;
  • 由于数据是正态分布(均值5),直方图呈现"中间高、两边低"的钟形,峰值出现在 4.18~5.45 区间(对应均值位置)。

4. 扩展案例:离散数据的直方图

如果数据是离散值(如分类标签),可使用 torch.bincount() 更适合:

python 复制代码
# 离散数据示例:0-4的分类标签,共100个样本
discrete_data = torch.randint(low=0, high=5, size=(100,))
# 计算每个离散值的频数(等价于bins=5、min=0、max=4的直方图)
counts = torch.bincount(discrete_data)

# 可视化
plt.figure(figsize=(8, 5))
plt.bar(x=range(5), height=counts.numpy(), color="#2ca02c", alpha=0.7, edgecolor="black")
plt.xlabel("离散值", fontsize=12)
plt.ylabel("频数", fontsize=12)
plt.title("离散数据直方图", fontsize=14)
plt.xticks(range(5))  # 横轴刻度为0-4
plt.grid(axis="y", alpha=0.3)
plt.show()

总结

  1. 直方图核心:将数据划分为连续区间,统计每个区间的频数,直观展示数据分布。
  2. PyTorch 实现
    • 连续数据:用 torch.histc() 计算直方图频数,需指定 bins/min/max
    • 离散数据:用 torch.bincount() 更高效(无需划分区间)。
  3. 可视化 :结合 matplotlib 将 PyTorch 计算的频数绘制成柱状图,能直观看到数据的分布特征(如正态分布的钟形、离散数据的分布比例)。
相关推荐
哈__2 小时前
CANN多模型并发部署方案
人工智能·pytorch
DeniuHe3 小时前
Pytorch中的众数
人工智能·pytorch·python
DeniuHe13 小时前
torch.distribution函数详解
pytorch
退休钓鱼选手16 小时前
[ Pytorch教程 ] 神经网络的基本骨架 torch.nn -Neural Network
pytorch·深度学习·神经网络
DeniuHe17 小时前
用 PyTorch 库创建了一个随机张量,并演示了多种张量取整和分解操作
pytorch
Network_Engineer21 小时前
从零手写LSTM:从门控原理到PyTorch源码级实现
人工智能·pytorch·lstm
多恩Stone1 天前
【3D-AICG 系列-1】Trellis v1 和 Trellis v2 的区别和改进
人工智能·pytorch·python·算法·3d·aigc
2501_901147831 天前
PyTorch DDP官方文档学习笔记(核心干货版)
pytorch·笔记·学习·算法·面试
铁手飞鹰1 天前
[深度学习]常用的库与操作
人工智能·pytorch·python·深度学习·numpy·scikit-learn·matplotlib