大模型系列——投机解码:Prompt Lookup Decoding代码解读

官方代码见:GitHub - apoorvumang/prompt-lookup-decoding

UPDATE 2 : This method is now available in vLLM as well by setting speculative_model="[ngram]" 🥳

UPDATE : This has been added to the transformers library. Please see this for a code example, or simply add prompt_lookup_num_tokens=10 to your model.generate(...) call.

TLDR : We modify speculative decoding where we replace the draft model with simple string matching in the prompt to generate candidate token sequences. This results in significant speedups (2x-4x) in input-grounded tasks, with no effect on output quality. This method can be used with any decoder model without model changes or external datastore, and with both greedy and sampling techniques.

Intuition : In several LLM use cases where you're doing input grounded generation (summarization, document QA, multi-turn chat, code editing), there is high n-gram overlap between LLM input (prompt) and LLM output. This could be entity names, phrases, or code chunks that the LLM directly copies from the input while generating the output. Prompt lookup exploits this pattern to speed up autoregressive decoding in LLMs.

python 复制代码
def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=10):
    input_length = input_ids.size(1)

    for ngram_size in range(max_ngram_size, 0, -1):
        # Extract the last n tokens as our search ngram
        ngram = input_ids[0, -ngram_size:].tolist()

        # Create sliding windows of size ngram_size
        windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)

        # Convert ngram to a tensor for comparison
        ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0)

        # Find where the windows match the ngram
        matches = (windows == ngram_tensor).all(dim=2)

        # Get the indices of matches
        match_indices = matches.nonzero(as_tuple=True)[1]

        # Iterate through match indices to find a valid continuation
        for idx in match_indices:
            start_idx = idx + ngram_size
            end_idx = start_idx + num_pred_tokens
            # Ensure we don't go beyond the length of input_ids and avoid self-match
            if end_idx <= input_length and start_idx < input_length - ngram_size:
                return input_ids[0, start_idx:end_idx]

    # If no match is found, return an empty tensor
    return torch.tensor([], dtype=torch.long, device=input_ids.device)

ODOs/Thoughts/Future work

  • There's probably better ways to do stringmatching than the current one, and there are several obvious things to improve eg. what to do when there are multiple matches? Whats the ideal length of continuation?
  • We haven't yet tried sampling, although there's no reason it shouldn't work.
    • Here, one additional thing to test would be whether prompt lookup while sampling can affect hallucination rates, since this artifically increases probability of sampling exact sequences from input (this was suggest by my colleague Shwetha S)
  • Testing actual FLOPs impact and tradeoffs is needed
  • Also need to figure out best hyperparams - 3 and 10 were chosen on very little testing
  • It would be an interesting challenge to design the "best lookup function" for decoding, could even be a competition?

这个方法可能还是有问题的,正如坐着所说,可能存在幻觉,不一定ngram匹配上的就能加速

相关推荐
效率客栈老秦15 小时前
Python Trae提示词开发实战(2):2026 最新 10个自动化批处理场景 + 完整代码
人工智能·python·ai·prompt·trae
GISer_Jing16 小时前
提示链(Prompt Chaining)、路由、并行化和反思
人工智能·设计模式·prompt·aigc
Blossom.11818 小时前
知识图谱增强大模型:构建可解释的行业智能搜索引擎
运维·人工智能·python·智能手机·自动化·prompt·知识图谱
AI Echoes1 天前
LangChain 非分割类型的文档转换器使用技巧
人工智能·python·langchain·prompt·agent
YuTaoShao1 天前
【Prompt】Prompt 工程入门指南
人工智能·llm·prompt·提示词
前端程序猿之路2 天前
30天大模型学习之Day 2:Prompt 工程基础系统
大数据·人工智能·学习·算法·语言模型·prompt·ai编程
Blossom.1182 天前
基于多智能体协作的自动化数据分析系统实践:从单点工具到全流程智能
运维·人工智能·分布式·智能手机·自动化·prompt·边缘计算
AI 智能服务2 天前
第2课___结构化输出与 Prompt 设计
人工智能·机器学习·prompt
paopao_wu2 天前
LangChainV1.0[06]-Prompt/上下文/结构化输出
人工智能·langchain·prompt·ai编程
寂寞恋上夜2 天前
字段校验规则清单:必填/范围/唯一/组合唯一/正则(附校验表)
人工智能·prompt·测试用例·markdown转xmind·deepseek思维导图