PyTorch 的统计三剑客:bucketize, bincount 与 histogram

在深度学习(尤其是计算机视觉)的开发过程中,我们经常面临这样的需求:

  • "这个特征图里,有多少像素的值超过了 0.5?"
  • "我想把连续的置信度分数映射成离散的等级(高/中/低)。"
  • "在这个 Batch 的分割 Mask 中,每一类的像素点各有多少?"

为了解决这些问题,PyTorch 提供了三个看似相似、实则各司其职的函数:torch.bucketizetorch.bincounttorch.histogram。很多人容易混淆它们,或者在只需要简单计数时误用了复杂的直方图函数。本文将从底层逻辑、函数关系到实战场景,带你彻底梳理这三个函数。

概念速览

我们可以用三个角色来比喻这三个函数:

  1. torch.bucketize (寻址员):负责"定位"。它不关心总数,只告诉你每个元素属于哪个区间。
  2. torch.bincount (点票员):负责"计数"。它只处理整数,负责统计每个整数 ID 出现了多少次。
  3. torch.histogram (统计局):负责"宏观概览"。它是前两者的集大成者,直接把连续数据变成统计分布。

功能介绍

torch.bucketize ------ 离散化映射

它的核心作用是将连续的浮点数映射为离散的索引:

  • 输入:任意数值张量。
  • 核心参数:boundaries (一维单调递增序列)。
  • 输出:与输入形状相同的 Int64 张量(存储的是边界的索引)。

加入边界是 [0, 10, 20]

  • 输入 5 -> 落入 [0, 10) -> 返回索引 1
  • 输入 25 -> 大于 20 -> 返回索引 3 (越界)
python 复制代码
import torch

boundaries = torch.tensor([0.0, 3.0, 5.0])
values = torch.tensor([1.0, 3.0, 6.0])

# right=False: boundaries[i-1] < input[m][n]...[l][x] <= boundaries[i]
# right=True: boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]
# 1.0, 3.0 在 (0.0, 3.0] -> index 1
# 6.0 在 5.0 之后   -> index 3
indices = torch.bucketize(values, boundaries)
# 结果: tensor([1, 1, 3])
print(indices)

应用场景:

  • Embedding Lookup 的前置步骤、将置信度分级。
  • 特征量化中,对特征基于算好的量化点进行量化处理。

torch.bincount ------ 整数高效计数

它是 PyTorch 中处理非负整数统计的高效工具。

  • 输入:必须是 1D 非负整数张量。
  • 输出:1D 张量,长度默认为 max(input) + 1
  • 核心参数:
    • minlength:强制输出长度至少为 N。防止因为某个 Batch 缺少某类数据导致张量形状不固定。
    • weights:如果指定,它计算的不再是次数,而是对应位置权重的累加和。
python 复制代码
data = torch.tensor([1, 1, 3, 1])

# 统计频次
# 0出现0次, 1出现3次, 2出现0次, 3出现1次
counts = torch.bincount(data, minlength=5)
# 结果: tensor([0, 3, 0, 1, 0])
print(counts)

应用场景:

  • 语义分割计算 mIoU(统计混淆矩阵)
  • 计算加权平均值。

这里要注意与torch.unique的区别。两者的核心区别就在于"位置"与"数值"的对应关系:

  1. torch.bincount(稠密/直接映射):
    • 位置即数值:输出张量的第 i i i个位置(索引),专门用来存整数 i i i出现的次数。
    • 填补空缺:如果你的输入只有 [1, 100],它必须生成一个长度为 101 的向量。为了存那个 100,它被迫把中间的 2~99 全部填 0。
    • 内存隐患:如果数据中有个极大值(比如 ID=1,000,000),哪怕只有一个数据,它也会申请一百万长度的数组,非常浪费内存。
  2. torch.unique(稀疏/压缩映射):
    • 只存存在的:它不关心数值对应的索引位置,而是返回两个数组:values(有什么数)和 counts(出现了几次)。
    • 紧凑:对于输入 [1, 100],它只返回 values=[1, 100]counts=[1, 1],完全没有中间的"空窗期"。

所以整体而言,

  • 如果是分类任务(类别 0-10),数据稠密且范围固定 用 bincount(极快)。
  • 如果是用户ID统计(ID可能是 9527),数据稀疏且范围巨大 ,用unique(省内存)。

torch.histogram ------ 端到端直方图

这是最直观的统计工具。

  • 输入:浮点数张量。
  • 输出:hist (频次) 和 bin_edges (边界)。
  • 特点:它内部其实隐式地执行了"分桶"和"计数"两个步骤。
python 复制代码
data = torch.randn(1000)
# 将数据分为 10 个区间统计
hist, edges = torch.histogram(data, bins=10)
# tensor([  4.,  14.,  64., 131., 194., 216., 199., 112.,  53.,  13.])
print(hist)
# tensor([-3.1851, -2.5927, -2.0003, -1.4079, -0.8154, -0.2230,  0.3694,  0.9619,
#          1.5543,  2.1467,  2.7391])
print(edges)

应用场景:

  • 数据分布的可视化:查看特征、 Loss 或梯度的分布情况。

三者关系的本质:A + B ≈ C

这三个函数之间存在着深层的推导关系:
Histogram ≈ bincount ( bucketize ( Input ) ) \text{Histogram} \approx \text{bincount}(\text{bucketize}(\text{Input})) Histogram≈bincount(bucketize(Input))

  1. 先用 bucketize 把连续值变成离散的桶索引。
  2. 再用 bincount 统计这些索引出现的次数。
  3. 这就等于 histogram 的结果。

可视化验证案例:学生成绩分级

为了证明这一点,我们模拟一个"学生成绩统计"的任务:将 0-100 分的成绩分为 E, D, C, B, A 五个等级。

python 复制代码
import matplotlib.pyplot as plt
import numpy as np
import torch

# 设置随机种子保证可复现
torch.manual_seed(42)

# ==========================================
# 1. 数据准备
# ==========================================
# 模拟 1000 个学生的成绩,符合正态分布,均值 75,标准差 12
scores = torch.randn(1000) * 12 + 75
scores = torch.clamp(scores, 0, 100)  # 截断在 0-100 分之间

# 定义分数段边界 (5个区间需要6个边界)
# E: [0, 60), D: [60, 70), C: [70, 80), B: [80, 90), A: [90, 100]
boundaries = torch.tensor([0.0, 60.0, 70.0, 80.0, 90.0, 100.0])
grade_names = ["E (<60)", "D (60-70)", "C (70-80)", "B (80-90)", "A (90+)"]

# ==========================================
# 2. 方法 A: Bucketize + Bincount
# ==========================================
# A1. bucketize: 确定每个分数属于哪个桶 (返回索引)
# right=False 表示左闭右开 [ ),这是常用的逻辑
# 注意: bucketize 可能返回 0 (小于第一个边界) 或 len (大于最后一个边界)
# 在此案例中,因为我们要模拟 Histogram 的 bin 逻辑,通常期望索引范围是 1 到 len(boundaries)-1
# 为了对齐,我们做一点简单的索引调整逻辑
bin_indices = torch.bucketize(scores, boundaries, right=False)  # 左闭右开

# 修正索引:bucketize 如果值 >= 0 且 < 60,会返回索引 1 (因为 0是第一个边界)
# bin_indices 的值域主要会在 1, 2, 3, 4, 5 对应五个区间
# 我们希望 0-60 是第0个桶。所以通常 index - 1。
bin_indices = torch.clamp(bin_indices, 1, len(boundaries) - 1) - 1

# A2. bincount: 统计每个桶里有多少人
# minlength=5 保证即使没人得 A 也能显示 0
counts_from_bincount = torch.bincount(bin_indices, minlength=len(boundaries) - 1)

print(f"方法A (Bucketize+Bincount) 计数: {counts_from_bincount}")

# ==========================================
# 3. 方法 B: 直接法 (Histogram)
# ==========================================
counts_from_hist, _ = torch.histogram(scores, bins=boundaries)

print(f"方法B (Histogram) 计数:          {counts_from_hist}")

# 验证两者是否相等
assert torch.allclose(counts_from_bincount.float(), counts_from_hist), (
    "统计结果不一致!"
)

# ==========================================
# 4. 可视化对比
# ==========================================
plt.figure(figsize=(15, 6))

# 子图 1: 原始数据的分布 (Bucketize 的作用对象)
plt.subplot(1, 3, 1)
plt.title("Step 1: Raw Data & Boundaries (Bucketize Input)")
plt.hist(scores.numpy(), bins=50, color="lightgray", alpha=0.7, label="Raw Scores")
for b in boundaries:
    plt.axvline(b, color="r", linestyle="--", alpha=0.5)
plt.xlabel("Score")
plt.ylabel("Frequency")
plt.legend()
plt.text(
    10, 5, "Bucketize determines\nwhich interval\neach point belongs to", fontsize=9
)

# 子图 2: 离散化后的索引 (Bincount 的作用对象)
plt.subplot(1, 3, 2)
plt.title("Step 2: Discretized Indices (Bincount Input)")
# 这里展示 bin_indices 的值
unique_ids, counts = torch.unique(bin_indices, return_counts=True)
plt.bar(unique_ids.numpy(), counts.numpy(), color="skyblue", edgecolor="black")
plt.xticks(range(5), grade_names, rotation=45)
plt.xlabel("Grade Class (Index)")
plt.ylabel("Count")
plt.text(-0.5, max(counts) * 0.9, "Bincount simply counts\nthese integers", fontsize=9)

# 子图 3: 最终直方图 (Histogram 的直接输出)
plt.subplot(1, 3, 3)
plt.title("Step 3: Final Histogram (Result)")
plt.bar(
    range(5), counts_from_hist.numpy(), color="salmon", alpha=0.8, edgecolor="black"
)
plt.xticks(range(5), grade_names, rotation=45)
plt.xlabel("Grade Class")
plt.ylabel("Count")
plt.text(-0.5, max(counts) * 0.9, "Histogram does\nboth steps at once", fontsize=9)

plt.tight_layout()
plt.show()

该用哪个?

工程中,选择哪个函数取决于实际数据类型和最终目的:

你的需求 推荐函数 理由
数据是连续的 (Float),我想看分布图 torch.histogram 最简单,一行代码出结果。
数据是连续的,但我需要知道"谁在哪* torch.bucketize histogram 丢弃了空间位置信息,而 bucketize 保留了每个元素的索引。
数据是整数 (类别ID/Mask),我想统计个数 torch.bincount 速度最快,专为离散整数设计。
计算"加权"平均值或总和 torch.bincount(weights=...) 利用 weights 参数可以轻松实现 Scatter Add 的功能。
数据非常稀疏 (ID=1000000) torch.unique bincount 会产生巨大的全 0 向量,浪费显存;unique 只存有的值。
相关推荐
sunsunyu032 小时前
基于OpenCV的图像重复检测算法实战
python·计算机视觉
王解2 小时前
CTF学习
学习
AI视觉网奇2 小时前
PlayerStreaming 驱动audio2face 学习笔记
笔记·学习·ue5
码农三叔2 小时前
(8-3-02)自动驾驶中的无地图环境路径探索:D* Lite路径规划系统(2)
人工智能·python·机器人·自动驾驶·路径规划·d star lite
承渊政道2 小时前
Linux系统学习【Linux基础指令以及权限问题】
linux·服务器·学习
Engineer邓祥浩2 小时前
设计模式学习(11) 23-9 组合模式
学习·设计模式·组合模式
专注于大数据技术栈2 小时前
java学习--什么是线程安全和不安全
java·学习·安全
这儿有一堆花2 小时前
从文本到像素:AI图像生成的底层逻辑解析
人工智能·机器学习·计算机视觉
Engineer邓祥浩2 小时前
设计模式学习(13) 23-11 享元模式
学习·设计模式·享元模式