1. 直方图的基础概念
直方图(Histogram) 是一种用于展示数据分布的统计图表:
- 它将数据划分为若干个连续的区间(bin/组距);
- 统计每个区间内包含的数据点数量(频数)或占比(频率);
- 以"区间"为横轴、"频数/频率"为纵轴绘制柱状图,直观反映数据的分布特征(如集中趋势、离散程度、峰值位置等)。
2. PyTorch 中计算直方图:torch.histc()
PyTorch 提供 torch.histc() 函数计算张量的直方图,核心参数:
| 参数 | 含义 |
|---|---|
input |
输入张量(需展平为一维) |
bins |
区间数量(默认10) |
min |
区间最小值(默认取输入张量最小值) |
max |
区间最大值(默认取输入张量最大值) |
步骤1:PyTorch 计算直方图数据
python
import torch
import matplotlib.pyplot as plt # 用于可视化
# 1. 生成测试数据:模拟一组服从正态分布的随机数
# 均值=5,标准差=2,生成1000个数据点
torch.manual_seed(42) # 固定随机种子,结果可复现
data = torch.normal(mean=5, std=2, size=(1000,))
# 2. 使用torch.histc()计算直方图
bins = 10 # 划分10个区间
hist = torch.histc(data, bins=bins)
# 计算区间边界(min到max均分bins个区间)
min_val = data.min().item()
max_val = data.max().item()
bin_edges = torch.linspace(min_val, max_val, bins + 1) # 10个区间对应11个边界点
# 打印结果
print("数据范围:", min_val, "~", max_val)
print("区间边界:", bin_edges.numpy())
print("每个区间的频数:", hist.numpy())
步骤2:可视化直方图(直观展示)
在上述代码基础上,添加可视化代码:
python
# 3. 绘制直方图
plt.figure(figsize=(10, 6))
# 方法1:用PyTorch计算的频数绘制
plt.bar(
x=[(bin_edges[i] + bin_edges[i+1])/2 for i in range(bins)], # 每个柱子的中心位置
height=hist.numpy(), # 柱子高度(频数)
width=(max_val - min_val)/bins * 0.8, # 柱子宽度
color="#1f77b4",
alpha=0.7,
edgecolor="black"
)
# 方法2(更简单):直接用matplotlib绘制原始数据的直方图(验证结果)
# plt.hist(data.numpy(), bins=10, color="#ff7f0e", alpha=0.5, edgecolor="black")
# 添加标签和标题
plt.xlabel("数据区间", fontsize=12)
plt.ylabel("频数(数据点数量)", fontsize=12)
plt.title("PyTorch 数据直方图(正态分布)", fontsize=14)
plt.grid(axis="y", alpha=0.3)
plt.show()
3. 运行结果说明
(1)控制台输出(示例)
数据范围: -2.1432 ~ 10.5997
区间边界: [-2.1432 -0.8776 0.388 1.6536 2.9192 4.1848 5.4504 6.716 7.9816 9.2472 10.5997]
每个区间的频数: [ 8. 22. 50. 95. 156. 210. 198. 145. 85. 31.]
(2)可视化图表效果
- 横轴是数据区间(如 -2.14 ~ -0.88、-0.88 ~ 0.39 等);
- 纵轴是每个区间内的数据点数量;
- 由于数据是正态分布(均值5),直方图呈现"中间高、两边低"的钟形,峰值出现在 4.18~5.45 区间(对应均值位置)。
4. 扩展案例:离散数据的直方图
如果数据是离散值(如分类标签),可使用 torch.bincount() 更适合:
python
# 离散数据示例:0-4的分类标签,共100个样本
discrete_data = torch.randint(low=0, high=5, size=(100,))
# 计算每个离散值的频数(等价于bins=5、min=0、max=4的直方图)
counts = torch.bincount(discrete_data)
# 可视化
plt.figure(figsize=(8, 5))
plt.bar(x=range(5), height=counts.numpy(), color="#2ca02c", alpha=0.7, edgecolor="black")
plt.xlabel("离散值", fontsize=12)
plt.ylabel("频数", fontsize=12)
plt.title("离散数据直方图", fontsize=14)
plt.xticks(range(5)) # 横轴刻度为0-4
plt.grid(axis="y", alpha=0.3)
plt.show()
总结
- 直方图核心:将数据划分为连续区间,统计每个区间的频数,直观展示数据分布。
- PyTorch 实现 :
- 连续数据:用
torch.histc()计算直方图频数,需指定bins/min/max; - 离散数据:用
torch.bincount()更高效(无需划分区间)。
- 连续数据:用
- 可视化 :结合
matplotlib将 PyTorch 计算的频数绘制成柱状图,能直观看到数据的分布特征(如正态分布的钟形、离散数据的分布比例)。