束搜索算法及其实现

文章目录

束搜索算法及其实现

一、束搜索(Beam Search)算法详解

(一) 核心思想

束搜索是一种启发式图搜索算法 ,用于序列生成任务(如机器翻译、文本生成)。它通过维护固定数量的候选序列(称为 束宽 k ),在每一步扩展时仅保留概率最高的 k 个路径,避免穷举所有可能序列,平衡效率与质量。

(二)工作流程
  1. 初始化 :起始标记(如 <start>)加入候选序列,初始得分(概率对数)为 0。
  2. 迭代扩展
    • 对当前每个候选序列,生成所有可能的下一词元。
    • 计算新序列的累积对数概率:新得分 = 原得分 + log(当前词概率)
  3. 剪枝 :从所有新候选序列中选择得分最高的 k 个,其余丢弃。
  4. 终止条件
    • 所有候选序列生成结束标记 <end>
    • 达到预设最大长度。
  5. 输出:选择最终得分最高的序列。
(三)与贪心搜索的对比**
特性 贪心搜索 (Greedy) 束搜索 (Beam Search)
搜索宽度 1(每步选最优词) k(保留多个候选)
全局最优性 易陷局部最优 更接近全局最优
计算成本 较高(约为贪心的 k 倍)
输出质量 可能重复或不连贯 更流畅、合理

关键点 :束宽 k 越大,生成质量越高,但计算开销越大;k=1 时退化为贪心搜索。


二、PyTorch 实现束搜索

以下是一个支持批量处理的束搜索实现,适用于 Seq2Seq 模型(如 LSTM 或 Transformer):

python 复制代码
import torch
import torch.nn.functional as F

def beam_search(model, encoder_output, beam_width=3, max_len=20, device="cuda"):
    """
    参数:
        model: 解码器模型(需实现 `get_next_token_probs` 方法)
        encoder_output: 编码器输出张量 (batch_size, seq_len, hidden_size)
        beam_width: 束宽 k
        max_len: 最大生成长度
    """
    batch_size = encoder_output.size(0)
    start_token = torch.tensor([0]).to(device)  # 起始标记(假设0)
    end_token = 1  # 结束标记(假设1)

    # 初始化束:[(序列, 得分)] * batch_size
    beams = [[[start_token], 0.0]] * batch_size
    active_beams = [True] * batch_size  # 标记活跃的束

    for _ in range(max_len):
        new_beams = [[] for _ in range(batch_size)]
        for i in range(batch_size):
            if not active_beams[i]:
                new_beams[i] = beams[i]
                continue

            seq, score = beams[i][-1] if beams[i] else ([start_token], 0.0)
            last_token = seq[-1]
            
            # 获取下一词概率分布 (vocab_size)
            with torch.no_grad():
                logits = model(last_token.unsqueeze(0), encoder_output[i].unsqueeze(0))
                probs = F.log_softmax(logits, dim=-1).squeeze(0)  # 对数概率
            
            # 扩展候选:当前得分 + log(新词概率)
            topk_probs, topk_tokens = torch.topk(probs, beam_width)
            for j in range(beam_width):
                new_seq = seq + [topk_tokens[j].item()]
                new_score = score + topk_probs[j].item()
                candidate = (new_seq, new_score)
                new_beams[i].append(candidate)

            # 剪枝:保留得分最高的 k 个候选
            new_beams[i] = sorted(new_beams[i], key=lambda x: x[1], reverse=True)[:beam_width]
        
        # 检查是否所有束均终止
        for i in range(batch_size):
            active_beams[i] = any(token[-1][0][-1] != end_token for token in new_beams[i])
        
        beams = new_beams
        if not any(active_beams):
            break

    # 返回每个样本的最优序列
    return [max(beam, key=lambda x: x[1])[0] for beam in beams]
代码说明:
  1. 概率处理 :使用 log_softmax 避免数值下溢,得分累加使用对数概率。
  2. 剪枝优化 :每步仅保留 k 个高分候选,降低计算复杂度。
  3. 终止条件 :遇到 <end> 标记或达到最大长度时停止。
  4. 批量支持:独立处理每个样本的束,适合 GPU 并行。

三、使用示例:文本生成

场景 :使用 LSTM 模型生成句子,束宽 k=3
python 复制代码
# 假设已定义 LSTM 解码器模型
class DecoderLSTM(torch.nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
        self.lstm = torch.nn.LSTM(hidden_size, hidden_size)
        self.fc = torch.nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden):
        embedded = self.embedding(x)
        output, hidden = self.lstm(embedded, hidden)
        logits = self.fc(output)
        return logits, hidden

# 模拟输入
vocab_size = 10000
hidden_size = 256
encoder_output = torch.randn(2, 10, hidden_size).to("cuda")  # 批量大小=2
model = DecoderLSTM(vocab_size, hidden_size).to("cuda")

# 执行束搜索
generated_seqs = beam_search(model, encoder_output, beam_width=3, max_len=15)
print(generated_seqs)  # 输出:[[0, 42, 17, 1], [0, 88, 3, 7, 1]] 
效果分析
  • 束宽 k=1 (贪心搜索):可能生成 ["I am am good"] 等重复句子。
  • 束宽 k=3 :生成 ["I am good"]["You are great"] 等更合理序列。

四、进阶技巧与注意事项

  1. 长度归一化

    • 长序列因概率连乘得分偏低,需除以长度:score = score / len(seq)^αα 常取 0.7)。
  2. 多样性优化

    • 组束搜索(Group Beam Search):将束分为多组,添加多样性惩罚避免相似输出。
    python 复制代码
    # Hugging Face Transformers 示例
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
    model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
    outputs = model.generate(
        inputs,
        num_beams=8,
        num_beam_groups=4,  # 分4组
        diversity_penalty=1.0  # 多样性惩罚强度
    )
  3. 与采样策略结合

    • Top-K 采样 :从概率最高的 K 个词中随机选择,增加多样性(如创意写作)。
    • Top-p(核)采样 :动态选择累积概率 ≥ p 的最小词集(如对话生成)。

五、适用场景对比

解码策略 特点 适用场景
贪心搜索 速度快,质量低 实时响应、结构化生成
束搜索 质量高,计算中等 机器翻译、图像描述
Top-K 采样 随机性高,多样性好 诗歌生成、故事创作
Top-p 采样 自适应多样性 对话系统、代码生成

实践建议 :多数任务优先用束搜索(k=3~10);需创造性输出时结合采样策略。

束搜索通过高效平衡搜索质量与计算成本,成为序列生成任务的主流解码方法。理解其原理及实现细节,能显著提升生成模型的输出质量。