目录
- [1. Top-p 和 Top-k 采样](#1. Top-p 和 Top-k 采样)
- [2. LayerNorm 和 RMSNorm](#2. LayerNorm 和 RMSNorm)
- [4. 手撕 Softmax 与 交叉熵 (Cross Entropy)](#4. 手撕 Softmax 与 交叉熵 (Cross Entropy))
- 面试总结与建议
1. Top-p 和 Top-k 采样
概念讲解:
在自回归文本生成中,模型每一步会输出一个概率分布(logits 经过 softmax),我们需要从中采样下一个 token。直接使用整个词汇表采样(即 temperature 缩放后的随机采样)可能导致生成低概率 token,使结果不连贯。Top-k 采样 和 Top-p 采样 是两种常用的截断采样方法,用于限制候选 token 集合,提高生成质量。
- Top-k 采样:
- 做法 :只保留概率最高的 k 个词,把剩下的词概率强制设为 0,然后重新归一化(让剩下的概率和为 1),再从中采样。
- 作用:直接砍掉长尾的低概率词,防止生成生僻字或乱码。
- 缺点:k 是固定的。如果模型很自信(某个词概率 90%),k 太大也会采样到噪音;如果模型很犹豫(概率很平),k 太小会限制多样性。
- Top-p (Nucleus) 采样:
- 做法 :将词按概率从大到小排序,依次累加概率,直到累加和超过 p (比如 0.9)。保留这些词,剩下的截断,重新归一化,再采样。
- 作用:动态调整候选词数量。模型自信时候选词少,模型犹豫时候选词多。
- 现状:目前 LLM 推理中,Top-p 比 Top-k 更常用,或者两者结合。
两种方法可以结合使用(如先取 top-k 再取 top-p),但通常分别实现。
代码实现:
python
def top_k_filtering(logits, top_k=50, temperature=1.0, filter_value=-float('Inf')):
"""
logits: [vocab_size] 或 [batch, vocab_size],模型输出的原始分数
top_k: 保留概率最大的 k 个词
temperature: 温度,大于 1 增加多样性,小于 1 增加确定性
"""
logits = logits / temperature
# 找出所有小于第 k 大值的索引,将其设为 filter_value (即 -inf,softmax 后为 0)
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
return logits
def top_p_filtering(logits, top_p=0.9, temperature=1.0, filter_value=-float('Inf'), min_tokens_to_keep=1):
"""
logits: [vocab_size] 或 [batch, vocab_size],模型输出的原始分数
top_p: 保留累积概率超过 p 的最小词集
temperature: 温度,大于 1 增加多样性,小于 1 增加确定性
"""
logits = logits / temperature
# 按概率从大到小排序
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) # [batch, vocab]
cumulated_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # [batch, vocab]
# 创建 mask:累积概率 > p 的位置为 True(需要被移除)
sorted_indices_to_remove = cumulated_probs > top_p # [batch, vocab]
sorted_indices_to_remove[:, :min_tokens_to_keep] = False # 保留第一个超过 p 的 token(确保至少有一些 token)
# 注意:sorted_logits 是排序后的,需要映射回原始位置
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value # 将被移除的 logits 设为 -inf(softmax 后概率为 0)
return logits
if __name__ == "__main__":
# 模拟一个 vocab_size=5 的 logits
logits = torch.tensor([[2.0, 1.0, 0.1, 0.1, 0.1]])
print("原始 Logits:", logits)
filtered_logits = top_p_filtering(logits, top_p=0.8, temperature=1.0)
print("过滤后 Logits:", filtered_logits)
# 采样
probs = F.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
print("采样结果 Token ID:", next_token)
2. LayerNorm 和 RMSNorm
概念讲解:
- Layer Normalization (LayerNorm) :对每个样本的每个特征层进行归一化。给定输入
x形状(batch, seq_len, hidden_size),对最后一个维度(hidden_size)计算均值和方差,然后标准化:(x - mean) / sqrt(var + eps),再乘以可学习的缩放参数gamma并加上偏移beta。LayerNorm 广泛用于 Transformer,稳定训练。 - RMSNorm :Root Mean Square Layer Normalization 是 LayerNorm 的一个简化变体。它假设减去均值不是必需的,只使用均方根 (RMS) 进行归一化:
x / RMS(x),其中RMS(x) = sqrt(mean(x^2) + eps)。同样乘以可学习的缩放参数gamma,但没有beta。RMSNorm 计算量更小,且在实验中性能与 LayerNorm 相当。
代码实现:
python
import torch
import torch.nn as nn
class LayerNorm(nn.Module):
"""
层归一化 (Layer Normalization)
公式: y = gamma * (x - mean) / sqrt(var + eps) + beta
"""
def __init__(self, hidden_size, eps=1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.ones(hidden_size))
self.beta = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
# x: (batch, seq_len, hidden_size) 或 (batch, hidden_size)
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
x_norm = (x - mean) / torch.sqrt(var + self.eps)
return self.gamma * x_norm + self.beta
class RMSNorm(nn.Module):
"""
RMS 归一化 (Root Mean Square Layer Normalization)
公式: y = gamma * x / RMS(x), 其中 RMS(x) = sqrt(mean(x^2) + eps)
"""
def __init__(self, hidden_size, eps=1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
# x: (batch, seq_len, hidden_size) 或 (batch, hidden_size)
rms = torch.sqrt(torch.mean(x, dim=-1, keepdim=True) + self.eps)
x_norm = x / rms
return self.gamma * x_norm
# 示例用法
if __name__ == "__main__":
batch, seq, hidden = 2, 3, 4
x = torch.randn(batch, seq, hidden)
ln = LayerNorm(hidden)
rms = RMSNorm(hidden)
out_ln = ln(x)
out_rms = rms(x)
print("LayerNorm 输出形状:", out_ln.shape)
print("RMSNorm 输出形状:", out_rms.shape)
4. 手撕 Softmax 与 交叉熵 (Cross Entropy)
概念讲解:
这是深度学习最基础的算子,但面试常考 数值稳定性。
- Softmax : \(P_i = \frac{e^{z_i}}{\sum e^{z_j}}\)
- 问题 :如果 \(z_i\) 很大,\(e^{z_i}\) 会溢出 (Infinity)。
- 解决 :利用 Softmax 的平移不变性,所有 \(z\) 减去最大值 \(\max(z)\)。即 \(P_i = \frac{e^{z_i - \max(z)}}{\sum e^{z_j - \max(z)}}\)。
- Cross Entropy (CE) : \(Loss = -\sum y_i \log(P_i)\)
- 在分类任务中,\(y\) 是 one-hot,所以简化为 \(-\log(P_{target})\)。
- 结合 Softmax :通常不单独算 Softmax 再算 Log,而是合并为
LogSoftmax,数值更稳定。
代码实现:
python
import torch
def stable_softmax(logits, dim=-1):
"""
数值稳定的 Softmax 实现
"""
# 1. 减去最大值,防止 exp 溢出
# keepdim=True 保证形状可以广播
max_logits = torch.max(logits, dim=dim, keepdim=True)[0]
exp_logits = torch.exp(logits - max_logits)
# 2. 归一化
sum_exp_logits = torch.sum(exp_logits, dim=dim, keepdim=True)
probs = exp_logits / sum_exp_logits
return probs
def cross_entropy_loss(logits, targets):
"""
手写交叉熵 Loss
logits: [batch, vocab]
targets: [batch] 类别索引
"""
batch_size = logits.shape[0]
# 1. 数值稳定的 LogSoftmax
# log(softmax(x)) = x - max(x) - log(sum(exp(x - max(x))))
max_logits = torch.max(logits, dim=1, keepdim=True)[0]
log_sum_exp = max_logits + torch.log(torch.sum(torch.exp(logits - max_logits), dim=1, keepdim=True))
log_probs = logits - log_sum_exp
# 2. NLL Loss (Negative Log Likelihood)
# 取出目标类别对应的 log 概率
# targets 需要 unsqueeze 才能 gather
target_log_probs = log_probs.gather(1, targets.unsqueeze(1)).squeeze(1)
# 3. 取平均
loss = -torch.mean(target_log_probs)
return loss
# --- 测试 Demo ---
if __name__ == "__main__":
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.2]])
targets = torch.tensor([0, 1]) # 第一个样本目标是类 0,第二个是类 1
# 对比 PyTorch 原生实现
torch_ce = nn.CrossEntropyLoss()
torch_loss = torch_ce(logits, targets)
# 对比手写实现
my_loss = cross_entropy_loss(logits, targets)
print(f"PyTorch Loss: {torch_loss.item():.6f}")
print(f"My Loss: {my_loss.item():.6f}")
# 两者应该非常接近
面试总结与建议
- 关于 Shift Right :这是 LLM 训练最核心的数据对齐逻辑。面试时如果能主动提到
contiguous()的作用(内存连续)和ignore_index(处理 padding),会非常加分。 - 关于数值稳定性 :在写 Softmax 和 CrossEntropy 时,必须 提到
max subtraction。如果不提,面试官可能会认为你缺乏工程经验。 - 关于 RMSNorm:现在 LLaMA、Qwen 等主流模型都用 RMSNorm。如果能说出它比 LayerNorm 少了减均值操作,计算更快,且效果在 LLM 上相当,会体现你对前沿架构的了解。
- 关于采样:实际推理中,Temperature 通常是在 Top-k/p 之前应用的。代码中体现这一点会显得更专业。