文章目录
前言
在大语言模型时代,我们常常使用交叉熵损失函数来计算loss,因此,理解该loss的计算流程有助于帮助我们对训练过程有更清晰的认知。本文从以下几个角度介绍nn.CrossEntropyLoss()
- 使用该函数的前期准备:如何组织函数的输入(logits & labels)
- 该函数流程
- 常用参数
- 该文章内容仅为个人理解,如有误解,欢迎讨论
什么是CrossEntropyLoss
这部分并不是本文的重点,我们仅介绍在语言模型的训练过程中,如何利用该loss
- 相关信息可见:本人博客
- 以及官网:CrossEntropyLoss官网
语言模型中的CrossEntropyLoss
计算loss的前期准备
在huggingface-transformers
源码中,我们在语言模型的forward
中总是能看到这样一段函数。我们以LlamaForCausalLM
为例:Llama源码
python
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
对于Decoder-only
模型,在训练时,我们的目标是next token prediction
,任务流程如下
-
假定我们是常规的问答任务,问题是"where is the capital of China",label为"The capital is Beijing"。该任务的目标为,当输入为"where is the capital of China"时,
-
我们对question和label进行拼接和tokenize化,一般转化结果 (tokenize忽略) 为:< bos > where is the capital of China < sep > The capital is Beijing < eos >
- < bos>为句子开头的标志
- < sep>用于分隔question和label,本质作用是,当模型看到时就知道:问题结束了,下一个token要输出答案了
- < eos>为生成结束的标志
- 假定每个词算一个token (忽略空格),那么输入一共有13个token
-
这时我们将整个序列输入到模型中,模型在每个token的位置都生成一个向量,我们利用
lm_head
将最后一层的hidden state转化成词表大小的向量logits
,用于后续利用Softmax
确定每个token的概率 -
现在模型有了输出logits,怎么计算loss?
-
对比labels和logits之间的差异来计算loss
-
现在一共有13个token,生成了13个logits,每个logits都是用于生成next token的。那么很直接的,我们来对比该logits生成的next token准不准就好了
-
输入:< bos> where is the capital of China < sep> The capital is Beijing < eos>
-
对比情况为:< sep>->The, The->capital, ..., is->Beijing, Beijing->< eos>
- < sep>对应位置要生成The,..., Beijing对应位置要输出< eos>
-
我们可以将输入右移一位作为labels: where is the capital of China < sep> The capital is Beijing
- 可以看到,对于输入来说, < eos>位置没有对应的需要生成的token,因此我们去掉该token
- 对于labels,< bos>不需要生成,因此我们去掉该token
-
因此,我们在计算loss时,对logits去尾,labels是输入掐头且右移一位
-
在代码中对应
pythonshift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous()
-
-
CrossEntropyLoss的输入
此时还不能直接将shift_logits
和shift_labels
进行对比,来计算loss。因为我们上面的操作只是为了<sep> The capital is Beijing
和The capital is Beijing <eos>
中的token能一一对应起来,对于其他部分生成的token,我们并没有要求(因为不是answer,不需要生成)
CrossEntropyLoss
函数中有一个参数为ignore_idx
,默认值为-100。labels值设置为-100的位置不会计算loss- 因此我们将除了需要计算loss的位置 (最后5个位置)的labels都设置为-100
- 最终,需要输入到
CrossEntropyLoss
中的inputs和labels为- inputs为: [, where, is, the, capital, of, China, < sep>, The, capital, is, Beijing ]对应的logits
- 注意:不需要进行Softmax,直接传logits即可,函数内部有更稳定的Softmax计算方式
- labels为: [-100, -100, -100, -100, -100, -100, -100, The, capital, is, Beijing, < eos>]
- 我们在训练时,构造输入和labels要注意构造为这种形式
- inputs为: [, where, is, the, capital, of, China, < sep>, The, capital, is, Beijing ]对应的logits
CrossEntropyLoss的输出
默认情况下,输出为mean
,即各个token计算得到loss的平均值(在token-level上平均,分母是token的个数)
python
import torch
import torch.nn as nn
# 假设有 3 个类,logits 形状为 (batch_size=3, num_classes=3)
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3], [1.5, 0.5, 2.0]])
# 标签,其中第二个样本的标签为 ignore_index (-100)
labels = torch.tensor([0, -100, 2])
# 定义 CrossEntropyLoss
criterion = nn.CrossEntropyLoss()
# 计算损失
loss = criterion(logits, labels)
print(f"Loss: {loss}")
>>> Loss: 0.51058030128479
-
常用参数:
-
reduction
:控制loss的输出形式,共三种'none', 'mean', 'sum'
,默认为'mean'
-
mean: 每个token计算得到的loss的平均值
-
none: 直接返回每个token计算得到的loss
-
例子:
pythonimport torch import torch.nn as nn # 假设有 3 个类,logits 形状为 (batch_size=3, num_classes=3) logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3], [1.5, 0.5, 2.0]]) # 标签,其中第二个样本的标签为 ignore_index (-100) labels = torch.tensor([0, -100, 2]) # 定义 CrossEntropyLoss criterion = nn.CrossEntropyLoss(reduction='none') # 计算损失 loss = criterion(logits, labels) print(f"Loss: {loss}") >>> Loss: tensor([0.4170, 0.0000, 0.6041])
-
-
sum: 所有token对应loss求和
-
-
额外说明
对最上面的代码补充说明
python
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
- 训练数据往往是按batch组织的,shape为
(batch_size, seq_len, vocab_size)
- 我们将所有batch的token压缩为一个序列,计算整个序列的loss,这样比较方便