pytorch scatter_ 函数介绍

scatter_ 是 PyTorch 中的一个原地操作函数,用于在给定的索引处将某些值填充到张量的指定维度中。它的常见用途之一是将类别标签转换为 one-hot 编码,不过它也适用于其他场景,如在特定索引处更新张量的值。

scatter_ 函数的签名如下:

复制代码
scatter_(dim, index, src)
  • dim:指定操作的维度。即在这个维度上更新值。
  • index:包含索引的张量,指定要更新值的位置。
  • src:要填入的值。可以是一个标量(单个值),也可以是一个张量。

使用示例

1. 使用 scatter_ 实现 one-hot 编码

我们可以通过 scatter_ 将类别标签转换为 one-hot 编码。

代码示例:

复制代码
import torch

# logits 模拟网络输出 [batch_size, num_classes]
logits = torch.tensor([[2.0, 1.0, 0.1], [1.5, 2.5, 0.5]])  # 形状 [2, 3]

# target 是真实的类别标签 [batch_size]
target = torch.tensor([0, 2])  # 形状 [2]

# 创建一个与 logits 相同大小的全零张量
target_onehot = torch.zeros_like(logits)  # 形状 [2, 3]

# 使用 scatter_ 函数在第 1 维(类别维度)根据 target 的索引设置为 1
target_onehot.scatter_(1, target.view(-1, 1), 1)

print("one-hot 编码:")
print(target_onehot)

注:target.view(-1, 1) 中的 1 指的是将 target 张量的形状重新调整为 两维 ,并且使得第二个维度的大小固定为 1

输出:

复制代码
tensor([[1., 0., 0.],  # 第一个样本,类别为 0
        [0., 0., 1.]])  # 第二个样本,类别为 2
2. 使用 scatter_ 更新指定位置的值

你还可以用 scatter_ 在张量的指定位置填充任意值。这里是一个简单的例子,将特定索引的位置设置为自定义的数值:

代码示例:

复制代码
import torch

# 创建一个 3x3 的全零张量
tensor = torch.zeros(3, 3)

# 定义索引
index = torch.tensor([[0, 2, 1]])  # 每行对应位置的索引

# 要填入的值
src = torch.tensor([[5, 9, 7]])

# 使用 scatter_ 在第 1 维(列)填充 src 的值到指定索引位置
tensor.scatter_(1, index, src)

print("填充值后的张量:")
print(tensor)

输出:

复制代码
tensor([[5., 7., 9.],  # 在索引 0 处填 5,索引 1 处填 7,索引 2 处填 9
        [5., 7., 9.],
        [5., 7., 9.]])
3. scatter_ 与广播

scatter_ 支持广播机制。你可以使用一个标量值来替换指定的索引位置,也可以使用一个与 index 兼容的张量来填充不同的值。

代码示例:

复制代码
import torch

# 创建一个 4x3 的全零张量
tensor = torch.zeros(4, 3)

# 定义索引
index = torch.tensor([[0, 2, 1], [2, 1, 0], [1, 0, 2], [2, 1, 0]])  # 4x3 的索引

# 要填入的值
src = torch.tensor([5, 9, 7])  # 广播到每一行

# 使用 scatter_ 在第 1 维(列)填充 src 的值到指定索引位置
tensor.scatter_(1, index, src)

print("填充值后的张量:")
print(tensor)

输出:

复制代码
tensor([[5., 7., 9.],
        [7., 9., 5.],
        [9., 5., 7.],
        [7., 9., 5.]])

总结

scatter_ 函数可以根据指定的索引,在目标张量的某个维度上填充源张量或标量的值。它的常见应用包括:

  • one-hot 编码:将类别标签转换为 one-hot 格式。
  • 更新张量特定位置的值:可以根据索引在张量的某些特定位置填入新的值。

它灵活且高效,适合用于需要对张量的特定索引进行操作的场景。

相关推荐
R²AIN SUITE1 分钟前
快消零售AI转型:R²AIN SUITE如何破解效率困局
大数据·人工智能·产品运营
ONLYOFFICE4 分钟前
集成 ONLYOFFICE 与 AI 插件,为您的服务带来智能文档编辑器
人工智能·ai·编辑器·onlyoffice·文档编辑器·文档预览·文档协作
一个天蝎座 白勺 程序猿10 分钟前
GpuGeek全栈AI开发实战:从零构建企业级大模型生产管线(附完整案例)
人工智能·gpugeek
love530love12 分钟前
家用或办公 Windows 电脑玩人工智能开源项目配备核显的必要性(含 NPU 及显卡类型补充)
人工智能·windows·python·开源·电脑
深圳市青牛科技实业有限公司14 分钟前
D2203使用手册—高压、小电流LDO产品4.6V~36V、150mA
人工智能·单片机·嵌入式硬件·电动工具·工业散热风扇
shengjk118 分钟前
序列化和反序列化:从理论到实践的全方位指南
java·大数据·开发语言·人工智能·后端·ai编程
AI大模型顾潇19 分钟前
[特殊字符] 本地大模型编程实战(29):用大语言模型LLM查询图数据库NEO4J(2)
前端·数据库·人工智能·语言模型·自然语言处理·prompt·neo4j
2501_9153743534 分钟前
数据清洗的艺术:如何为AI模型准备高质量数据集?
人工智能·机器学习
山北雨夜漫步37 分钟前
机器学习 Day17 朴素贝叶斯算法-----概率论知识
人工智能·算法·机器学习
愚公搬代码1 小时前
【愚公系列】《Manus极简入门》038-数字孪生设计师:“虚实映射师”
人工智能·agi·ai agent·智能体·manus