以Llama-2为例,在生成模型中使用自定义StoppingCriteria

以Llama-2为例,在生成模型中使用自定义StoppingCriteria

  • [1. 前言](#1. 前言)
  • [2. 场景介绍](#2. 场景介绍)
  • [3. 解决方法](#3. 解决方法)
  • [4. 结语](#4. 结语)

1. 前言

在之前的文章中,介绍了使用transformers模块创建的模型,其generate方法的详细原理和使用方法,文章链接:

以beam search为例,详解transformers中generate方法(上)
以beam search为例,详解transformers中generate方法(下)

其中提到了用户参与生成过程的两个关键组件,logits_processorstopping_criteria,使用这两个类,是用户控制生成过程的主要手段。其中,logits_processor用来在生成过程中,根据用户设置的指定规则,强行修改当前step在词表空间上的概率分布,而stopping_criteria,根据用户所规定的规则来中止生成。

这两个组件在transformers模块中都有一些预设的类可以直接使用,预设类的基本信息介绍可参考以beam search为例,详解transformers中generate方法(上)

本文将结合实际应用场景,介绍用户如何根据自己的需求来设计并实现一个自定义的stopping_criteria,来控制生成过程提前结束。

2. 场景介绍

这次介绍的场景是,使用Llama-2的生成能力对一段新闻进行概括,希望能够生成一句简短的话,来概括新闻中发生的最核心的事情。

通过给定对话背景,结合历史样例的方式,希望Llama-2能够输出期望的结果。

对话的prompt构造方法可以参考之前的内容:NLP实践------Llama-2 多轮对话prompt构建

然而,即便是采用了in-context learning的方式,Llama-2生成的结果仍然过于冗长。

例如对于这样一篇新闻:

python3 复制代码
text = """, Photo Credit : Associated Press Four air crew members were missing after an Australian army helicopter ditched into waters off the Queensland state coast during joint military exercises with the United States, officials said Saturday. The MRH-90 Taipan helicopter went down near Lindeman Island, a Great Barrier... ..."""  
# 后边忽略若干内容

模型生成的结果为:

Four Australian army air crew members are missing after an Australian army MRH-90 Taipan helicopter ditched into waters off the Queensland state coast during joint military exercises with the United States. The helicopter went down near Lindeman Island, a Great Barrier Reef tourist resort, at around 11 pm on Friday. A search involving US, Canadian, and Australian personnel is underway to find the missing crew, who are all Australian men. Debris that appeared to be from a helicopter has been recovered, according to Queensland Police Assistant Commissioner Douglas McDonald. The Taipan was taking part in Talisman Sabre, a biennial joint US-Australian military exercise that is largely based in Queensland. This year's exercise involves 13 nations and over 30,000 military personnel. Defense Minister Richard Marles said the helicopter ditched, which refers to an emergency landing on water. He added that defense exercises, which are so necessary for the readiness of our defense force, are serious and carry risk. US Defense Secretary... ...
# 后边忽略若干内容

可以看出,并不是模型生成的结果不好,但是它太啰嗦了,而对于我的需求而言,模型只需要输出其中的第一句话就足够了。

这时候可能有人就会觉得:"那我分句然后把第一句话保留下来不就好了?"

------这样做虽然也可以达成效果,但是这个生成过程,时间和算力已经被消耗了。

所以需要采取方法,让模型在生成到第一个句号的时候,就停止生成,返回结果。于是就需要用到今天的主角------Stopping Criteria。

3. 解决方法

transformers模块中内置了几个默认的stopping criteria,然而,在很多情况下,它们并不能满足需求,这时,就需要创建自定义的stopping criteria。

首先需要引用基类:

python3 复制代码
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \
    STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings

其中,

  • StoppingCriteriaList是一个容器,需要将所有的criteria都添加到其中,generate时传入的是这个容器;
  • StoppingCriteria是基础类,自定义的criteria需要继承这个基础类。

接下来就实现一个criteria,效果是,遇到指定的token时,就停止生成:

python3 复制代码
class StopAtSpecificTokenCriteria(StoppingCriteria):
    """
    当生成出第一个指定token时,立即停止生成
    ---------------
    ver: 2023-08-02
    by: changhongyu
    """
    def __init__(self, token_id_list: List[int] = None):
        """
        :param token_id_list: 停止生成的指定token的id的列表
        """
        self.token_id_list = token_id_list
        
    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # return np.argmax(scores[-1].detach().cpu().numpy()) in self.token_id_list
        # 储存scores会额外占用资源,所以直接用input_ids进行判断
        return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list

那么,如果希望遇到句号就停止生成,那就用句号对应的token_id去实例化一个这样的stopping criteria,并将它添加到容器中:

python3 复制代码
# Llama-2的词表中,英文句号的id是29889
stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list=[29889]))

然后,在生成的时候,假如原本的生成指令是:

python3 复制代码
model.generate(**inputs)

那么再把stopping criteria作为参数传入进去,就可以发挥效果了:

python3 复制代码
model.generate(stopping_criteria=stopping_criteria, **inputs)

4. 结语

Stopping Criteria用于在每一个step的生成结束时,判断生成过程是否要结束,是用户控制生成过程的有效手段,其发挥作用的方式也比较直接,实现自定义criteria也并不复杂,只需要确保该类的调用方法返回值是bool值,并覆盖全部情况即可。

Logits Processor是用户控制生成的另一个有效工具,在接下来的博客中,还将介绍自定义logits processor是如何使用的,欢迎感兴趣的同学继续关注。

相关推荐
Guofu_Liao4 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
AI_小站12 小时前
RAG 示例:使用 langchain、Redis、llama.cpp 构建一个 kubernetes 知识库问答
人工智能·程序人生·langchain·kubernetes·llama·知识库·rag
Guofu_Liao12 小时前
Llama模型文件介绍
人工智能·llama
曼城周杰伦17 小时前
自然语言处理:第六十二章 KAG 超越GraphRAG的图谱框架
人工智能·pytorch·神经网络·自然语言处理·chatgpt·nlp·gpt-3
Donvink17 小时前
多模态大语言模型——《动手学大模型》实践教程第六章
人工智能·深度学习·语言模型·自然语言处理·llama
Donvink20 小时前
大模型安全和越狱攻击——《动手学大模型》实践教程第五章
深度学习·安全·语言模型·llama
Donvink20 小时前
大模型智能体安全——《动手学大模型》实践教程第七章
深度学习·安全·语言模型·prompt·llama
慢热型网友.1 天前
【项目实战】基于 LLaMA-Factory 通过 LoRA 微调 Qwen2
llama
机器学习是魔鬼1 天前
LLaMA-Factory 上手即用教程
llama·模型训练·ai功能岛·矩池云
Galeoto1 天前
fine tuning with llama-factory
llama