Transformer Masked loss原理精讲及其PyTorch逐行实现

Masked Loss 的核心原理是:在计算损失函数时,只考虑真实有意义的词元(token),而忽略掉为了数据对齐而填充的无意义的填充词元(padding token)。

这是重要的技术,可以确保模型专注于学习有意义的任务,并得到一个正确的性能评估。

1.原理精讲

为什么需要 Masked Loss?

在训练神经网络时,我们通常会用一个批次(batch)的数据进行训练,而不是一次只用一个样本。对于自然语言处理任务,我们会一次性处理多句话。但这些句子的长度都几乎不一样。

例如,我们有一个包含两个句子的批次:

["我", "是", "学生"] (长度为 3)

["今天", "天气", "真", "好"] (长度为 4)

为了将它们放入一个统一的张量(tensor)中进行高效的并行计算,我们必须将较短的句子"填充"到一个统一的长度(通常是这个批次中最长句子的长度)。我们会使用一个特殊的 <pad> 词元来完成这个任务。

填充后的数据就变成了:

["我", "是", "学生", "<pad>"]

["今天", "天气", "真", "好"]

现在,问题来了。当模型在训练时,它会为每个位置都生成一个预测。对于第一句话,它也会尝试在第4个位置预测 <pad>。如果我们不加处理,损失函数就会计算模型预测 <pad> 的准确度,并把这个"误差"也算进总的损失里。

这样做有两个坏处:

  1. 浪费计算资源:强迫模型去学习一个无意义的任务------"在句子末尾预测填充符"。

  2. 评估指标失真:这个无意义任务的损失会"稀释"我们真正关心的、对真实词元的预测损失,导致我们无法准确评估模型的真实性能。

Masked Loss 就是为了解决这个问题而生的。它的目标就是创建一个"掩码(mask)",告诉损失函数不计算PAD。

PyTorch 逐行实现

在 PyTorch 中,实现 Masked Loss 非常简单,因为 nn.CrossEntropyLoss 已经内置了处理它的高效方法。

我们将一步步模拟这个过程。

第零步:准备工作

我们先导入库,并设定一些基本参数。

python 复制代码
import torch
import torch.nn as nn

#设定参数


BATCH_SIZE = 2      # 一个批次里有2句话

SEQ_LEN = 5         # 统一填充后的句子长度是5

VOCAB_SIZE = 10     # 假设我们的词汇表很小,只有10个词

PADDING_IDX = 0     # 我们约定,ID为0的词元就是 <pad> 填充符

代码解释 : 我们设定了一个场景:一个批次包含2个句子,每个句子被填充到长度5,词汇表共10个词,并且我们用 0 来代表 <pad>

第一步:模拟模型输出和真实标签

我们创建两个张量:一个是模型预测的 logits,另一个是带填充的真实标签 target

python 复制代码
# 模拟模型的原始输出 (logits)
# 形状: (批量大小, 序列长度, 词汇表大小)
logits = torch.randn(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)

# 模拟真实的标签 (ground truth)
# 注意其中包含了 PADDING_IDX (0)
target = torch.tensor([
    [1, 5, 4, 2, PADDING_IDX],  # 第1句话,最后一个是padding
    [3, 8, 7, PADDING_IDX, PADDING_IDX]   # 第2句话,最后两个是padding
])

print("模型预测 Logits 的形状:", logits.shape)
print("真实标签 Target 的形状:", target.shape)
print("真实标签内容:\n", target)

代码解释logits 是模型对每个位置、每个词的预测得分。target 是我们的"标准答案",可以看到,为了对齐,较短的句子末尾被填充了 0

第二步:定义损失函数

python 复制代码
# 定义交叉熵损失函数
# 关键:告诉损失函数,所有标签值为 PADDING_IDX 的位置都被忽略

criterion = nn.CrossEntropyLoss(ignore_index=PADDING_IDX)

ignore_index=PADDING_IDX 这个参数就是实现 Masked Loss 的方法。当我们把 padding_idx (这里是0) 传给它,CrossEntropyLoss 在内部计算时,会自动跳过所有目标标签是 0 的位置。

第三步:调整张量形状

CrossEntropyLoss 期望的输入形状是:Input: (N, C)Target: (N),其中 N 是样本总数,C 是类别数。而我们现在的 logitstarget 都是二维的批次数据,需要调整一下。

python 复制代码
# CrossEntropyLoss 需要的输入形状是 (N, C)
# N 是总的需要计算的元素数量, C是类别数 (即词汇表大小)
# 我们用 .view() 来重塑张量

# 将 logits 从 (2, 5, 10) 变为 (10, 10)
reshaped_logits = logits.view(-1, VOCAB_SIZE)

# 将 target 从 (2, 5) 变为 (10)
reshaped_target = target.view(-1)

print("\n重塑后的 Logits 形状:", reshaped_logits.shape)
print("重塑后的 Target 形状:", reshaped_target.shape)

代码解释 : 我们把 (BATCH_SIZE, SEQ_LEN) 这两个维度"压平"成一个维度。-1 是一个占位符,告诉 PyTorch 自动计算这个维度的大小(在这里就是 2 * 5 = 10)。

第四步:计算损失

现在,所有准备工作都已就绪,我们可以直接计算损失。

python 复制代码
# 计算损失
# criterion 会自动使用我们设置的 ignore_index=0 来忽略填充位置
loss = criterion(reshaped_logits, reshaped_target)

print(f"\n计算出的 Masked Loss 是: {loss.item()}")

代码解释 : 尽管 reshaped_target 中仍然包含 0,但由于我们在第二步中设置了 ignore_index=0,这些位置的损失不会被计算和累加 。最终得到的 loss 值,是只基于那 7 个真实词元([1, 5, 4, 2][3, 8, 7])计算出来的平均损失。

这样,我们就用非常简洁的方式实现了 Masked Loss。

相关推荐
山烛22 分钟前
Python 数据可视化之 Matplotlib 库
开发语言·python·matplotlib·数据可视化
莫彩28 分钟前
【大模型论文阅读】2503.01821_On the Power of Context-Enhanced Learning in LLMs
论文阅读·人工智能·语言模型
hhhh明37 分钟前
【调试Bug】网络在训练中输出NaN
人工智能·算法
蛋仔聊测试43 分钟前
SQL语句执行顺序全解析
python·面试
里昆1 小时前
【AI】Jupyterlab中数据集的位置和程序和Pycharm中的区别
人工智能·学习
我的ID配享太庙呀1 小时前
从零开始:在 PyCharm 中搭建 Django 商城的用户注册与登录功能(轮播图+商品页-小白入门版)
数据库·python·django·sqlite·web·教育电商
WSSWWWSSW1 小时前
基于模拟的流程为灵巧机器人定制训练数据
人工智能·chatgpt·机器人
bksheng1 小时前
【SSL证书校验问题】通过 monkey-patch 关掉 SSL 证书校验
网络·爬虫·python·网络协议·ssl
大视码垛机1 小时前
协作机器人掀起工厂革命:码垛场景如何用数据重塑制造业命脉?
大数据·数据库·人工智能
呆头鹅AI工作室1 小时前
[2025CVPR-图象分类方向]SPARC:用于视觉语言模型中零样本多标签识别的分数提示和自适应融合
图像处理·人工智能·python·深度学习·神经网络·计算机视觉·语言模型