[NLP]LLM---大模型指令微调中的“Prompt”

一 指令微调数据集形式太多

大家有没有分析过 prompt对模型训练或者推理的影响?之前推理的时候,发现不加训练的时候prompt,直接输入模型性能会变差的,这个倒是可以理解。假如不加prompt直接训练,是不是测试的时候不加prompt也可以?还有一个就是多轮prompt和单轮prompt怎么构造的问题?好多模型训练方式不统一 包括指令数据形式有所不同,选择困难症又来了。。

先说一些观点,假如我们在微调一个大模型,单次实验微调所用的指令微调数据集应该选取"质量高、多样性",在训练资源充足的情况可以加入数量更多,长度更大的数据集。可以基于多个质量比较高的数据,做一份格式统一的多样性数据用来做sft,一次性微调完比较好,多次微调效果可能会折扣。

或者有继续微调比较合适的方案也可以,不损失之前模型的效果(或者损失比较小),目前可以尝试Lora或者Qlora的方式微调底座模型,然后将训练好的Lora权重合并到原始模型,这样可以减轻多次微调对模型的影响。

二 常见指令微调模板

目前常见微调数据集的格式包括以下几种:指令跟随格式、多轮对话格式,以及其他辅助格式。

1. 指令跟随格式

指令跟随形式是指用户输入指令,模型按照指令的要求输出结果的格式。这种形式的数据集通常采用json文件格式存储,典型的如Alpaca-52k数据集[1]。Alpaca-52k是一个开源的遵循指令的,最初用来微调LLaMA模型以得到Alpaca-7B模型的数据集,包含了52000条指令数据。这里的52k代表的是共有五万两千对指令与输出对。

Alpaca的格式有两类,一类是instruction/output格式,例如:

#instruction :什么是三原色?

#output:三原色是红、蓝、黄。这些颜色被称为原色,因为它们不能通过混合其他颜色来创建,而所有其他颜色都可以通过以不同比例组合它们来制作。在用于光的加法色彩系统中,原色是红、绿、蓝(RGB)。

另一类为instruction/input/output格式,例如:

#instruction:检查这个句子的拼写和语法错误

#input:他吃完饭离开了餐厅。

#output:这句话中有两个拼写错误。正确的句子应该是:"他吃完饭离开了餐厅。

然而,由于生成Alpaca数据集的self-instruct [2] 技术得到的数据集本身存在一些瑕疵,因此数据集需要进一步清洗和改进,例如alpaca-cleaned [3] 和alpaca-gpt4 [4]。此外还有中文翻译版本[5]。这些数据集通常包含几万个指令对,文件大小约为40MB左右的json格式文件。

这里补充一句,在相关研究中,数据集长度通常采用token数或指令条数进行计算。由于token数与tokenizer相关,而指令数会因文本长度的不同而有大的影响。为了直观起见,在本文的数据集规模评估中选择文件存储大小作为评估指标。

2. 多轮对话格式

多轮对话形式是指用户和模型之间以对话的形式进行,模型将通过与用户进行多轮的交互最终来达到用户的需求。典型的如训练Vicuna模型 [6] 所使用的ShareGPT数据集,ShareGPT本身是一个与ChatGPT(GPT-4)模型的聊天记录分享平台,它托管了大量由用户挑选的对话数据集,这些聊天记录通常展示的是聊天机器人自然流畅、具有创意的回答。Vicuna模型通过收集该平台的数据,数据大小为 673MB [7],其训练出来的模型具有较好的多轮对话能力,具体格式如下 [6]:

将所有数据集格式化为遵循聊天机器人风格的模式,以统一指令数据集的各种风格和格式,如下图所示。分别在用户指令和助手回复之前添加了特殊token <|user|>和<|assistant|>。并且在每个助手的回复末尾添加一个文本结束标记</s>,在推理时,该标记用以停止模型每轮的响应。

3. 其他形式

除了上述提到的数据格式,还有一些数据格式不易转化为对话形式,例如纯文本文档。另外,还有一些针对特定用途的数据集,例如文本总结数据集以及根据纯文本生成对话的数据集,如RefGPT [8] 文章提到的方案。根据文本的不同功能,它们还包括调用API的格式 [9] 和调用数据库语言的格式 [10] 等。当然,除非以纯文本的形式存在,否则这些格式都可以转换为指令跟随或多轮对话的格式。需要注意的是,这里所提到的微调数据集的格式并不包括基于强化学习训练的所使用的RLHF数据集。

4. 一些常见的模板

stanford_alpaca中模板

PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}


if example.get('input', '') == '':
    prompt = PROMPT_DICT['prompt_no_input'].format_map(example)
else:
    prompt = PROMPT_DICT['prompt_input'].format_map(example)
example2 = prompt + example['output']

Llama2中的模板

instruction = """[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

            If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n{} [/INST]"""

text_1 = f"".join(["[INST] <<SYS>>\n    "
   "You are a helpful, respectful and honest assistant. "
   "Always answer as helpfully as possible, while being safe."
   " Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, "
   "or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
   "    If a question does not make any sense, or is not factually coherent, "
   "explain why instead of answering something not correct. "
   "If you don't know the answer to a question, please don't share false information.\n"
   "<</SYS>>\n\n{0} [/INST] "]).format(
    data_point.get('instruction', '').strip() +"\t"+ data_point.get('input', '').strip())
    
    
缩短后为
text_1 = f"[INST] <<SYS>>\n    You are a helpful, respectful and honest assistant.<</SYS>>" \
         f"\n\n{0} [/INST] ".format(
    data_point.get('instruction', '').strip() + "\t" + data_point.get('input', '').strip())



[f'[INST] <<SYS>>\n{system_message.strip()}\n<</SYS>>\n\n' + prompt + ' [/INST] ' + response for prompt, response in zip(examples['prompt'], examples['response'])]

Linly-AI中模板

### Instruction:{prompt.strip()}  ### Response:

OpenLLM 排行榜top1的NousResearch

和alpaca模板差不多

### Instruction:
<prompt>

### Response:
<leave a newline blank for model to respond>
### Instruction:
<prompt>

### Input:
<additional context>

### Response:
<leave a newline blank for model to respond>

Yayi模板

https://huggingface.co/wenge-research/yayi-7b-llama2

prompt = "你是谁?"
formatted_prompt = f"""<|System|>:
You are a helpful, respectful and honest assistant named YaYi developed by Beijing Wenge Technology Co.,Ltd. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.

<|Human|>:
{prompt}

<|YaYi|>:
"""

StableBeluga2的模板

### System:
This is a system prompt, please behave and help the user.

### User:
Your prompt here

### Assistant:
The output of Stable Beluga 2

比如

system_prompt = "### System:\nYou are Stable Beluga, an AI that follows instructions extremely well. Help as much as you can. Remember, be safe, and don't do anything illegal.\n\n"

message = "Write me a poem please"
prompt = f"{system_prompt}### User: {message}\n\n### Assistant:\n"

Guanaco数据集常用模板

### Human: {prompt}
### Assistant:
prompt = "Introduce yourself"
formatted_prompt = (
    f"A chat between a curious human and an artificial intelligence assistant."
    f"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
    f"### Human: {prompt} ### Assistant:"
)

三 多轮对话输入和输出构造

参考yangjianxin1/Firefly项目和LinkSoul-AI/Chinese-Llama-2-7b项目,一般采用的方式是:

在计算loss时,我们通过mask的方式,input部分的loss不参与参数更新,只有"target"部分的loss参与参数更新。 这种方式充分利用了模型并行计算的优势,训练更加高效,且多轮对话中的每个target部分都参与了训练,训练更充分。 否则,就需要把一个n轮对话,拆分成n条数据,且只计算最后一个target的loss,大大降低了训练效率。

具体实现方式1:

# https://github.com/LinkSoul-AI/Chinese-Llama-2-7b/blob/main/train.py
def tokenize(item, tokenizer):
    roles = {"human": "user", "gpt": "assistant"}
    input_ids = []
    labels = []
    if "instruction" in item and len(item["instruction"]) > 0:
        system = item["instruction"]
    else:
        system = dummy_message["system"]
    system = B_SYS + system + E_SYS
    # add system before the first content in conversations
    item["conversations"][0]['value'] = system + item["conversations"][0]['value']
    for i, turn in enumerate(item["conversations"]):
        role = turn['from']
        content = turn['value']
        content = content.strip()
        if role == 'human':
            content = f"{B_INST} {content} {E_INST} "
            content_ids = tokenizer.encode(content)
            labels += [IGNORE_TOKEN_ID] * (len(content_ids))
        else:
            # assert role == "gpt"
            content = f"{content} "
            content_ids = tokenizer.encode(content, add_special_tokens=False) + [tokenizer.eos_token_id]   # add_special_tokens=False remove bos token, and add eos at the end
            labels += content_ids
        input_ids += content_ids

    input_ids = input_ids[:tokenizer.model_max_length]
    labels = labels[:tokenizer.model_max_length]

    trunc_id = last_index(labels, IGNORE_TOKEN_ID) + 1
    input_ids = input_ids[:trunc_id]
    labels = labels[:trunc_id]
    if len(labels) == 0:
        return tokenize(dummy_message, tokenizer)
    input_ids = safe_ids(input_ids, tokenizer.vocab_size, tokenizer.pad_token_id)
    labels = safe_ids(labels, tokenizer.vocab_size, IGNORE_TOKEN_ID)
    return input_ids, labels

具体实现方式1:

# https://github.com/yangjianxin1/Firefly/blob/master/component/dataset.py
class SFTDataset(Dataset):
    def __init__(self, file, tokenizer, max_seq_length):
        self.tokenizer = tokenizer
        self.bos_token_id = tokenizer.bos_token_id
        self.eos_token_id = tokenizer.eos_token_id
        self.eos_token = tokenizer.eos_token
        self.bos_token = tokenizer.bos_token
        self.max_seq_length = max_seq_length
        logger.info('Loading data: {}'.format(file))
        with open(file, 'r', encoding='utf8') as f:
            data_list = f.readlines()
        logger.info("there are {} data in dataset".format(len(data_list)))
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        # 每条数据格式为: <s>input1</s>target1</s>input2</s>target2</s>...
        data = self.data_list[index]
        data = json.loads(data)
        conversation = data['conversation']

        # 收集多轮对话
        utterances = []
        for x in conversation:
            utterances.append(x['human'])
            utterances.append(x['assistant'])
        utterances_ids = self.tokenizer(utterances, add_special_tokens=False).input_ids

        # 模型的输入格式为:<s>input1</s>target1</s>input2</s>target2</s>...
        input_ids = [self.bos_token_id]
        target_mask = [0]  # 用于对input进行mask,只计算target部分的loss
        for i, utterances_id in enumerate(utterances_ids):
            input_ids += (utterances_id + [self.eos_token_id])
            if i % 2 == 0:
                target_mask += [0] * (len(utterances_id) + 1)
            else:
                target_mask += [1] * (len(utterances_id) + 1)
        assert len(input_ids) == len(target_mask)
        # 对长度进行截断
        input_ids = input_ids[:self.max_seq_length]
        target_mask = target_mask[:self.max_seq_length]
        attention_mask = [1] * len(input_ids)
        assert len(input_ids) == len(target_mask) == len(attention_mask)
        inputs = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'target_mask': target_mask
        }
        return inputs

核心代码就是通过IGNORE_INDEX(-100)遮蔽掉input对应的目标输出即可。

四 微调数据的基本处理

微调数据集需要经过一系列的处理步骤,包括数据收集、数据清洗和数据增强等。数据收集是文本处理的基础,可通过公共数据集、自定义数据集和行业数据集等多种方式获得。在获得数据集后,需要进行数据清洗,去除噪声、重复和低质量的数据,将其统一转化为可训练的格式。另外,为了提高数据集的质量和丰富性,可以采用数据增强技术,如翻译、摘要、同义词替换、随机插入等操作,当然由于大模型本身已经有了很强的文本处理能力,这些数据增强技术都可以使用大模型来辅助完成。

通常,微调数据集的规模比预训练数据集小得多。典型的相比于几个TB的预训练文本数据,预训练的存储大小通常在几MB到1GB左右。在收集和整理数据后,可以将自定义数据集与其他开源数据集混合训练。此外,微调数据集通常还包含一个用于自身认知的数据集,典型的如训练Vicuna模型时提到的Dummy数据集 [6]。自定义数据集与其他开源数据集混合训练有助于提高模型效果和泛化性

五 如何高效率微调大模型

如何短时间、高效率的训练出实际效果不错、综合能力比较强的大模型呢?从指令微调数据集处理工作上,个人认为可以从以下方式进行:

(1) 事先准备多种高质量的指令微调数据集,每个数据集尽量保持差异性。那高质量如何定义呢?我们可以从一些效果不错的模型收集它们训练使用的指令数据集

(2)笔者在实验过程中,发现加入多伦对话的数据有助于提升模型生成能力,如果仅用单轮对话或者单轮指令训练出的模型生成长度可能偏短。

(3)另外通过实验发现,如果模型微调的时候使用模板,那么推理的时候应该也使用模板,否则效果会影响,直观上就是生成效果不理想,生成比较短,甚至"驴唇不对马嘴";训练使用了英文模板,推理的时候未使用提示模板的情况下会出现中英文混杂现象。

六 微调数据构建方式的探索

在上述的讨论中可以看到,微调数据集的构建非常重要,可以说是定制化自有模型时最核心的环节了。微调的目的是以一个预训练的模型为基础,利用一个小数据集,以打磨细节的方法,重新微调一个更为定制化的模型。

在构建微调数据集时,有一些值得注意的事项和构建方法,例如可以基于现有的大模型进行self-instruct,以及利用一些基本原则通过结合self-instruct方法来构建微调数据,如Dromedary-65B模型的微调方法 [13]。虽然从已有模型中构建数据是一种简便的方法,但并不一定能得到高质量的数据集。加之数据量大小不是唯一的评判标准,例如根据LIMA [14] 文章的结果表明,一定数量的微调数据就可以激活大模型预训练的数据,关键在于数据的质量和对模型的启发。基于LIMA文章的思想,出现了一个有意思的模型称为based [15],该模型的指导思想是,大模型本身已经拥有对各种事物的看法了,仅仅需要教会它如何说话就可以了。该模型有意思的地方在于,其微调数据的文件大小仅72.8KB,就可以让大模型流畅表达它的观点了,作为对比,LIMA的微调数据文件大小有2.97MB。

构建微调数据集的目的是,一方面是告知大模型一些新的知识,另一方面是调整大模型以我们期待的方式回复我们。如果需要告诉大模型一种新的知识,可能需要用高质量教导式的方式进行数据扩充。在这里,已有的文章提供了许多启示,例如Orca模型 [16] 的训练方式,从GPT-4获得丰富的解释轨迹,进行逐步思维,从而使得LLaMA-13B模型训练出具有ChatGPT相当的效果。又如 Textbook is all you need [17] 文章所提出的(尽管这篇文章讨论的并非是微调过程),可以构建更加具有教育意义的知识,例如采用教科书级别的数据集,这样能使大模型在编程领域上达到更高的水平。当然,为了增加微调数据的复杂度,也可以基于大模型根据已有的数据通过演化的方法来生成更加复杂的微调数据 [18][19]。

另外,微调数据集的构建和tokenizer也有关系。其中最大的影响是,tokenizer会影响到大模型的学习,例如文章 [20][21] 提到的,不恰当的tokenizer影响会影响大模型在两位数的加法正确性。当然,如果不想更改已经训练好的tokenizer,那么在构建微调数据集时,最好使用tokenizer中已有的词汇。当然,tokenizer本身会影响到token的长度,例如带有更多中文词汇的tokenizer可以使得中文文本经过tokenizer之后更短。同时在数据集处理过程加入StartToken,PadToken,EndToken等标记,也可以帮助模型更好地理解数据,或者帮助下游应用进行编码,一个具体的例子如Vicuna模型,在版本更新后,他们在微调数据集中加入了新的对话结束的标记:</s>,使模型能有效地预测何时停止生成字符。

综上,在构建微调数据集时,需要考虑方方面面的问题,不仅需要注重数据质量和数量的平衡,同时也要让模型了解我们的期望,以及在专有的定制领域获得相应的知识,从而达到在定制领域具有更高的预测准确性。由于微调数据的重要性,因此这方面的努力都是值得的。

本文主要讨论了大语言模型的微调数据集构建技术,并阐述了微调数据集的格式、数据增强和数据整理等步骤。与预训练数据集相比,微调数据集的构建需要更加精益求精。在实践中,采用自定义数据集与其他开源数据集混合训练的方式可以帮助微调模型提高效果和泛化性。然而,构建高质量微调数据集是一项庞杂琐碎的任务,需要耗费大量的时间和精力。期待在未来出现更加友好易用的GUI工具,帮助我们更好地构建微调数据集。

【LLM系列之指令微调】长话短说大模型指令微调的"Prompt" - 知乎 (zhihu.com)大语言模型微调:定制自己的微调数据集-云海天教程 (yht7.com)

相关推荐
花千树-0102 分钟前
LangChain教程 - 表达式语言 (LCEL) -构建智能链
gpt·langchain·prompt·aigc·ai编程·llama·ai-native
云天徽上6 分钟前
【机器学习案列】车牌自动识别系统:基于YOLO11的高效实现
人工智能·机器学习
DashVector16 分钟前
如何通过HTTP API插入Doc
数据库·人工智能·http·阿里云·向量检索
顾道长生'1 小时前
(NIPS-2024)PISSA:大型语言模型的主奇异值和奇异向量适配
人工智能·语言模型·自然语言处理
Macropodus1 小时前
near-synonym反义词生成(2):Prompt +Bert-MLM(FT)
自然语言处理·prompt·反义词生成·中文反义词·bert-mlm
语音之家1 小时前
CultureLLM 与 CulturePark:增强大语言模型对多元文化的理解
人工智能·语言模型·自然语言处理
Tasfa1 小时前
【AI系列】从零开始学习大模型GPT (1)- Build a Large Language Model (From Scratch)
人工智能·gpt·学习
一个处女座的程序猿1 小时前
LLMs之o3:《Deliberative Alignment: Reasoning Enables Safer Language Models》翻译与解读
人工智能·深度学习·机器学习
静静AI学堂1 小时前
动态头部:利用注意力机制统一目标检测头部
人工智能·目标检测·计算机视觉
嵌入式小强工作室1 小时前
stm32能跑人工智能么
人工智能·stm32·嵌入式硬件