机器学习|从0开发大模型之SFT训练

继续写《从0开发大模型》系列文章,上一章主要数据数据预训练,让模型能学到句子接龙和部分语言理解能力,获取基座版本,但是用基座版本的模型的对话能力太弱了,需要用大量的数据微调,本文主要介绍如何用SFT训练模型。

1、什么是SFT

SFT是有监督微调(Supervised Fine-Tuning),指采用预先训练好的网络模型,并针对你自己的专门任务在少量的监督数据上对其进行重新训练的技术。

SFT在大语言模型中的应用有以下重要原因:

任务特定性能提升:预训练语言模型通过大规模的无监督训练学习了语言的统计模式和语义表示,然而它在特定任务下的效果可能并不令人满意,通过在任务特定的有标签数据上进行微调,模型可以进一步学习任务相关的特征和模式,从而提高性能。

领域适应性:预训练语言模型可能在不同领域的数据上表现不一致,通过在特定领域的有标签数据上进行微调,可以使模型更好地适应该领域的特殊术语、结构和语义,提高在该领域任务上的效果。

数据稀缺性:某些任务可能受制于数据的稀缺性,很难获得大规模的标签数据,监督微调可以通过使用有限的标签数据来训练模型,从而在数据有限的情况下取得较好的性能。

防止过拟合:在监督微调过程中,通过使用有标签数据进行有监督训练,可以减少模型在特定任务上的过拟合风险,这是因为监督微调过程中的有标签数据可以提供更具体的任务信号,有助于约束模型的学习,避免过多地拟合预训练过程中的无监督信号。

2、整理SFT数据

整理SFT数据需要遵循以下原则:

  • 按照QA的格式整理数据
  • 如果训练多语言模型,需要准备其他语言的数据,本文训练的模型是中文的,所以只准备中文的数据
  • SFT的数据需要确保QA的数据回答是正确的,否则模型无法学习到正确的答案

(1)数据格式如下(CSV):

bash 复制代码
history,q,a
[],好的。现在请你将这个文本中的所有的逗号都替换成空格。,"好的,请稍等一下,现在我会将文本中的所有逗号替换为空格。处理后文本为:""这是一个句子 目的是看看是否可以正确地从这个句子中删除关键词。""。处理结果如何?"

其中history是历史的输入,q是问题,a是答案,但是以上数据无法直接用于微调,需要会拼接,比如翻译类型的会这样处理:

css 复制代码
instruction:
[USR]:将下列内容翻译成英语:{待翻译文本}
answer
[BOT]:{翻译结果}

拼接后的文本:
<bos_token>[USER]:将下列内容翻译成英语:{待翻译文本}<special token>[BOT]:{翻译结果} <eos_token>

(2)SFT的数据集可以参考以下数据集:

  • BelleGroup/train_3.5M_CN
  • LinkSoul/instruction_merge_set
  • stingning/ultrachat
  • BAAI/COIG-PC-core
  • shibing624/sharegpt_gpt4
  • shareAI/ShareGPT-Chinese-English-90k
  • Tiger Research
  • BelleGroup/school_math_0.25M
  • YeungNLP/moss-003-sft-data

(3)整理数据,将数据压缩:

这里为了数据处理方便,这里数据直接使用:www.modelscope.cn/datasets/de... sft_process_and_write_data 将数据转换为token。

python 复制代码
def sft_process_and_write_data(data, max_length = 1024, padding = 0):
    doc_ids = []
    for per in data:
        history, q, a = per['history'], per['q'], per['a']
        if len(q) < 10 or len(a) < 5:
            continue
        if len(q) > 512 or len(a) > 512:
            continue

        messages = []
        for history_message in history:
            if len(history_message) <= 1:
                continue
            messages.append(
                {"role": 'user', "content": history_message[0][:max_length // 2]}
            )
            messages.append(
                {"role": 'assistant', "content": history_message[1][:max_length // 2]}
            )

        messages += [
            {"role": "user", "content": q},
            {"role": "assistant", "content": a},
        ]
        new_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        input_id = tokenizer(new_prompt).data['input_ids'][:max_length]
        padding_len = max_length - len(input_id)
        input_id = input_id + [padding] * padding_len
        if len(input_id) >= 5:
            doc_ids += input_id
    
    return doc_ids

def sft_process():
    file_name = 'sft_data.bin'
    chunk_size = 2000  # 每次处理的记录数

    input_doc_ids = []
    datalist = []
    sft_datasets = [f'{basepath}/sft_data_zh.jsonl']
    chunk_num = 0
    for path in sft_datasets:
        with jsonlines.open(path) as reader:
            for idx, obj in enumerate(reader):
                try:
                    datalist.append({
                        'history': obj.get('history', ''),
                        'q': obj.get('input', '') + obj.get('q', ''),
                        'a': obj.get('output', '') + obj.get('a', '')
                    })

                    if len(datalist) >= chunk_size:
                        chunk_num += 1
                        input_doc_ids += sft_process_and_write_data(datalist)
                        arr = np.array(input_doc_ids, dtype=np.uint16)
                        with open(f'{basepath}/{file_name}', 'wb') as f:
                            f.write(arr.tobytes())
                        datalist = []
                        if chunk_num % 100 == 0:
                            print(f'chunk:{chunk_num} process end, and input_doc_ids length:{len(input_doc_ids)}')
                except jsonlines.InvalidLineError as e:
                    print(f"Skipping invalid JSON line {idx + 1}: {e}")
                    continue
                    
            if len(datalist) > 0:
                input_doc_ids += sft_process_and_write_data(datalist)
                arr = np.array(input_doc_ids, dtype=np.uint16)
                with open(f'{basepath}/{file_name}', 'wb') as f:
                    f.write(arr.tobytes())
                datalist = []

3、SFT训练

SFT训练的代码和上一篇预训练的代码差别不大,区别是加载SFT数据集,代码如下(替换上一篇预训练的 PretrainDataset 函数):

python 复制代码
class SFTDataset(Dataset):
    def __init__(self, data_path_lst, tokenizer, max_length=1024, prompt_max_len=512, answer_max_len=256):
        super().__init__()
        self.max_length = max_length
        self.prompt_max_len = prompt_max_len
        self.answer_max_len = answer_max_len

        data_lst = []
        for data_path in data_path_lst:
            with open(data_path, 'rb') as f:
                data = np.fromfile(f, dtype=np.uint16)
                data_lst.append(data)
        data = np.concatenate(data_lst)
        data = data[:max_length * int(len(data) / max_length)]
        self.data = data.reshape(-1, max_length)
        print("train data.shape:{}".format(self.data.shape))
        print("SFTDataset finished.....")

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index: int):
        sample = self.data[index]
        X = np.array(sample[:-1]).astype(np.int64)
        Y = np.array(sample[1:]).astype(np.int64)
        return torch.from_numpy(X), torch.from_numpy(Y)

SFT训练的数据集很大,训练时间较长,大概需要2-3天的时间,其中部分输出(从这里看loss值已经开始下降了):

ini 复制代码
...
Epoch:[7/20](307000/341718) loss:0.644 lr:0.0000070 epoch_Time:13.0min:
Epoch:[7/20](308000/341718) loss:0.856 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](309000/341718) loss:0.424 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](310000/341718) loss:0.524 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](311000/341718) loss:0.272 lr:0.0000070 epoch_Time:11.0min:
Epoch:[7/20](312000/341718) loss:0.373 lr:0.0000069 epoch_Time:11.0min:
Epoch:[7/20](313000/341718) loss:0.387 lr:0.0000069 epoch_Time:11.0min:
Epoch:[7/20](314000/341718) loss:0.560 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](315000/341718) loss:0.365 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](316000/341718) loss:0.226 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](317000/341718) loss:0.666 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](318000/341718) loss:0.504 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](319000/341718) loss:0.534 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](320000/341718) loss:0.403 lr:0.0000069 epoch_Time:8.0min:
Epoch:[7/20](321000/341718) loss:0.445 lr:0.0000069 epoch_Time:8.0min:
Epoch:[7/20](322000/341718) loss:0.581 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](323000/341718) loss:0.655 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](324000/341718) loss:0.606 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](325000/341718) loss:0.480 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](326000/341718) loss:0.696 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](327000/341718) loss:0.634 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](328000/341718) loss:0.852 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](329000/341718) loss:0.717 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](330000/341718) loss:0.680 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](331000/341718) loss:0.415 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](332000/341718) loss:0.617 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](333000/341718) loss:0.647 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](334000/341718) loss:0.554 lr:0.0000069 epoch_Time:3.0min:
Epoch:[7/20](335000/341718) loss:0.746 lr:0.0000069 epoch_Time:3.0min:
Epoch:[7/20](336000/341718) loss:0.499 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](337000/341718) loss:0.318 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](338000/341718) loss:0.651 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](339000/341718) loss:0.424 lr:0.0000069 epoch_Time:1.0min:
Epoch:[7/20](340000/341718) loss:0.567 lr:0.0000069 epoch_Time:1.0min:
Epoch:[7/20](341000/341718) loss:0.568 lr:0.0000069 epoch_Time:1.0min:
...

参考

(1)github.com/karpathy/ll...

相关推荐
思通数科多模态大模型4 分钟前
10大核心应用场景,解锁AI检测系统的智能安全之道
人工智能·深度学习·安全·目标检测·计算机视觉·自然语言处理·数据挖掘
数据岛8 分钟前
数据集论文:面向深度学习的土地利用场景分类与变化检测
人工智能·深度学习
龙的爹233335 分钟前
论文翻译 | RECITATION-AUGMENTED LANGUAGE MODELS
人工智能·语言模型·自然语言处理·prompt·gpu算力
白光白光36 分钟前
凸函数与深度学习调参
人工智能·深度学习
sp_fyf_202438 分钟前
【大语言模型】ACL2024论文-18 MINPROMPT:基于图的最小提示数据增强用于少样本问答
人工智能·深度学习·神经网络·目标检测·机器学习·语言模型·自然语言处理
weixin_5436628641 分钟前
BERT的中文问答系统33
人工智能·深度学习·bert
爱喝白开水a44 分钟前
Sentence-BERT实现文本匹配【分类目标函数】
人工智能·深度学习·机器学习·自然语言处理·分类·bert·大模型微调
Jack黄从零学c++1 小时前
opencv(c++)---自带的卷积运算filter2D以及应用
c++·人工智能·opencv
封步宇AIGC1 小时前
量化交易系统开发-实时行情自动化交易-4.2.3.指数移动平均线实现
人工智能·python·机器学习·数据挖掘
Mr.谢尔比2 小时前
李宏毅机器学习课程知识点摘要(1-5集)
人工智能·pytorch·深度学习·神经网络·算法·机器学习·计算机视觉