文章目录
- 束搜索算法及其实现
-
-
- [一、束搜索(Beam Search)算法详解](#一、束搜索(Beam Search)算法详解)
-
- [(一) 核心思想](#(一) 核心思想)
- (二)工作流程
- (三)与贪心搜索的对比**
- [二、PyTorch 实现束搜索](#二、PyTorch 实现束搜索)
- 三、使用示例:文本生成
-
- [**场景**:使用 LSTM 模型生成句子,束宽 `k=3`。](#场景:使用 LSTM 模型生成句子,束宽
k=3
。) - **效果分析**:
- [**场景**:使用 LSTM 模型生成句子,束宽 `k=3`。](#场景:使用 LSTM 模型生成句子,束宽
- 四、进阶技巧与注意事项
- 五、适用场景对比
-
束搜索算法及其实现

一、束搜索(Beam Search)算法详解
(一) 核心思想
束搜索是一种启发式图搜索算法 ,用于序列生成任务(如机器翻译、文本生成)。它通过维护固定数量的候选序列(称为 束宽 k
),在每一步扩展时仅保留概率最高的 k
个路径,避免穷举所有可能序列,平衡效率与质量。
(二)工作流程
- 初始化 :起始标记(如
<start>
)加入候选序列,初始得分(概率对数)为 0。 - 迭代扩展 :
- 对当前每个候选序列,生成所有可能的下一词元。
- 计算新序列的累积对数概率:
新得分 = 原得分 + log(当前词概率)
。
- 剪枝 :从所有新候选序列中选择得分最高的
k
个,其余丢弃。 - 终止条件 :
- 所有候选序列生成结束标记
<end>
。 - 达到预设最大长度。
- 所有候选序列生成结束标记
- 输出:选择最终得分最高的序列。
(三)与贪心搜索的对比**
特性 | 贪心搜索 (Greedy) | 束搜索 (Beam Search) |
---|---|---|
搜索宽度 | 1(每步选最优词) | k (保留多个候选) |
全局最优性 | 易陷局部最优 | 更接近全局最优 |
计算成本 | 低 | 较高(约为贪心的 k 倍) |
输出质量 | 可能重复或不连贯 | 更流畅、合理 |
关键点 :束宽
k
越大,生成质量越高,但计算开销越大;k=1
时退化为贪心搜索。
二、PyTorch 实现束搜索
以下是一个支持批量处理的束搜索实现,适用于 Seq2Seq 模型(如 LSTM 或 Transformer):
python
import torch
import torch.nn.functional as F
def beam_search(model, encoder_output, beam_width=3, max_len=20, device="cuda"):
"""
参数:
model: 解码器模型(需实现 `get_next_token_probs` 方法)
encoder_output: 编码器输出张量 (batch_size, seq_len, hidden_size)
beam_width: 束宽 k
max_len: 最大生成长度
"""
batch_size = encoder_output.size(0)
start_token = torch.tensor([0]).to(device) # 起始标记(假设0)
end_token = 1 # 结束标记(假设1)
# 初始化束:[(序列, 得分)] * batch_size
beams = [[[start_token], 0.0]] * batch_size
active_beams = [True] * batch_size # 标记活跃的束
for _ in range(max_len):
new_beams = [[] for _ in range(batch_size)]
for i in range(batch_size):
if not active_beams[i]:
new_beams[i] = beams[i]
continue
seq, score = beams[i][-1] if beams[i] else ([start_token], 0.0)
last_token = seq[-1]
# 获取下一词概率分布 (vocab_size)
with torch.no_grad():
logits = model(last_token.unsqueeze(0), encoder_output[i].unsqueeze(0))
probs = F.log_softmax(logits, dim=-1).squeeze(0) # 对数概率
# 扩展候选:当前得分 + log(新词概率)
topk_probs, topk_tokens = torch.topk(probs, beam_width)
for j in range(beam_width):
new_seq = seq + [topk_tokens[j].item()]
new_score = score + topk_probs[j].item()
candidate = (new_seq, new_score)
new_beams[i].append(candidate)
# 剪枝:保留得分最高的 k 个候选
new_beams[i] = sorted(new_beams[i], key=lambda x: x[1], reverse=True)[:beam_width]
# 检查是否所有束均终止
for i in range(batch_size):
active_beams[i] = any(token[-1][0][-1] != end_token for token in new_beams[i])
beams = new_beams
if not any(active_beams):
break
# 返回每个样本的最优序列
return [max(beam, key=lambda x: x[1])[0] for beam in beams]
代码说明:
- 概率处理 :使用
log_softmax
避免数值下溢,得分累加使用对数概率。 - 剪枝优化 :每步仅保留
k
个高分候选,降低计算复杂度。 - 终止条件 :遇到
<end>
标记或达到最大长度时停止。 - 批量支持:独立处理每个样本的束,适合 GPU 并行。
三、使用示例:文本生成
场景 :使用 LSTM 模型生成句子,束宽 k=3
。
python
# 假设已定义 LSTM 解码器模型
class DecoderLSTM(torch.nn.Module):
def __init__(self, vocab_size, hidden_size):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
self.fc = torch.nn.Linear(hidden_size, vocab_size)
def forward(self, x, hidden):
embedded = self.embedding(x)
output, hidden = self.lstm(embedded, hidden)
logits = self.fc(output)
return logits, hidden
# 模拟输入
vocab_size = 10000
hidden_size = 256
encoder_output = torch.randn(2, 10, hidden_size).to("cuda") # 批量大小=2
model = DecoderLSTM(vocab_size, hidden_size).to("cuda")
# 执行束搜索
generated_seqs = beam_search(model, encoder_output, beam_width=3, max_len=15)
print(generated_seqs) # 输出:[[0, 42, 17, 1], [0, 88, 3, 7, 1]]
效果分析:
- 束宽
k=1
(贪心搜索):可能生成["I am am good"]
等重复句子。 - 束宽
k=3
:生成["I am good"]
或["You are great"]
等更合理序列。
四、进阶技巧与注意事项
-
长度归一化 :
- 长序列因概率连乘得分偏低,需除以长度:
score = score / len(seq)^α
(α
常取 0.7)。
- 长序列因概率连乘得分偏低,需除以长度:
-
多样性优化 :
- 组束搜索(Group Beam Search):将束分为多组,添加多样性惩罚避免相似输出。
python# Hugging Face Transformers 示例 from transformers import AutoModelForSeq2SeqLM, AutoTokenizer model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") outputs = model.generate( inputs, num_beams=8, num_beam_groups=4, # 分4组 diversity_penalty=1.0 # 多样性惩罚强度 )
-
与采样策略结合 :
- Top-K 采样 :从概率最高的
K
个词中随机选择,增加多样性(如创意写作)。 - Top-p(核)采样 :动态选择累积概率 ≥
p
的最小词集(如对话生成)。
- Top-K 采样 :从概率最高的
五、适用场景对比
解码策略 | 特点 | 适用场景 |
---|---|---|
贪心搜索 | 速度快,质量低 | 实时响应、结构化生成 |
束搜索 | 质量高,计算中等 | 机器翻译、图像描述 |
Top-K 采样 | 随机性高,多样性好 | 诗歌生成、故事创作 |
Top-p 采样 | 自适应多样性 | 对话系统、代码生成 |
实践建议 :多数任务优先用束搜索(
k=3~10
);需创造性输出时结合采样策略。
束搜索通过高效平衡搜索质量与计算成本,成为序列生成任务的主流解码方法。理解其原理及实现细节,能显著提升生成模型的输出质量。