Q1:SFT时,计算LOSS
LLM是自回归生成模型,每次只会生成一个 token,难道 SFT 时,对于一个 (L, D) 的数据,要调用 L 次LLM去计算loss?
A1:
在每个前向传播过程,模型一次性生成整个序列的概率分布,而不是逐个生成令牌。对于一个(L, D)
的输入,那么Attention后,我们还是会得到一个(L, D)
输出。只不过(i, D)
这个向量中存着1~i
中间所有的信息,那么用它就可以生成第 i + 1
个位置的内容。这也是为什么generate函数中每次会取 logits[:, -1] 去生成新的内容。
Q2: SFT时,数据为什么prompt+input+output
LLM是自回归生成模型,在训练时候为什么不是用 prompt + input
作为输入,然后得到 output
再去与真实的 label
计算 loss
更新参数呢?
A2:
首先,如果要是像问题中这种策略去训练,一来每次要调用 l e n g t h o u t p u t length_{output} lengthoutput 次模型,二来模型生成的内容和 label
长度不一定一样,计算 loss
会出问题。其中这主要是因为我们在计算loss时,pytorch中要求loss_function(input, label)
中的 input, label
的shape要一致。然后为了加速收敛,这里其实是一种teacher force
的策略,就在第i个位置,我们会得到一个hidden_state
,然后第i+1
个位置的token
应该由这个hidden_state
去生成,但是我们强制让第i+1
个位置的token
和label
中这个位置的token
一样,也就是在相对正确的环境下再去生成生成第i+1
个位置的hidden_state
。
Q3:SFT时,构造lable
SFT时,构造的lable
为什么要把prompt+input
部分mask
掉。
A3:
像Q1中那样,我们生成的时候是一次性把整个序列的概率分布拿到。然后我们其实不想模型去学会对齐prompt+input
这部分的能力(因为没用),所以把prompt+input
mask 掉,只计算output
部分的loss。