torch.nn.functional.one_hot
是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。
函数签名
torch.nn.functional.one_hot(tensor, num_classes=-1)
参数
-
tensor
:- 输入的整数张量。该张量的每个元素都表示一个类别索引。
tensor
的数据类型必须是整数类型(如torch.LongTensor
或torch.IntTensor
)。
-
num_classes
:- 输出独热编码向量的长度,即类别的总数。如果设置为默认值
-1
,则num_classes
会自动设置为输入张量中最大值加1,即max(tensor) + 1
。 - 如果指定
num_classes
,生成的每个独热向量的长度就是num_classes
,即使某些类别索引可能小于该值。
- 输出独热编码向量的长度,即类别的总数。如果设置为默认值
输出
- 输出是一个新张量,其中输入张量的每个整数都被转换为一个独热编码向量。
- 输出张量的形状为:
(*input_shape, num_classes)
,即在输入张量的最后增加一个维度,代表类别的独热编码。
独热编码示例
独热编码是指在一个向量中,只有一个位置是1,其余位置都是0。例如,如果有三个类别,类别0可以表示为 [1, 0, 0]
,类别1 表示为 [0, 1, 0]
,类别2 表示为 [0, 0, 1]
。
示例
示例 1:简单独热编码
import torch
import torch.nn.functional as F
# 假设有类别索引 [0, 1, 2]
labels = torch.tensor([0, 1, 2])
one_hot = F.one_hot(labels, num_classes=3)
print(one_hot)
输出:
tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
在这里,类别索引 [0, 1, 2]
分别被编码为独热向量 [1, 0, 0]
, [0, 1, 0]
和 [0, 0, 1]
。
示例 2:自定义类别数量
# 输入类别索引为 [0, 1, 4]
labels = torch.tensor([0, 1, 4])
one_hot = F.one_hot(labels, num_classes=5)
print(one_hot)
输出:
tensor([[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 0, 1]])
即使 labels
中最大值是 4,指定了 num_classes=5
,独热向量的长度为 5。
示例 3:多维输入
# 输入为二维张量
labels = torch.tensor([[0, 1], [2, 3]])
one_hot = F.one_hot(labels, num_classes=4)
print(one_hot)
输出:
tensor([[[1, 0, 0, 0],
[0, 1, 0, 0]],
[[0, 0, 1, 0],
[0, 0, 0, 1]]])
输出张量的形状为 (2, 2, 4)
,即在输入形状 (2, 2)
的基础上,在最后增加了一个维度来表示类别的独热编码。
应用场景
-
分类任务: 在神经网络的分类任务中,通常需要将类别标签转换为独热编码。例如在多分类问题中,将标签转换为独热编码后,可以与交叉熵损失函数配合使用。
-
序列数据处理: 在自然语言处理任务中,可以使用独热编码将词汇表中的每个单词转换为独热向量,表示该单词在词汇表中的位置。
-
距离计算: 在某些算法中,使用独热编码表示类别或索引可以帮助计算不同类别或位置之间的距离。
总结
torch.nn.functional.one_hot
是一个简单但强大的工具,用于将整数标签或类别索引转换为独热编码。它通常用于分类问题的标签预处理,特别是在多类别分类任务中非常有用。