【llm相关】受限解码

受限解码

Trie树,参考:https://github.com/HonghuiBao2000/LETTER/blob/master/LETTER-TIGER/generation_trie.py

前缀允许函数:

python 复制代码
def prefix_allowed_tokens_fn(candidate_trie:Trie):
    """
    创建一个前缀允许函数,用于约束解码
    
    Args:
        candidate_trie: 构建好的Trie树
    
    Returns:
        prefix_allowed_tokens函数,用于model.generate的prefix_allowed_tokens_fn参数
    """
    def prefix_allowed_tokens(batch_id:int, sentence:torch.Tensor):
        """
        对每个batch中的序列,返回允许的下一个token ID列表
        
        Args:
            batch_id: 批次中的序列索引
            sentence: 到目前为止生成的token序列
        
        Returns:
            允许的下一个token ID列表
        """
        sentence = sentence.tolist()
        idx = sentence.index(TOKENID_assistant)  # assistant
        # idx: <|im_start|>, idx+1: assistant, idx+2: \n
        # right_part = sentence[idx+3:]
        right_part = sentence[idx-1:]
        trie_out = candidate_trie.get(right_part)
        # log(f'batch_id: {batch_id}. right_part: {right_part}. trie_out: {trie_out}')
        return trie_out

    return prefix_allowed_tokens
transformers generate版受限推理

受限推理

python 复制代码
def constrained_inference(
    model,
    tokenizer,
    prompt: str,
    prefix_allowed_tokens,
    max_length: int = 150,
    num_beams: int = 5,
    num_return_sequences: int = 1,
    device: str = "cuda:0"
) -> List[str]:
    """
    使用受限解码进行推理
    
    Args:
        model: 训练好的语言模型
        tokenizer: 分词器
        prompt: 输入提示文本
        prefix_allowed_tokens: 前缀允许函数
        max_length: 最大生成长度
        num_beams: beam search的beam数
        num_return_sequences: 返回序列数
        device: 设备(cuda或cpu)
    
    Returns:
        生成的文本列表
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    output_ids = model.generate(
        input_ids,
        max_length=max_length,
        prefix_allowed_tokens_fn=prefix_allowed_tokens,
        num_beams=num_beams,
        num_return_sequences=num_return_sequences,
        early_stopping=True,
        pad_token_id=tokenizer.eos_token_id
    )
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)  # 必须是False 因为我添加了自定义的token
    return outputs

大坑:

https://blog.csdn.net/h10666/article/details/140432219

https://github.com/huggingface/transformers/issues/27676

https://github.com/huggingface/transformers/issues/15169#issuecomment-1018617055

莫名其妙地不从我的候选item里选token。

好像是因为候选token的概率分数都是-inf,而-inf是不允许的?

他们没空修:https://github.com/huggingface/transformers/issues/22890

  • 最后,新建一个低版本的transformer环境来跑,参考https://blog.csdn.net/h10666/article/details/140432219设置了版本4.26.0
  • 问题:我用的qwen2.5,4.26.0不兼容qwen2.5.。。。
bash 复制代码
conda create -n beam python=3.8
pip index versions transformers
pip install transformers==4.26.0
transformers 逐步版受限推理

参考MTGRec的beam search,我再加了个受限推理

MTGRec是多batch_size的,我为了简单改成了batch_size固定为1(嘿嘿

MTGRec:https://github.com/RUCAIBox/MTGRec/blob/main/model.py#L125

python 复制代码
def constrained_inference(
    model,
    tokenizer,
    prompt: str,
    prefix_allowed_tokens,
    max_prediction_length: int = 3,
    num_beams: int = 5,
    num_return_sequences: int = 1,
    device: str = "cuda:0"
) -> List[str]:
    """
    使用受限解码进行推理
    https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/generation/utils.py#2823
    Args:
        model: 训练好的语言模型
        tokenizer: 分词器
        prompt: 输入提示文本
        prefix_allowed_tokens: 前缀允许函数
        max_prediction_length: 最大生成长度
        num_beams: beam search的beam数
        num_return_sequences: 返回序列数
        device: 设备(cuda或cpu)
    
    Returns:
        生成的文本列表
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)  # (1, n_token) 注意这里只能是1 因为trie树没有考虑pad

    # beam search
    input_ids = input_ids.repeat_interleave(num_beams, dim=0)  # (num_beams, n_token)
    scores = torch.zeros((input_ids.shape[0], 1), dtype=torch.float).to(device)  # (num_beams, 1)
    scores[1:, :] = -1e9
    with torch.no_grad():
        for step in range(max_prediction_length):
            outputs = model(input_ids)
            logits = outputs.logits[:, -1, :]  # 取最后一个token的logits  (1, n_vocabulary) ==> (num_beams, n_vocabulary)
            if num_beams <= 1:
                next_token = torch.argmax(logits, dim=-1, keepdim=True)  # (1, 1)
                input_ids = torch.cat([input_ids, next_token], dim=-1)
            else:
                vocab_size = logits.shape[-1]
                # 受限解码
                prefix = torch.tensor([TOKENID_im_start, TOKENID_assistant, TOKENID_empty], dtype=input_ids.dtype).to(device)  # [3]
                prefix = prefix.unsqueeze(0).repeat_interleave(num_beams, dim=0)  # (num_beams, 3)
                if step != 0:
                    prefix = torch.cat([prefix, input_ids[:, -step:]], dim=-1)  # (num_beams, 3+step)
                allowed_tokens = [
                    prefix_allowed_tokens(i, prefix[i, :]) 
                    for i in range(num_beams)
                ]  # len == num_beams
                filter_beams = [[i] * len(allowed_tokens[i]) for i in range(num_beams)]
                filter_beams = torch.tensor(list(chain.from_iterable(filter_beams)), dtype=input_ids.dtype).to(device)
                allowed_tokens = torch.tensor(list(chain.from_iterable(allowed_tokens)), dtype=input_ids.dtype).to(device)
                base_scores = -1e9 * torch.ones((num_beams, vocab_size), dtype=torch.float).to(device)
                base_scores[filter_beams, allowed_tokens] = 0

                # 取2*beams的topk
                n_topk = 2 * num_beams  # 预留受限解码空间
                new_scores = scores.repeat_interleave(vocab_size, dim=-1)  # (num_beams, n_vocabulary)
                new_scores += base_scores  # 受限解码分数
                new_scores += torch.log_softmax(logits, dim=-1)
                new_scores = new_scores.view(1, -1)  # (1, num_beams * n_vocabulary)

                next_token_scores, _next_tokens = torch.topk(new_scores, n_topk, dim=-1, largest=True, sorted=True)
                # next_token_scores: (1, n_topk), next_tokens: (1, n_topk)
                indices_beam = torch.div(_next_tokens, vocab_size, rounding_mode="floor")
                indices_token = _next_tokens % vocab_size

                scores = new_scores.view(num_beams, -1)[indices_beam, indices_token].reshape(-1, 1)  # (n_topk, 1)
                beam_next_tokens = indices_token[:, :n_topk].reshape(-1,1)  # (n_topk, 1)
                input_ids = torch.cat([input_ids[indices_beam, :].squeeze(0), beam_next_tokens], dim=-1)  # (n_topk, n_token+1)

                # 受限解码
                scores = scores[:num_beams,:]
                input_ids = input_ids[:num_beams, :]

    outputs = tokenizer.batch_decode(input_ids[:, -3:], skip_special_tokens=False)  # List[str] len(outputs)==num_beams
    outputs = outputs[:num_return_sequences]
    return outputs

感觉预留的目的应该是 先推理然后筛选掉不合法的 再保留num_beams个

但是我算分数的时候就直接过滤了,所以好像其实不用搞2*num_beams了

相关推荐
慢慢向上的蜗牛1 小时前
Qwen3-0.6B ONNX(KV-Cache)模型部署
llm·onnx·文本生成·自回归·kv-cache
Java陈序员2 小时前
一键测算!一款筛选本机可流畅运行的大模型终端工具!
rust·llm
Together_CZ3 小时前
OpenCV 5.0 重磅发布:全面技术深度解析
图像处理·人工智能·opencv·计算机视觉·llm·dnn·推理
呆呆敲代码的小Y4 小时前
CodeGraph 使用教程:专为代码库打造的知识图谱
人工智能·ai·llm·知识图谱·代码库·codegraph·代码知识库
qcx234 小时前
【AI daily 2026-06-10】RAG 2026 已进入“Agentic RAG“时代
人工智能·ai·llm·agent·agi
海棠AI实验室4 小时前
AI 时代文献综述:从检索到成稿的 RAG 五步法
windows·算法·自动化·llm·rag
冬奇Lab16 小时前
Agent 系列(18):成本与性能优化——省钱且更快
人工智能·llm·agent
吴佳浩17 小时前
Hermes vs OpenClaw:基于源码的 Agent Loop 全面分析
人工智能·llm·agent
AndrewHZ19 小时前
【LLM技术全景】规模定律与模型演进:为什么模型越大越强?
人工智能·gpt·深度学习·语言模型·llm·openai·规模定律
装不满的克莱因瓶19 小时前
了解 LangChain 中的 LLM 与 ChatModel 的差异
人工智能·python·ai·langchain·llm·agent·chatmodel