【llm相关】受限解码

受限解码

Trie树,参考:https://github.com/HonghuiBao2000/LETTER/blob/master/LETTER-TIGER/generation_trie.py

前缀允许函数:

python 复制代码
def prefix_allowed_tokens_fn(candidate_trie:Trie):
    """
    创建一个前缀允许函数,用于约束解码
    
    Args:
        candidate_trie: 构建好的Trie树
    
    Returns:
        prefix_allowed_tokens函数,用于model.generate的prefix_allowed_tokens_fn参数
    """
    def prefix_allowed_tokens(batch_id:int, sentence:torch.Tensor):
        """
        对每个batch中的序列,返回允许的下一个token ID列表
        
        Args:
            batch_id: 批次中的序列索引
            sentence: 到目前为止生成的token序列
        
        Returns:
            允许的下一个token ID列表
        """
        sentence = sentence.tolist()
        idx = sentence.index(TOKENID_assistant)  # assistant
        # idx: <|im_start|>, idx+1: assistant, idx+2: \n
        # right_part = sentence[idx+3:]
        right_part = sentence[idx-1:]
        trie_out = candidate_trie.get(right_part)
        # log(f'batch_id: {batch_id}. right_part: {right_part}. trie_out: {trie_out}')
        return trie_out

    return prefix_allowed_tokens
transformers generate版受限推理

受限推理

python 复制代码
def constrained_inference(
    model,
    tokenizer,
    prompt: str,
    prefix_allowed_tokens,
    max_length: int = 150,
    num_beams: int = 5,
    num_return_sequences: int = 1,
    device: str = "cuda:0"
) -> List[str]:
    """
    使用受限解码进行推理
    
    Args:
        model: 训练好的语言模型
        tokenizer: 分词器
        prompt: 输入提示文本
        prefix_allowed_tokens: 前缀允许函数
        max_length: 最大生成长度
        num_beams: beam search的beam数
        num_return_sequences: 返回序列数
        device: 设备(cuda或cpu)
    
    Returns:
        生成的文本列表
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    output_ids = model.generate(
        input_ids,
        max_length=max_length,
        prefix_allowed_tokens_fn=prefix_allowed_tokens,
        num_beams=num_beams,
        num_return_sequences=num_return_sequences,
        early_stopping=True,
        pad_token_id=tokenizer.eos_token_id
    )
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)  # 必须是False 因为我添加了自定义的token
    return outputs

大坑:

https://blog.csdn.net/h10666/article/details/140432219

https://github.com/huggingface/transformers/issues/27676

https://github.com/huggingface/transformers/issues/15169#issuecomment-1018617055

莫名其妙地不从我的候选item里选token。

好像是因为候选token的概率分数都是-inf,而-inf是不允许的?

他们没空修:https://github.com/huggingface/transformers/issues/22890

bash 复制代码
conda create -n beam python=3.8
pip index versions transformers
pip install transformers==4.26.0
transformers 逐步版受限推理

参考MTGRec的beam search,我再加了个受限推理

MTGRec是多batch_size的,我为了简单改成了batch_size固定为1(嘿嘿

MTGRec:https://github.com/RUCAIBox/MTGRec/blob/main/model.py#L125

python 复制代码
def constrained_inference(
    model,
    tokenizer,
    prompt: str,
    prefix_allowed_tokens,
    max_prediction_length: int = 3,
    num_beams: int = 5,
    num_return_sequences: int = 1,
    device: str = "cuda:0"
) -> List[str]:
    """
    使用受限解码进行推理
    https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/generation/utils.py#2823
    Args:
        model: 训练好的语言模型
        tokenizer: 分词器
        prompt: 输入提示文本
        prefix_allowed_tokens: 前缀允许函数
        max_prediction_length: 最大生成长度
        num_beams: beam search的beam数
        num_return_sequences: 返回序列数
        device: 设备(cuda或cpu)
    
    Returns:
        生成的文本列表
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)  # (1, n_token) 注意这里只能是1 因为trie树没有考虑pad

    # beam search
    input_ids = input_ids.repeat_interleave(num_beams, dim=0)  # (num_beams, n_token)
    scores = torch.zeros((input_ids.shape[0], 1), dtype=torch.float).to(device)  # (num_beams, 1)
    scores[1:, :] = -1e9
    with torch.no_grad():
        for step in range(max_prediction_length):
            outputs = model(input_ids)
            logits = outputs.logits[:, -1, :]  # 取最后一个token的logits  (1, n_vocabulary) ==> (num_beams, n_vocabulary)
            if num_beams <= 1:
                next_token = torch.argmax(logits, dim=-1, keepdim=True)  # (1, 1)
                input_ids = torch.cat([input_ids, next_token], dim=-1)
            else:
                vocab_size = logits.shape[-1]
                # 受限解码
                prefix = torch.tensor([TOKENID_im_start, TOKENID_assistant, TOKENID_empty], dtype=input_ids.dtype).to(device)  # [3]
                prefix = prefix.unsqueeze(0).repeat_interleave(num_beams, dim=0)  # (num_beams, 3)
                if step != 0:
                    prefix = torch.cat([prefix, input_ids[:, -step:]], dim=-1)  # (num_beams, 3+step)
                allowed_tokens = [
                    prefix_allowed_tokens(i, prefix[i, :]) 
                    for i in range(num_beams)
                ]  # len == num_beams
                filter_beams = [[i] * len(allowed_tokens[i]) for i in range(num_beams)]
                filter_beams = torch.tensor(list(chain.from_iterable(filter_beams)), dtype=input_ids.dtype).to(device)
                allowed_tokens = torch.tensor(list(chain.from_iterable(allowed_tokens)), dtype=input_ids.dtype).to(device)
                base_scores = -1e9 * torch.ones((num_beams, vocab_size), dtype=torch.float).to(device)
                base_scores[filter_beams, allowed_tokens] = 0

                # 取2*beams的topk
                n_topk = 2 * num_beams  # 预留受限解码空间
                new_scores = scores.repeat_interleave(vocab_size, dim=-1)  # (num_beams, n_vocabulary)
                new_scores += base_scores  # 受限解码分数
                new_scores += torch.log_softmax(logits, dim=-1)
                new_scores = new_scores.view(1, -1)  # (1, num_beams * n_vocabulary)

                next_token_scores, _next_tokens = torch.topk(new_scores, n_topk, dim=-1, largest=True, sorted=True)
                # next_token_scores: (1, n_topk), next_tokens: (1, n_topk)
                indices_beam = torch.div(_next_tokens, vocab_size, rounding_mode="floor")
                indices_token = _next_tokens % vocab_size

                scores = new_scores.view(num_beams, -1)[indices_beam, indices_token].reshape(-1, 1)  # (n_topk, 1)
                beam_next_tokens = indices_token[:, :n_topk].reshape(-1,1)  # (n_topk, 1)
                input_ids = torch.cat([input_ids[indices_beam, :].squeeze(0), beam_next_tokens], dim=-1)  # (n_topk, n_token+1)

                # 受限解码
                scores = scores[:num_beams,:]
                input_ids = input_ids[:num_beams, :]

    outputs = tokenizer.batch_decode(input_ids[:, -3:], skip_special_tokens=False)  # List[str] len(outputs)==num_beams
    outputs = outputs[:num_return_sequences]
    return outputs

感觉预留的目的应该是 先推理然后筛选掉不合法的 再保留num_beams个

但是我算分数的时候就直接过滤了,所以好像其实不用搞2*num_beams了

相关推荐
爱可生开源社区4 小时前
SCALE | 2025 年 11 月《大模型 SQL 能力排行榜》发布
数据库·sql·llm
思想的光芒5 小时前
Agno开发教程系列(六):工具系统详解教程
llm
lew-yu6 小时前
当前开源旗舰LLM主流架构介绍
架构·大模型·llm
带刺的坐椅7 小时前
Solon AI 开发学习11 - chat - 工具调用与定制(Tool Call)
java·ai·llm·solon
破烂pan18 小时前
SGLang启动参数详解
llm·模型部署·sglang
黑客思维者18 小时前
LLM底层原理学习笔记:Adam优化器为何能征服巨型模型成为深度学习的“速度与稳定之王”
笔记·深度学习·学习·llm·adam优化器
Ma04071320 小时前
什么是幻觉
llm·幻觉
带刺的坐椅1 天前
Solon AI 开发学习10 - chat - 工具调用概念介绍
ai·llm·solon·tool-call
国家不保护废物1 天前
RAG + Agent + Prompt工程上
llm