在深度学习的数据处理中经常需要统计或筛选 张量(Tensor) 中的唯一值,比如去重、统计类别数量、计算唯一标签数等。
PyTorch 提供了一个非常方便的函数 ------ torch.unique(),可以轻松完成这些操作。
本文将带你深入了解 torch.unique() 的用法、参数、返回值以及实际应用场景。
文章目录
-
- [一、什么是 `torch.unique()`?](#一、什么是
torch.unique()?) - 二、函数语法
- 三、参数说明
- 四、基本用法
-
- [🎯 示例 1:基础去重](#🎯 示例 1:基础去重)
- [🎯 示例 2:不排序](#🎯 示例 2:不排序)
- 五、返回索引与计数
-
- [🎯 示例 3:`return_inverse`](#🎯 示例 3:
return_inverse) - [🎯 示例 4:`return_counts`](#🎯 示例 4:
return_counts) - [🎯 示例 5:同时返回多个结果](#🎯 示例 5:同时返回多个结果)
- [🎯 示例 3:`return_inverse`](#🎯 示例 3:
- [六、按维度去重(dim 参数)](#六、按维度去重(dim 参数))
-
- [🎯 示例 6:按行去重](#🎯 示例 6:按行去重)
- [🎯 示例 7:按列去重](#🎯 示例 7:按列去重)
- [七、`torch.unique()` 与 NumPy 对比](#七、
torch.unique()与 NumPy 对比) - 八、实际应用场景
-
- [1. 分类问题中统计类别数量](#1. 分类问题中统计类别数量)
- [2. 计算样本分布(类别频率)](#2. 计算样本分布(类别频率))
- [3. 在图像分割中统计像素类别](#3. 在图像分割中统计像素类别)
- [⚠️ 九、注意事项](#⚠️ 九、注意事项)
- [📚 参考资料](#📚 参考资料)
- [一、什么是 `torch.unique()`?](#一、什么是
一、什么是 torch.unique()?
torch.unique() 是 PyTorch 中的一个去重函数,用于返回张量中所有的唯一元素(unique elements)。
它类似于 Python 的 set() 或 NumPy 的 np.unique(),但专为 GPU 加速的张量操作 设计。
二、函数语法
python
torch.unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None)
三、参数说明
| 参数 | 类型 | 说明 |
|---|---|---|
input |
Tensor |
输入张量 |
sorted |
bool |
是否对结果排序(默认 True) |
return_inverse |
bool |
是否返回原张量中每个值在唯一值列表中的索引 |
return_counts |
bool |
是否返回每个唯一值的出现次数 |
dim |
int 或 None |
按指定维度去重,默认对整个张量去重 |
四、基本用法
🎯 示例 1:基础去重
python
import torch
x = torch.tensor([1, 2, 2, 3, 3, 3])
unique_x = torch.unique(x)
print(unique_x)
输出:
plain
tensor([1, 2, 3])
✅ 结果去除了重复值,并自动排序。
🎯 示例 2:不排序
python
x = torch.tensor([3, 2, 1, 3, 2])
unique_x = torch.unique(x, sorted=False)
print(unique_x)
输出:
plain
tensor([3, 2, 1])
当 sorted=False 时,结果的顺序与首次出现的顺序一致。
五、返回索引与计数
🎯 示例 3:return_inverse
return_inverse=True 会返回一个索引张量,表示原张量中每个元素在唯一值(即新张量)中的位置。
python
x = torch.tensor([2, 1, 2, 3])
u, inv = torch.unique(x, return_inverse=True)
print(u)
print(inv)
输出:
plain
tensor([1, 2, 3])
tensor([1, 0, 1, 2])
解释:
- 唯一值为
[1, 2, 3] - 原数组
[2, 1, 2, 3]中:- 第一个元素 2 → 索引 1
- 第二个元素 1 → 索引 0
- 第三个元素 2 → 索引 1
- 第四个元素 3 → 索引 2
🎯 示例 4:return_counts
return_counts=True 会返回每个唯一值出现的次数。
python
x = torch.tensor([1, 2, 2, 3, 3, 3])
u, counts = torch.unique(x, return_counts=True)
print(u)
print(counts)
输出:
plain
tensor([1, 2, 3])
tensor([1, 2, 3])
表示:
- 值 1 出现 1 次
- 值 2 出现 2 次
- 值 3 出现 3 次
🎯 示例 5:同时返回多个结果
你可以同时返回 unique 值、inverse 索引和计数:
python
x = torch.tensor([1, 2, 2, 3, 3, 3])
u, inv, counts = torch.unique(x, return_inverse=True, return_counts=True)
print(u)
print(inv)
print(counts)
输出:
plain
tensor([1, 2, 3])
tensor([0, 1, 1, 2, 2, 2])
tensor([1, 2, 3])
六、按维度去重(dim 参数)
默认情况下,torch.unique() 会将张量展开成一维后去重。
但如果你希望在特定维度上去重(如按行或按列),可以使用 dim 参数。
🎯 示例 6:按行去重
python
x = torch.tensor([[1, 2],
[1, 2],
[3, 4]])
unique_rows = torch.unique(x, dim=0)
print(unique_rows)
输出:
plain
tensor([[1, 2],
[3, 4]])
表示第 1、2 行重复,只保留一个。
🎯 示例 7:按列去重
python
x = torch.tensor([[1, 1, 3],
[2, 2, 4]])
unique_cols = torch.unique(x, dim=1)
print(unique_cols)
输出:
plain
tensor([[1, 3],
[2, 4]])
七、torch.unique() 与 NumPy 对比
| 功能 | PyTorch (torch.unique) |
NumPy (np.unique) |
|---|---|---|
| 默认排序 | ✅ 是 | ✅ 是 |
| 支持 GPU | ✅ 是 | ❌ 否 |
| 返回 inverse 索引 | ✅ 是 | ✅ 是 |
| 返回 counts | ✅ 是 | ✅ 是 |
| 按维度去重 | ✅ 是(dim) |
❌ 不直接支持 |
| 性能 | 高(GPU 支持) | 仅 CPU |
八、实际应用场景
1. 分类问题中统计类别数量
python
labels = torch.tensor([0, 1, 0, 2, 2, 1, 3])
classes = torch.unique(labels)
print(f"共有 {len(classes)} 个类别: {classes.tolist()}")
输出:
plain
共有 4 个类别: [0, 1, 2, 3]
2. 计算样本分布(类别频率)
python
labels = torch.tensor([0, 1, 0, 2, 2, 1, 3])
u, counts = torch.unique(labels, return_counts=True)
for c, cnt in zip(u.tolist(), counts.tolist()):
print(f"类别 {c}: {cnt} 个样本")
输出:
plain
类别 0: 2 个样本
类别 1: 2 个样本
类别 2: 2 个样本
类别 3: 1 个样本
3. 在图像分割中统计像素类别
例如在语义分割任务中,计算 mask 图像中有多少个不同的像素类别:
python
mask = torch.randint(0, 5, (256, 256)) # 随机生成类别标签
num_classes = len(torch.unique(mask))
print(f"图像中共有 {num_classes} 个类别")
⚠️ 九、注意事项
torch.unique()** 默认会对结果排序**,如果在意性能,可以设置sorted=False。- 对高维张量使用
dim去重时,必须保证该维度的所有元素形状一致。 - 对大张量使用
return_counts或return_inverse时可能会消耗更多显存。
📚 参考资料

