python模拟beam search优化LLM输出过程

beam search是一种用于序列生成的搜索算法,它在每一步扩展多个候选序列(称为beam),并保留最有可能的k个(beam width)序列,直到生成结束。

这里尝试模拟beam search优化LLM输出过程,以及如何在实际中应用和优化beam search。

所用示例参考和修改自网络资料。

1 beam search原理

Beam Search 是自回归生成模型 的一种解码策略,介于贪心搜索穷举搜索之间。

方法 原理 时间复杂度 结果质量
贪心搜索 每一步只选概率最高的词 O(n) 可能陷入局部最优
穷举搜索 考虑所有可能路径 O(Vⁿ) 全局最优,但不现实
Beam Search 每一步保留 top-K 个候选 O(K·n) 近似最优,效率高

2 beam search示例

这里使用一个简单的例子,并且假设模型已经训练好,可以通过函数获取下一个单词的概率分布。

假设这个简单的语言模型,用于生成五言诗句。

假设每一步预测只依赖于已生成的部分,并且有一个模型函数可以给出下一个词的概率。

步骤如下所示:

  1. 初始化beam,只包含一个空的序列,并且得分(通常为对数概率)为0。

  2. 对于每一步(直到达到最大长度或结束符):

    a. 对于当前beam中的每个序列,用模型预测下一个词的概率分布。

    b. 从每个序列的下一步概率中选出Top k个候选(k为beam size),计算新的序列和得分(当前得分加上对数概率)。

    c. 将所有候选序列按照得分排序,选出Top k个作为新的beam。

  3. 当所有序列都达到结束符或达到最大长度时,停止。然后从beam中选择得分最高的序列作为输出。

示例代码如下。

复制代码
import numpy as np
from collections import defaultdict
import math

# 模拟一个简单的语言模型输出概率
class SimpleLanguageModel:
    def __init__(self):
        # 词汇表
        self.vocab = ["春", "风", "花", "月", "夜", "江", "山", "水", "云", "雨", 
                      "人", "生", "梦", "酒", "歌", "<start>", "<end>"]
        self.word2id = {word: i for i, word in enumerate(self.vocab)}
        self.id2word = {i: word for i, word in enumerate(self.vocab)}
        
        # 模拟一个简单的概率分布 (实际中来自神经网络)
        self.prob_table = {
            "<start>": [("春", 0.4), ("风", 0.3), ("月", 0.2), ("江", 0.1)],
            "春": [("风", 0.5), ("花", 0.3), ("雨", 0.1), ("江", 0.1)],
            "风": [("花", 0.4), ("月", 0.3), ("云", 0.2), ("雨", 0.1)],
            "花": [("月", 0.5), ("夜", 0.3), ("春", 0.1), ("雨", 0.1)],
            "月": [("夜", 0.6), ("光", 0.2), ("明", 0.1), ("圆", 0.1)],
            "夜": [("色", 0.4), ("深", 0.3), ("晚", 0.2), ("来", 0.1)],
            "江": [("山", 0.5), ("水", 0.3), ("南", 0.1), ("流", 0.1)],
            "山": [("水", 0.6), ("色", 0.2), ("青", 0.1), ("高", 0.1)],
            "水": [("流", 0.5), ("声", 0.3), ("清", 0.1), ("深", 0.1)],
            "云": [("雨", 0.6), ("雾", 0.2), ("散", 0.1), ("飘", 0.1)],
            "雨": [("后", 0.5), ("中", 0.3), ("前", 0.1), ("夜", 0.1)],
            "人": [("生", 0.6), ("间", 0.2), ("心", 0.1), ("情", 0.1)],
            "生": [("梦", 0.5), ("死", 0.3), ("活", 0.1), ("命", 0.1)],
            "梦": [("里", 0.5), ("中", 0.3), ("幻", 0.1), ("醒", 0.1)],
            "酒": [("歌", 0.6), ("醉", 0.3), ("杯", 0.05), ("诗", 0.05)],
            "歌": [("声", 0.5), ("唱", 0.3), ("舞", 0.1), ("欢", 0.1)],
        }
    
    def predict_next_word(self, current_word):
        """返回下一个词的概率分布"""
        if current_word not in self.prob_table:
            # 如果不在表中,返回均匀分布
            return [(word, 1.0/len(self.vocab)) for word in self.vocab]
        return self.prob_table[current_word]
    
    def get_top_k(self, current_word, k=3):
        """返回top-k个候选词"""
        predictions = self.predict_next_word(current_word)
        predictions.sort(key=lambda x: x[1], reverse=True)
        return predictions[:k]

# 实现 Beam Search
def beam_search_decode(model, beam_width=3, max_length=5):
    """
    使用 Beam Search 生成诗句
    """
    # 初始 beam,只有一个空序列
    beams = [{
        'sequence': ['<start>'],
        'score': 0.0,  # 使用对数概率,初始为0(log(1)=0)
        'last_word': '<start>'
    }]
    
    for step in range(max_length):
        all_candidates = []
        
        for beam in beams:
            # 如果最后一个词是结束符,则不再扩展
            if beam['last_word'] == '<end>' or step == max_length - 1:
                all_candidates.append(beam)
                continue
            
            # 获取下一个词的候选
            top_k = model.get_top_k(beam['last_word'], beam_width)
            
            for word, prob in top_k:
                # 计算新得分(对数概率相加)
                # 为了避免浮点数下溢,使用对数概率
                new_score = beam['score'] + math.log(prob + 1e-10)  # 加一个小数避免log(0)
                
                # 构建新序列
                new_sequence = beam['sequence'] + [word]
                
                candidate = {
                    'sequence': new_sequence,
                    'score': new_score,
                    'last_word': word
                }
                all_candidates.append(candidate)
        
        # 按得分排序,保留最好的 beam_width 个
        all_candidates.sort(key=lambda x: x['score'], reverse=True)
        beams = all_candidates[:beam_width]
        
        print(f"\n第 {step+1} 步后的候选序列:")
        for i, beam in enumerate(beams):
            seq_str = ''.join([w for w in beam['sequence'] if w not in ['<start>', '<end>']])
            print(f"  Beam {i+1}: {seq_str:10s} 得分: {beam['score']:.3f}")
        
        # 检查是否所有beam都结束了
        if all(beam['last_word'] == '<end>' for beam in beams):
            break
    
    # 选择最佳序列
    best_beam = max(beams, key=lambda x: x['score'])
    # 去掉起始标记
    final_sequence = [w for w in best_beam['sequence'] if w not in ['<start>', '<end>']]
    
    return ''.join(final_sequence), best_beam['score']

# 运行示例
print("=" * 50)
print("Beam Search 诗词生成示例")
print("=" * 50)

model = SimpleLanguageModel()

# 使用不同beam宽度进行比较
for beam_width in [1, 2, 3]:
    print(f"\n{'='*30}")
    print(f"Beam Width = {beam_width}")
    print(f"{'='*30}")
    
    result, score = beam_search_decode(model, beam_width=beam_width, max_length=5)
    print(f"\n生成的诗句: {result}")
    print(f"最终得分: {score:.3f}")

输出如下所示,这里得分为log(prob),所以数值为负。

==================================================

Beam Search 诗词生成示例

==================================================

==============================

Beam Width = 1

==============================

第 1 步后的候选序列:

Beam 1: 春 得分: -0.916

第 2 步后的候选序列:

Beam 1: 春风 得分: -1.609

第 3 步后的候选序列:

Beam 1: 春风花 得分: -2.526

第 4 步后的候选序列:

Beam 1: 春风花月 得分: -3.219

第 5 步后的候选序列:

Beam 1: 春风花月 得分: -3.219

生成的诗句: 春风花月

最终得分: -3.219

==============================

Beam Width = 2

==============================

第 1 步后的候选序列:

Beam 1: 春 得分: -0.916

Beam 2: 风 得分: -1.204

第 2 步后的候选序列:

Beam 1: 春风 得分: -1.609

Beam 2: 春花 得分: -2.120

第 3 步后的候选序列:

Beam 1: 春风花 得分: -2.526

Beam 2: 春风月 得分: -2.813

第 4 步后的候选序列:

Beam 1: 春风花月 得分: -3.219

Beam 2: 春风月夜 得分: -3.324

第 5 步后的候选序列:

Beam 1: 春风花月 得分: -3.219

Beam 2: 春风月夜 得分: -3.324

生成的诗句: 春风花月

最终得分: -3.219

==============================

Beam Width = 3

==============================

第 1 步后的候选序列:

Beam 1: 春 得分: -0.916

Beam 2: 风 得分: -1.204

Beam 3: 月 得分: -1.609

第 2 步后的候选序列:

Beam 1: 春风 得分: -1.609

Beam 2: 月夜 得分: -2.120

Beam 3: 春花 得分: -2.120

第 3 步后的候选序列:

Beam 1: 春风花 得分: -2.526

Beam 2: 春风月 得分: -2.813

Beam 3: 春花月 得分: -2.813

第 4 步后的候选序列:

Beam 1: 春风花月 得分: -3.219

Beam 2: 春风月夜 得分: -3.324

Beam 3: 春花月夜 得分: -3.324

第 5 步后的候选序列:

Beam 1: 春风花月 得分: -3.219

Beam 2: 春风月夜 得分: -3.324

Beam 3: 春花月夜 得分: -3.324

生成的诗句: 春风花月

最终得分: -3.219

3 beam search应用

这里探索各种beam_search的变体应用。

带长度归一化的 Beam Search示例

复制代码
def beam_search_with_length_penalty(model, beam_width=3, max_length=10, alpha=0.7):
    """带长度归一化的 Beam Search"""
    beams = [{'sequence': ['<start>'], 'score': 0.0, 'last_word': '<start>'}]
    
    for step in range(max_length):
        all_candidates = []
        
        for beam in beams:
            if beam['last_word'] == '<end>':
                all_candidates.append(beam)
                continue
            
            top_k = model.get_top_k(beam['last_word'], beam_width)
            
            for word, prob in top_k:
                # 对数概率
                log_prob = math.log(prob + 1e-10)
                
                # 长度归一化:得分除以长度^alpha
                new_length = len(beam['sequence']) + 1
                length_penalty = (new_length ** alpha)
                
                # 新的得分(注意:因为是负对数概率,所以除以长度惩罚因子)
                new_score = (beam['score'] * (len(beam['sequence']) ** alpha) + log_prob) / length_penalty
                
                new_sequence = beam['sequence'] + [word]
                
                candidate = {
                    'sequence': new_sequence,
                    'score': new_score,
                    'last_word': word
                }
                all_candidates.append(candidate)
        
        # 排序并选择top-k
        all_candidates.sort(key=lambda x: x['score'], reverse=True)
        beams = all_candidates[:beam_width]
    
    # 返回最佳结果
    best_beam = max(beams, key=lambda x: x['score'])
    final_sequence = [w for w in best_beam['sequence'] if w not in ['<start>', '<end>']]
    return ''.join(final_sequence)

这时一个多样化的 Beam Search变体的示例代码。

复制代码
def diverse_beam_search(model, num_groups=2, beam_width=2, max_length=5):
    """多样化 Beam Search,增加结果的多样性"""
    # 初始化每个组
    groups = []
    for _ in range(num_groups):
        groups.append([{
            'sequence': ['<start>'],
            'score': 0.0,
            'last_word': '<start>'
        }])
    
    for step in range(max_length):
        all_group_candidates = []
        
        for group_idx, beams in enumerate(groups):
            group_candidates = []
            
            for beam in beams:
                if beam['last_word'] == '<end>':
                    group_candidates.append(beam)
                    continue
                
                top_k = model.get_top_k(beam['last_word'], beam_width * 2)
                
                for word, prob in top_k:
                    # 添加组间多样性惩罚
                    diversity_penalty = 0.0
                    # 检查其他组中是否已有这个词
                    for other_group_idx in range(num_groups):
                        if other_group_idx != group_idx:
                            # 简单惩罚:如果其他组的候选中有这个词,降低分数
                            # 这里简化处理,实际会更复杂
                            pass
                    
                    log_prob = math.log(prob + 1e-10) - diversity_penalty
                    new_score = beam['score'] + log_prob
                    
                    new_sequence = beam['sequence'] + [word]
                    
                    candidate = {
                        'sequence': new_sequence,
                        'score': new_score,
                        'last_word': word,
                        'group': group_idx
                    }
                    group_candidates.append(candidate)
            
            # 每组内排序
            group_candidates.sort(key=lambda x: x['score'], reverse=True)
            groups[group_idx] = group_candidates[:beam_width]
            all_group_candidates.extend(group_candidates[:beam_width])
    
    # 从所有组中选择最佳结果
    best_candidate = max(all_group_candidates, key=lambda x: x['score'])
    final_sequence = [w for w in best_candidate['sequence'] if w not in ['<start>', '<end>']]
    return ''.join(final_sequence)

示例代码如下所示

复制代码
def visualize_beam_search(model, beam_width=2, max_length=4):
    """可视化 Beam Search 的搜索过程"""
    beams = [{'sequence': ['<start>'], 'score': 0.0, 'last_word': '<start>'}]
    
    print("Beam Search 搜索过程可视化:")
    print("=" * 40)
    
    for step in range(max_length):
        print(f"\n时间步 {step+1}:")
        print("-" * 20)
        
        all_candidates = []
        
        for beam_idx, beam in enumerate(beams):
            print(f"\n扩展 Beam {beam_idx+1}: {' '.join(beam['sequence'])}")
            
            if beam['last_word'] == '<end>':
                all_candidates.append(beam)
                continue
            
            top_k = model.get_top_k(beam['last_word'], beam_width)
            
            for word, prob in top_k:
                new_score = beam['score'] + math.log(prob + 1e-10)
                new_sequence = beam['sequence'] + [word]
                
                candidate = {
                    'sequence': new_sequence,
                    'score': new_score,
                    'last_word': word
                }
                all_candidates.append(candidate)
                
                print(f"  -> {word} (p={prob:.3f}, 得分={new_score:.3f})")
        
        # 排序并选择top-k
        all_candidates.sort(key=lambda x: x['score'], reverse=True)
        beams = all_candidates[:beam_width]
        
        print(f"\n当前最佳候选:")
        for i, beam in enumerate(beams):
            seq_str = ' '.join(beam['sequence'])
            print(f"  Beam {i+1}: {seq_str:20s} 得分: {beam['score']:.3f}")
    
    # 最终结果
    print("\n" + "=" * 40)
    print("最终结果:")
    best_beam = max(beams, key=lambda x: x['score'])
    final_seq = [w for w in best_beam['sequence'] if w not in ['<start>', '<end>']]
    print(f"生成序列: {' '.join(final_seq)}")
    print(f"最终得分: {best_beam['score']:.3f}")

# 运行可视化
model = SimpleLanguageModel()
visualize_beam_search(model, beam_width=2, max_length=4)

输出如下所示

Beam Search 搜索过程可视化:

========================================

时间步 1:


扩展 Beam 1: <start>

-> 春 (p=0.400, 得分=-0.916)

-> 风 (p=0.300, 得分=-1.204)

当前最佳候选:

Beam 1: <start> 春 得分: -0.916

Beam 2: <start> 风 得分: -1.204

时间步 2:


扩展 Beam 1: <start> 春

-> 风 (p=0.500, 得分=-1.609)

-> 花 (p=0.300, 得分=-2.120)

扩展 Beam 2: <start> 风

-> 花 (p=0.400, 得分=-2.120)

-> 月 (p=0.300, 得分=-2.408)

当前最佳候选:

Beam 1: <start> 春 风 得分: -1.609

Beam 2: <start> 春 花 得分: -2.120

时间步 3:


扩展 Beam 1: <start> 春 风

-> 花 (p=0.400, 得分=-2.526)

-> 月 (p=0.300, 得分=-2.813)

扩展 Beam 2: <start> 春 花

-> 月 (p=0.500, 得分=-2.813)

-> 夜 (p=0.300, 得分=-3.324)

当前最佳候选:

Beam 1: <start> 春 风 花 得分: -2.526

Beam 2: <start> 春 风 月 得分: -2.813

时间步 4:


扩展 Beam 1: <start> 春 风 花

-> 月 (p=0.500, 得分=-3.219)

-> 夜 (p=0.300, 得分=-3.730)

扩展 Beam 2: <start> 春 风 月

-> 夜 (p=0.600, 得分=-3.324)

-> 光 (p=0.200, 得分=-4.423)

当前最佳候选:

Beam 1: <start> 春 风 花 月 得分: -3.219

Beam 2: <start> 春 风 月 夜 得分: -3.324

========================================

最终结果:

生成序列: 春 风 花 月

最终得分: -3.219

Click to add a cell.

reference


相关推荐
算法与编程之美2 小时前
深度学习任务中的多层卷积与全连接输出方法
人工智能·深度学习
Deepoch2 小时前
具身智能产业新范式:Deepoc开发板如何破解机器人智能化升级难题
人工智能·科技·机器人·开发板·具身模型·deepoc
浪子不回头4152 小时前
SGLang学习笔记
人工智能·笔记·学习
王琦03183 小时前
Python 函数详解
开发语言·python
胡伯来了3 小时前
13. Python打包工具- setuptools
开发语言·python
小鸡吃米…3 小时前
Python 中的多层继承
开发语言·python
飞哥数智坊3 小时前
TRAE 国内版 SOLO 全放开
人工智能·ai编程·trae
中國移动丶移不动3 小时前
Python MySQL 数据库操作完整示例
数据库·python·mysql
落叶,听雪3 小时前
AI建站推荐
大数据·人工智能·python