beam search是一种用于序列生成的搜索算法,它在每一步扩展多个候选序列(称为beam),并保留最有可能的k个(beam width)序列,直到生成结束。
这里尝试模拟beam search优化LLM输出过程,以及如何在实际中应用和优化beam search。
所用示例参考和修改自网络资料。
1 beam search原理
Beam Search 是自回归生成模型 的一种解码策略,介于贪心搜索 和穷举搜索之间。
| 方法 | 原理 | 时间复杂度 | 结果质量 |
|---|---|---|---|
| 贪心搜索 | 每一步只选概率最高的词 | O(n) | 可能陷入局部最优 |
| 穷举搜索 | 考虑所有可能路径 | O(Vⁿ) | 全局最优,但不现实 |
| Beam Search | 每一步保留 top-K 个候选 | O(K·n) | 近似最优,效率高 |
2 beam search示例
这里使用一个简单的例子,并且假设模型已经训练好,可以通过函数获取下一个单词的概率分布。
假设这个简单的语言模型,用于生成五言诗句。
假设每一步预测只依赖于已生成的部分,并且有一个模型函数可以给出下一个词的概率。
步骤如下所示:
初始化beam,只包含一个空的序列,并且得分(通常为对数概率)为0。
对于每一步(直到达到最大长度或结束符):
a. 对于当前beam中的每个序列,用模型预测下一个词的概率分布。
b. 从每个序列的下一步概率中选出Top k个候选(k为beam size),计算新的序列和得分(当前得分加上对数概率)。
c. 将所有候选序列按照得分排序,选出Top k个作为新的beam。
当所有序列都达到结束符或达到最大长度时,停止。然后从beam中选择得分最高的序列作为输出。
示例代码如下。
import numpy as np
from collections import defaultdict
import math
# 模拟一个简单的语言模型输出概率
class SimpleLanguageModel:
def __init__(self):
# 词汇表
self.vocab = ["春", "风", "花", "月", "夜", "江", "山", "水", "云", "雨",
"人", "生", "梦", "酒", "歌", "<start>", "<end>"]
self.word2id = {word: i for i, word in enumerate(self.vocab)}
self.id2word = {i: word for i, word in enumerate(self.vocab)}
# 模拟一个简单的概率分布 (实际中来自神经网络)
self.prob_table = {
"<start>": [("春", 0.4), ("风", 0.3), ("月", 0.2), ("江", 0.1)],
"春": [("风", 0.5), ("花", 0.3), ("雨", 0.1), ("江", 0.1)],
"风": [("花", 0.4), ("月", 0.3), ("云", 0.2), ("雨", 0.1)],
"花": [("月", 0.5), ("夜", 0.3), ("春", 0.1), ("雨", 0.1)],
"月": [("夜", 0.6), ("光", 0.2), ("明", 0.1), ("圆", 0.1)],
"夜": [("色", 0.4), ("深", 0.3), ("晚", 0.2), ("来", 0.1)],
"江": [("山", 0.5), ("水", 0.3), ("南", 0.1), ("流", 0.1)],
"山": [("水", 0.6), ("色", 0.2), ("青", 0.1), ("高", 0.1)],
"水": [("流", 0.5), ("声", 0.3), ("清", 0.1), ("深", 0.1)],
"云": [("雨", 0.6), ("雾", 0.2), ("散", 0.1), ("飘", 0.1)],
"雨": [("后", 0.5), ("中", 0.3), ("前", 0.1), ("夜", 0.1)],
"人": [("生", 0.6), ("间", 0.2), ("心", 0.1), ("情", 0.1)],
"生": [("梦", 0.5), ("死", 0.3), ("活", 0.1), ("命", 0.1)],
"梦": [("里", 0.5), ("中", 0.3), ("幻", 0.1), ("醒", 0.1)],
"酒": [("歌", 0.6), ("醉", 0.3), ("杯", 0.05), ("诗", 0.05)],
"歌": [("声", 0.5), ("唱", 0.3), ("舞", 0.1), ("欢", 0.1)],
}
def predict_next_word(self, current_word):
"""返回下一个词的概率分布"""
if current_word not in self.prob_table:
# 如果不在表中,返回均匀分布
return [(word, 1.0/len(self.vocab)) for word in self.vocab]
return self.prob_table[current_word]
def get_top_k(self, current_word, k=3):
"""返回top-k个候选词"""
predictions = self.predict_next_word(current_word)
predictions.sort(key=lambda x: x[1], reverse=True)
return predictions[:k]
# 实现 Beam Search
def beam_search_decode(model, beam_width=3, max_length=5):
"""
使用 Beam Search 生成诗句
"""
# 初始 beam,只有一个空序列
beams = [{
'sequence': ['<start>'],
'score': 0.0, # 使用对数概率,初始为0(log(1)=0)
'last_word': '<start>'
}]
for step in range(max_length):
all_candidates = []
for beam in beams:
# 如果最后一个词是结束符,则不再扩展
if beam['last_word'] == '<end>' or step == max_length - 1:
all_candidates.append(beam)
continue
# 获取下一个词的候选
top_k = model.get_top_k(beam['last_word'], beam_width)
for word, prob in top_k:
# 计算新得分(对数概率相加)
# 为了避免浮点数下溢,使用对数概率
new_score = beam['score'] + math.log(prob + 1e-10) # 加一个小数避免log(0)
# 构建新序列
new_sequence = beam['sequence'] + [word]
candidate = {
'sequence': new_sequence,
'score': new_score,
'last_word': word
}
all_candidates.append(candidate)
# 按得分排序,保留最好的 beam_width 个
all_candidates.sort(key=lambda x: x['score'], reverse=True)
beams = all_candidates[:beam_width]
print(f"\n第 {step+1} 步后的候选序列:")
for i, beam in enumerate(beams):
seq_str = ''.join([w for w in beam['sequence'] if w not in ['<start>', '<end>']])
print(f" Beam {i+1}: {seq_str:10s} 得分: {beam['score']:.3f}")
# 检查是否所有beam都结束了
if all(beam['last_word'] == '<end>' for beam in beams):
break
# 选择最佳序列
best_beam = max(beams, key=lambda x: x['score'])
# 去掉起始标记
final_sequence = [w for w in best_beam['sequence'] if w not in ['<start>', '<end>']]
return ''.join(final_sequence), best_beam['score']
# 运行示例
print("=" * 50)
print("Beam Search 诗词生成示例")
print("=" * 50)
model = SimpleLanguageModel()
# 使用不同beam宽度进行比较
for beam_width in [1, 2, 3]:
print(f"\n{'='*30}")
print(f"Beam Width = {beam_width}")
print(f"{'='*30}")
result, score = beam_search_decode(model, beam_width=beam_width, max_length=5)
print(f"\n生成的诗句: {result}")
print(f"最终得分: {score:.3f}")
输出如下所示,这里得分为log(prob),所以数值为负。
==================================================
Beam Search 诗词生成示例
==================================================
==============================
Beam Width = 1
==============================
第 1 步后的候选序列:
Beam 1: 春 得分: -0.916
第 2 步后的候选序列:
Beam 1: 春风 得分: -1.609
第 3 步后的候选序列:
Beam 1: 春风花 得分: -2.526
第 4 步后的候选序列:
Beam 1: 春风花月 得分: -3.219
第 5 步后的候选序列:
Beam 1: 春风花月 得分: -3.219
生成的诗句: 春风花月
最终得分: -3.219
==============================
Beam Width = 2
==============================
第 1 步后的候选序列:
Beam 1: 春 得分: -0.916
Beam 2: 风 得分: -1.204
第 2 步后的候选序列:
Beam 1: 春风 得分: -1.609
Beam 2: 春花 得分: -2.120
第 3 步后的候选序列:
Beam 1: 春风花 得分: -2.526
Beam 2: 春风月 得分: -2.813
第 4 步后的候选序列:
Beam 1: 春风花月 得分: -3.219
Beam 2: 春风月夜 得分: -3.324
第 5 步后的候选序列:
Beam 1: 春风花月 得分: -3.219
Beam 2: 春风月夜 得分: -3.324
生成的诗句: 春风花月
最终得分: -3.219
==============================
Beam Width = 3
==============================
第 1 步后的候选序列:
Beam 1: 春 得分: -0.916
Beam 2: 风 得分: -1.204
Beam 3: 月 得分: -1.609
第 2 步后的候选序列:
Beam 1: 春风 得分: -1.609
Beam 2: 月夜 得分: -2.120
Beam 3: 春花 得分: -2.120
第 3 步后的候选序列:
Beam 1: 春风花 得分: -2.526
Beam 2: 春风月 得分: -2.813
Beam 3: 春花月 得分: -2.813
第 4 步后的候选序列:
Beam 1: 春风花月 得分: -3.219
Beam 2: 春风月夜 得分: -3.324
Beam 3: 春花月夜 得分: -3.324
第 5 步后的候选序列:
Beam 1: 春风花月 得分: -3.219
Beam 2: 春风月夜 得分: -3.324
Beam 3: 春花月夜 得分: -3.324
生成的诗句: 春风花月
最终得分: -3.219
3 beam search应用
这里探索各种beam_search的变体应用。
3.1 长度惩罚beam search
带长度归一化的 Beam Search示例
def beam_search_with_length_penalty(model, beam_width=3, max_length=10, alpha=0.7):
"""带长度归一化的 Beam Search"""
beams = [{'sequence': ['<start>'], 'score': 0.0, 'last_word': '<start>'}]
for step in range(max_length):
all_candidates = []
for beam in beams:
if beam['last_word'] == '<end>':
all_candidates.append(beam)
continue
top_k = model.get_top_k(beam['last_word'], beam_width)
for word, prob in top_k:
# 对数概率
log_prob = math.log(prob + 1e-10)
# 长度归一化:得分除以长度^alpha
new_length = len(beam['sequence']) + 1
length_penalty = (new_length ** alpha)
# 新的得分(注意:因为是负对数概率,所以除以长度惩罚因子)
new_score = (beam['score'] * (len(beam['sequence']) ** alpha) + log_prob) / length_penalty
new_sequence = beam['sequence'] + [word]
candidate = {
'sequence': new_sequence,
'score': new_score,
'last_word': word
}
all_candidates.append(candidate)
# 排序并选择top-k
all_candidates.sort(key=lambda x: x['score'], reverse=True)
beams = all_candidates[:beam_width]
# 返回最佳结果
best_beam = max(beams, key=lambda x: x['score'])
final_sequence = [w for w in best_beam['sequence'] if w not in ['<start>', '<end>']]
return ''.join(final_sequence)
3.2 多样化beam search
这时一个多样化的 Beam Search变体的示例代码。
def diverse_beam_search(model, num_groups=2, beam_width=2, max_length=5):
"""多样化 Beam Search,增加结果的多样性"""
# 初始化每个组
groups = []
for _ in range(num_groups):
groups.append([{
'sequence': ['<start>'],
'score': 0.0,
'last_word': '<start>'
}])
for step in range(max_length):
all_group_candidates = []
for group_idx, beams in enumerate(groups):
group_candidates = []
for beam in beams:
if beam['last_word'] == '<end>':
group_candidates.append(beam)
continue
top_k = model.get_top_k(beam['last_word'], beam_width * 2)
for word, prob in top_k:
# 添加组间多样性惩罚
diversity_penalty = 0.0
# 检查其他组中是否已有这个词
for other_group_idx in range(num_groups):
if other_group_idx != group_idx:
# 简单惩罚:如果其他组的候选中有这个词,降低分数
# 这里简化处理,实际会更复杂
pass
log_prob = math.log(prob + 1e-10) - diversity_penalty
new_score = beam['score'] + log_prob
new_sequence = beam['sequence'] + [word]
candidate = {
'sequence': new_sequence,
'score': new_score,
'last_word': word,
'group': group_idx
}
group_candidates.append(candidate)
# 每组内排序
group_candidates.sort(key=lambda x: x['score'], reverse=True)
groups[group_idx] = group_candidates[:beam_width]
all_group_candidates.extend(group_candidates[:beam_width])
# 从所有组中选择最佳结果
best_candidate = max(all_group_candidates, key=lambda x: x['score'])
final_sequence = [w for w in best_candidate['sequence'] if w not in ['<start>', '<end>']]
return ''.join(final_sequence)
4 可视化beam search
示例代码如下所示
def visualize_beam_search(model, beam_width=2, max_length=4):
"""可视化 Beam Search 的搜索过程"""
beams = [{'sequence': ['<start>'], 'score': 0.0, 'last_word': '<start>'}]
print("Beam Search 搜索过程可视化:")
print("=" * 40)
for step in range(max_length):
print(f"\n时间步 {step+1}:")
print("-" * 20)
all_candidates = []
for beam_idx, beam in enumerate(beams):
print(f"\n扩展 Beam {beam_idx+1}: {' '.join(beam['sequence'])}")
if beam['last_word'] == '<end>':
all_candidates.append(beam)
continue
top_k = model.get_top_k(beam['last_word'], beam_width)
for word, prob in top_k:
new_score = beam['score'] + math.log(prob + 1e-10)
new_sequence = beam['sequence'] + [word]
candidate = {
'sequence': new_sequence,
'score': new_score,
'last_word': word
}
all_candidates.append(candidate)
print(f" -> {word} (p={prob:.3f}, 得分={new_score:.3f})")
# 排序并选择top-k
all_candidates.sort(key=lambda x: x['score'], reverse=True)
beams = all_candidates[:beam_width]
print(f"\n当前最佳候选:")
for i, beam in enumerate(beams):
seq_str = ' '.join(beam['sequence'])
print(f" Beam {i+1}: {seq_str:20s} 得分: {beam['score']:.3f}")
# 最终结果
print("\n" + "=" * 40)
print("最终结果:")
best_beam = max(beams, key=lambda x: x['score'])
final_seq = [w for w in best_beam['sequence'] if w not in ['<start>', '<end>']]
print(f"生成序列: {' '.join(final_seq)}")
print(f"最终得分: {best_beam['score']:.3f}")
# 运行可视化
model = SimpleLanguageModel()
visualize_beam_search(model, beam_width=2, max_length=4)
输出如下所示
Beam Search 搜索过程可视化:
========================================
时间步 1:
扩展 Beam 1: <start>
-> 春 (p=0.400, 得分=-0.916)
-> 风 (p=0.300, 得分=-1.204)
当前最佳候选:
Beam 1: <start> 春 得分: -0.916
Beam 2: <start> 风 得分: -1.204
时间步 2:
扩展 Beam 1: <start> 春
-> 风 (p=0.500, 得分=-1.609)
-> 花 (p=0.300, 得分=-2.120)
扩展 Beam 2: <start> 风
-> 花 (p=0.400, 得分=-2.120)
-> 月 (p=0.300, 得分=-2.408)
当前最佳候选:
Beam 1: <start> 春 风 得分: -1.609
Beam 2: <start> 春 花 得分: -2.120
时间步 3:
扩展 Beam 1: <start> 春 风
-> 花 (p=0.400, 得分=-2.526)
-> 月 (p=0.300, 得分=-2.813)
扩展 Beam 2: <start> 春 花
-> 月 (p=0.500, 得分=-2.813)
-> 夜 (p=0.300, 得分=-3.324)
当前最佳候选:
Beam 1: <start> 春 风 花 得分: -2.526
Beam 2: <start> 春 风 月 得分: -2.813
时间步 4:
扩展 Beam 1: <start> 春 风 花
-> 月 (p=0.500, 得分=-3.219)
-> 夜 (p=0.300, 得分=-3.730)
扩展 Beam 2: <start> 春 风 月
-> 夜 (p=0.600, 得分=-3.324)
-> 光 (p=0.200, 得分=-4.423)
当前最佳候选:
Beam 1: <start> 春 风 花 月 得分: -3.219
Beam 2: <start> 春 风 月 夜 得分: -3.324
========================================
最终结果:
生成序列: 春 风 花 月
最终得分: -3.219
Click to add a cell.
reference