Transformer Masked loss原理精讲及其PyTorch逐行实现

Masked Loss 的核心原理是:在计算损失函数时,只考虑真实有意义的词元(token),而忽略掉为了数据对齐而填充的无意义的填充词元(padding token)。

这是重要的技术,可以确保模型专注于学习有意义的任务,并得到一个正确的性能评估。

1.原理精讲

为什么需要 Masked Loss?

在训练神经网络时,我们通常会用一个批次(batch)的数据进行训练,而不是一次只用一个样本。对于自然语言处理任务,我们会一次性处理多句话。但这些句子的长度都几乎不一样。

例如,我们有一个包含两个句子的批次:

["我", "是", "学生"] (长度为 3)

["今天", "天气", "真", "好"] (长度为 4)

为了将它们放入一个统一的张量(tensor)中进行高效的并行计算,我们必须将较短的句子"填充"到一个统一的长度(通常是这个批次中最长句子的长度)。我们会使用一个特殊的 <pad> 词元来完成这个任务。

填充后的数据就变成了:

["我", "是", "学生", "<pad>"]

["今天", "天气", "真", "好"]

现在,问题来了。当模型在训练时,它会为每个位置都生成一个预测。对于第一句话,它也会尝试在第4个位置预测 <pad>。如果我们不加处理,损失函数就会计算模型预测 <pad> 的准确度,并把这个"误差"也算进总的损失里。

这样做有两个坏处:

  1. 浪费计算资源:强迫模型去学习一个无意义的任务------"在句子末尾预测填充符"。

  2. 评估指标失真:这个无意义任务的损失会"稀释"我们真正关心的、对真实词元的预测损失,导致我们无法准确评估模型的真实性能。

Masked Loss 就是为了解决这个问题而生的。它的目标就是创建一个"掩码(mask)",告诉损失函数不计算PAD。

PyTorch 逐行实现

在 PyTorch 中,实现 Masked Loss 非常简单,因为 nn.CrossEntropyLoss 已经内置了处理它的高效方法。

我们将一步步模拟这个过程。

第零步:准备工作

我们先导入库,并设定一些基本参数。

python 复制代码
import torch
import torch.nn as nn

#设定参数


BATCH_SIZE = 2      # 一个批次里有2句话

SEQ_LEN = 5         # 统一填充后的句子长度是5

VOCAB_SIZE = 10     # 假设我们的词汇表很小,只有10个词

PADDING_IDX = 0     # 我们约定,ID为0的词元就是 <pad> 填充符

代码解释 : 我们设定了一个场景:一个批次包含2个句子,每个句子被填充到长度5,词汇表共10个词,并且我们用 0 来代表 <pad>

第一步:模拟模型输出和真实标签

我们创建两个张量:一个是模型预测的 logits,另一个是带填充的真实标签 target

python 复制代码
# 模拟模型的原始输出 (logits)
# 形状: (批量大小, 序列长度, 词汇表大小)
logits = torch.randn(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)

# 模拟真实的标签 (ground truth)
# 注意其中包含了 PADDING_IDX (0)
target = torch.tensor([
    [1, 5, 4, 2, PADDING_IDX],  # 第1句话,最后一个是padding
    [3, 8, 7, PADDING_IDX, PADDING_IDX]   # 第2句话,最后两个是padding
])

print("模型预测 Logits 的形状:", logits.shape)
print("真实标签 Target 的形状:", target.shape)
print("真实标签内容:\n", target)

代码解释logits 是模型对每个位置、每个词的预测得分。target 是我们的"标准答案",可以看到,为了对齐,较短的句子末尾被填充了 0

第二步:定义损失函数

python 复制代码
# 定义交叉熵损失函数
# 关键:告诉损失函数,所有标签值为 PADDING_IDX 的位置都被忽略

criterion = nn.CrossEntropyLoss(ignore_index=PADDING_IDX)

ignore_index=PADDING_IDX 这个参数就是实现 Masked Loss 的方法。当我们把 padding_idx (这里是0) 传给它,CrossEntropyLoss 在内部计算时,会自动跳过所有目标标签是 0 的位置。

第三步:调整张量形状

CrossEntropyLoss 期望的输入形状是:Input: (N, C)Target: (N),其中 N 是样本总数,C 是类别数。而我们现在的 logitstarget 都是二维的批次数据,需要调整一下。

python 复制代码
# CrossEntropyLoss 需要的输入形状是 (N, C)
# N 是总的需要计算的元素数量, C是类别数 (即词汇表大小)
# 我们用 .view() 来重塑张量

# 将 logits 从 (2, 5, 10) 变为 (10, 10)
reshaped_logits = logits.view(-1, VOCAB_SIZE)

# 将 target 从 (2, 5) 变为 (10)
reshaped_target = target.view(-1)

print("\n重塑后的 Logits 形状:", reshaped_logits.shape)
print("重塑后的 Target 形状:", reshaped_target.shape)

代码解释 : 我们把 (BATCH_SIZE, SEQ_LEN) 这两个维度"压平"成一个维度。-1 是一个占位符,告诉 PyTorch 自动计算这个维度的大小(在这里就是 2 * 5 = 10)。

第四步:计算损失

现在,所有准备工作都已就绪,我们可以直接计算损失。

python 复制代码
# 计算损失
# criterion 会自动使用我们设置的 ignore_index=0 来忽略填充位置
loss = criterion(reshaped_logits, reshaped_target)

print(f"\n计算出的 Masked Loss 是: {loss.item()}")

代码解释 : 尽管 reshaped_target 中仍然包含 0,但由于我们在第二步中设置了 ignore_index=0,这些位置的损失不会被计算和累加 。最终得到的 loss 值,是只基于那 7 个真实词元([1, 5, 4, 2][3, 8, 7])计算出来的平均损失。

这样,我们就用非常简洁的方式实现了 Masked Loss。

相关推荐
用户5191495848457 分钟前
curl --continue-at 参数异常行为分析:文件覆盖与删除风险
人工智能·aigc
用户84913717547167 分钟前
joyagent智能体学习(第1期):项目概览与架构解析
人工智能·llm·agent
是乐谷7 分钟前
阿里云杭州 AI 产品法务岗位信息分享(2025 年 8 月)
java·人工智能·阿里云·面试·职场和发展·机器人·云计算
用户5191495848459 分钟前
初识ARIA时我希望有人告诉我的事:Web无障碍开发指南
人工智能·aigc
AI知识管理20 分钟前
AI知识管理产品落地设计方案
人工智能·产品
weixin_5079299123 分钟前
第G7周:Semi-Supervised GAN 理论与实战
人工智能·pytorch·深度学习
天才测试猿27 分钟前
常见的Jmeter压测问题
自动化测试·软件测试·python·测试工具·jmeter·职场和发展·压力测试
mortimer28 分钟前
一次与“顽固”外部程序的艰难交锋:subprocess 调用exe踩坑实录
windows·python·ai编程
一叶飘零_sweeeet41 分钟前
IDEA 插件 Trae AI 全攻略
java·人工智能·intellij-idea
SEO_juper1 小时前
AI 搜索时代:引领变革,重塑您的 SEO 战略
人工智能·搜索引擎·seo·数字营销·seo优化