pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。

函数签名

复制代码
torch.nn.functional.one_hot(tensor, num_classes=-1)

参数

  1. tensor:

    • 输入的整数张量。该张量的每个元素都表示一个类别索引。
    • tensor 的数据类型必须是整数类型(如 torch.LongTensortorch.IntTensor)。
  2. num_classes:

    • 输出独热编码向量的长度,即类别的总数。如果设置为默认值 -1,则 num_classes 会自动设置为输入张量中最大值加1,即 max(tensor) + 1
    • 如果指定 num_classes,生成的每个独热向量的长度就是 num_classes,即使某些类别索引可能小于该值。

输出

  • 输出是一个新张量,其中输入张量的每个整数都被转换为一个独热编码向量。
  • 输出张量的形状为:(*input_shape, num_classes),即在输入张量的最后增加一个维度,代表类别的独热编码。

独热编码示例

独热编码是指在一个向量中,只有一个位置是1,其余位置都是0。例如,如果有三个类别,类别0可以表示为 [1, 0, 0],类别1 表示为 [0, 1, 0],类别2 表示为 [0, 0, 1]

示例

示例 1:简单独热编码
复制代码
import torch
import torch.nn.functional as F

# 假设有类别索引 [0, 1, 2]
labels = torch.tensor([0, 1, 2])
one_hot = F.one_hot(labels, num_classes=3)

print(one_hot)

输出:

复制代码
tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]])

在这里,类别索引 [0, 1, 2] 分别被编码为独热向量 [1, 0, 0], [0, 1, 0][0, 0, 1]

示例 2:自定义类别数量
复制代码
# 输入类别索引为 [0, 1, 4]
labels = torch.tensor([0, 1, 4])
one_hot = F.one_hot(labels, num_classes=5)

print(one_hot)

输出:

复制代码
tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 0, 1]])

即使 labels 中最大值是 4,指定了 num_classes=5,独热向量的长度为 5。

示例 3:多维输入
复制代码
# 输入为二维张量
labels = torch.tensor([[0, 1], [2, 3]])
one_hot = F.one_hot(labels, num_classes=4)

print(one_hot)

输出:

复制代码
tensor([[[1, 0, 0, 0],
         [0, 1, 0, 0]],

        [[0, 0, 1, 0],
         [0, 0, 0, 1]]])

输出张量的形状为 (2, 2, 4),即在输入形状 (2, 2) 的基础上,在最后增加了一个维度来表示类别的独热编码。

应用场景

  1. 分类任务: 在神经网络的分类任务中,通常需要将类别标签转换为独热编码。例如在多分类问题中,将标签转换为独热编码后,可以与交叉熵损失函数配合使用。

  2. 序列数据处理: 在自然语言处理任务中,可以使用独热编码将词汇表中的每个单词转换为独热向量,表示该单词在词汇表中的位置。

  3. 距离计算: 在某些算法中,使用独热编码表示类别或索引可以帮助计算不同类别或位置之间的距离。

总结

torch.nn.functional.one_hot 是一个简单但强大的工具,用于将整数标签或类别索引转换为独热编码。它通常用于分类问题的标签预处理,特别是在多类别分类任务中非常有用。

相关推荐
西柚小萌新22 分钟前
【深入浅出PyTorch】--9.使用ONNX进行部署并推理
人工智能·pytorch·python
LDG_AGI24 分钟前
【推荐系统】深度学习训练框架(十):PyTorch Dataset—PyTorch数据基石
人工智能·pytorch·分布式·python·深度学习·机器学习
长桥夜波37 分钟前
机器学习日报23
人工智能·机器学习
roman_日积跬步-终至千里39 分钟前
【模式识别与机器学习(9)】数据预处理-第一部分:数据基础认知
人工智能·机器学习
AI人工智能+1 小时前
表格识别技术:完整还原银行对账单表格结构、逻辑关系及视觉布局,大幅提升使处理速度提升
人工智能·深度学习·ocr·表格识别
胡乱编胡乱赢1 小时前
Decaf攻击:联邦学习中的数据分布分解攻击
人工智能·深度学习·机器学习·联邦学习·decaf攻击
远上寒山1 小时前
DINO 系列(v1/v2/v3)之二:DINOv2 原理的详细介绍
人工智能·深度学习·自监督·dinov2·自蒸馏·dino系列
_codemonster1 小时前
深度学习实战(基于pytroch)系列(四十)长短期记忆(LSTM)从零开始实现
人工智能·深度学习·lstm
青云交2 小时前
Java 大视界 -- Java 大数据机器学习模型在自然语言处理中的跨语言信息检索与知识融合
机器学习·自然语言处理·java 大数据·知识融合·跨语言信息检索·多语言知识图谱·低资源语言处理
_Twink1e2 小时前
【HCIA-AIV4.0】2025题库+解析(二)
人工智能·深度学习·机器学习