【llm相关】受限解码

受限解码

Trie树,参考:https://github.com/HonghuiBao2000/LETTER/blob/master/LETTER-TIGER/generation_trie.py

前缀允许函数:

python 复制代码
def prefix_allowed_tokens_fn(candidate_trie:Trie):
    """
    创建一个前缀允许函数,用于约束解码
    
    Args:
        candidate_trie: 构建好的Trie树
    
    Returns:
        prefix_allowed_tokens函数,用于model.generate的prefix_allowed_tokens_fn参数
    """
    def prefix_allowed_tokens(batch_id:int, sentence:torch.Tensor):
        """
        对每个batch中的序列,返回允许的下一个token ID列表
        
        Args:
            batch_id: 批次中的序列索引
            sentence: 到目前为止生成的token序列
        
        Returns:
            允许的下一个token ID列表
        """
        sentence = sentence.tolist()
        idx = sentence.index(TOKENID_assistant)  # assistant
        # idx: <|im_start|>, idx+1: assistant, idx+2: \n
        # right_part = sentence[idx+3:]
        right_part = sentence[idx-1:]
        trie_out = candidate_trie.get(right_part)
        # log(f'batch_id: {batch_id}. right_part: {right_part}. trie_out: {trie_out}')
        return trie_out

    return prefix_allowed_tokens
transformers generate版受限推理

受限推理

python 复制代码
def constrained_inference(
    model,
    tokenizer,
    prompt: str,
    prefix_allowed_tokens,
    max_length: int = 150,
    num_beams: int = 5,
    num_return_sequences: int = 1,
    device: str = "cuda:0"
) -> List[str]:
    """
    使用受限解码进行推理
    
    Args:
        model: 训练好的语言模型
        tokenizer: 分词器
        prompt: 输入提示文本
        prefix_allowed_tokens: 前缀允许函数
        max_length: 最大生成长度
        num_beams: beam search的beam数
        num_return_sequences: 返回序列数
        device: 设备(cuda或cpu)
    
    Returns:
        生成的文本列表
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    output_ids = model.generate(
        input_ids,
        max_length=max_length,
        prefix_allowed_tokens_fn=prefix_allowed_tokens,
        num_beams=num_beams,
        num_return_sequences=num_return_sequences,
        early_stopping=True,
        pad_token_id=tokenizer.eos_token_id
    )
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)  # 必须是False 因为我添加了自定义的token
    return outputs

大坑:

https://blog.csdn.net/h10666/article/details/140432219

https://github.com/huggingface/transformers/issues/27676

https://github.com/huggingface/transformers/issues/15169#issuecomment-1018617055

莫名其妙地不从我的候选item里选token。

好像是因为候选token的概率分数都是-inf,而-inf是不允许的?

他们没空修:https://github.com/huggingface/transformers/issues/22890

bash 复制代码
conda create -n beam python=3.8
pip index versions transformers
pip install transformers==4.26.0
transformers 逐步版受限推理

参考MTGRec的beam search,我再加了个受限推理

MTGRec是多batch_size的,我为了简单改成了batch_size固定为1(嘿嘿

MTGRec:https://github.com/RUCAIBox/MTGRec/blob/main/model.py#L125

python 复制代码
def constrained_inference(
    model,
    tokenizer,
    prompt: str,
    prefix_allowed_tokens,
    max_prediction_length: int = 3,
    num_beams: int = 5,
    num_return_sequences: int = 1,
    device: str = "cuda:0"
) -> List[str]:
    """
    使用受限解码进行推理
    https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/generation/utils.py#2823
    Args:
        model: 训练好的语言模型
        tokenizer: 分词器
        prompt: 输入提示文本
        prefix_allowed_tokens: 前缀允许函数
        max_prediction_length: 最大生成长度
        num_beams: beam search的beam数
        num_return_sequences: 返回序列数
        device: 设备(cuda或cpu)
    
    Returns:
        生成的文本列表
    """
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)  # (1, n_token) 注意这里只能是1 因为trie树没有考虑pad

    # beam search
    input_ids = input_ids.repeat_interleave(num_beams, dim=0)  # (num_beams, n_token)
    scores = torch.zeros((input_ids.shape[0], 1), dtype=torch.float).to(device)  # (num_beams, 1)
    scores[1:, :] = -1e9
    with torch.no_grad():
        for step in range(max_prediction_length):
            outputs = model(input_ids)
            logits = outputs.logits[:, -1, :]  # 取最后一个token的logits  (1, n_vocabulary) ==> (num_beams, n_vocabulary)
            if num_beams <= 1:
                next_token = torch.argmax(logits, dim=-1, keepdim=True)  # (1, 1)
                input_ids = torch.cat([input_ids, next_token], dim=-1)
            else:
                vocab_size = logits.shape[-1]
                # 受限解码
                prefix = torch.tensor([TOKENID_im_start, TOKENID_assistant, TOKENID_empty], dtype=input_ids.dtype).to(device)  # [3]
                prefix = prefix.unsqueeze(0).repeat_interleave(num_beams, dim=0)  # (num_beams, 3)
                if step != 0:
                    prefix = torch.cat([prefix, input_ids[:, -step:]], dim=-1)  # (num_beams, 3+step)
                allowed_tokens = [
                    prefix_allowed_tokens(i, prefix[i, :]) 
                    for i in range(num_beams)
                ]  # len == num_beams
                filter_beams = [[i] * len(allowed_tokens[i]) for i in range(num_beams)]
                filter_beams = torch.tensor(list(chain.from_iterable(filter_beams)), dtype=input_ids.dtype).to(device)
                allowed_tokens = torch.tensor(list(chain.from_iterable(allowed_tokens)), dtype=input_ids.dtype).to(device)
                base_scores = -1e9 * torch.ones((num_beams, vocab_size), dtype=torch.float).to(device)
                base_scores[filter_beams, allowed_tokens] = 0

                # 取2*beams的topk
                n_topk = 2 * num_beams  # 预留受限解码空间
                new_scores = scores.repeat_interleave(vocab_size, dim=-1)  # (num_beams, n_vocabulary)
                new_scores += base_scores  # 受限解码分数
                new_scores += torch.log_softmax(logits, dim=-1)
                new_scores = new_scores.view(1, -1)  # (1, num_beams * n_vocabulary)

                next_token_scores, _next_tokens = torch.topk(new_scores, n_topk, dim=-1, largest=True, sorted=True)
                # next_token_scores: (1, n_topk), next_tokens: (1, n_topk)
                indices_beam = torch.div(_next_tokens, vocab_size, rounding_mode="floor")
                indices_token = _next_tokens % vocab_size

                scores = new_scores.view(num_beams, -1)[indices_beam, indices_token].reshape(-1, 1)  # (n_topk, 1)
                beam_next_tokens = indices_token[:, :n_topk].reshape(-1,1)  # (n_topk, 1)
                input_ids = torch.cat([input_ids[indices_beam, :].squeeze(0), beam_next_tokens], dim=-1)  # (n_topk, n_token+1)

                # 受限解码
                scores = scores[:num_beams,:]
                input_ids = input_ids[:num_beams, :]

    outputs = tokenizer.batch_decode(input_ids[:, -3:], skip_special_tokens=False)  # List[str] len(outputs)==num_beams
    outputs = outputs[:num_return_sequences]
    return outputs

感觉预留的目的应该是 先推理然后筛选掉不合法的 再保留num_beams个

但是我算分数的时候就直接过滤了,所以好像其实不用搞2*num_beams了

相关推荐
中杯可乐多加冰17 小时前
RAG 深度实践系列(七):从“能用”到“好用”——RAG 系统优化与效果评估
人工智能·大模型·llm·大语言模型·rag·检索增强生成
山顶夕景1 天前
【LLM】大模型数据清洗&合成&增强方法
大模型·llm·训练数据
tiger1191 天前
FPGA 在大模型推理中的应用
人工智能·llm·fpga·大模型推理
AndrewHZ1 天前
【AI黑话日日新】什么是大模型的test-time scaling?
人工智能·深度学习·大模型·llm·推理加速·测试时缩放
GPUStack1 天前
vLLM、SGLang 融资背后,AI 推理正在走向系统化与治理
大模型·llm·vllm·模型推理·sglang·高性能推理
Tadas-Gao1 天前
大模型幻觉治理新范式:SCA与[PAUSE]注入技术的深度解析与创新设计
人工智能·深度学习·机器学习·架构·大模型·llm
猿小羽1 天前
基于 Spring AI 与 Streamable HTTP 构建 MCP Server 实践
java·llm·spring ai·mcp·streamable http
AndrewHZ1 天前
【AI黑话日日新】什么是隐式CoT?
人工智能·深度学习·算法·llm·cot·复杂推理
一个处女座的程序猿2 天前
CV之VLM之LLM-OCR:《DeepSeek-OCR 2: Visual Causal Flow》翻译与解读
llm·ocr·cv·vlm
dawdo2222 天前
自己动手从头开始编写LLM推理引擎(9)-KV缓存实现和优化
缓存·llm·transformer·qwen·kv cache