机器学习|从0开发大模型之SFT训练

继续写《从0开发大模型》系列文章,上一章主要数据数据预训练,让模型能学到句子接龙和部分语言理解能力,获取基座版本,但是用基座版本的模型的对话能力太弱了,需要用大量的数据微调,本文主要介绍如何用SFT训练模型。

1、什么是SFT

SFT是有监督微调(Supervised Fine-Tuning),指采用预先训练好的网络模型,并针对你自己的专门任务在少量的监督数据上对其进行重新训练的技术。

SFT在大语言模型中的应用有以下重要原因:

任务特定性能提升:预训练语言模型通过大规模的无监督训练学习了语言的统计模式和语义表示,然而它在特定任务下的效果可能并不令人满意,通过在任务特定的有标签数据上进行微调,模型可以进一步学习任务相关的特征和模式,从而提高性能。

领域适应性:预训练语言模型可能在不同领域的数据上表现不一致,通过在特定领域的有标签数据上进行微调,可以使模型更好地适应该领域的特殊术语、结构和语义,提高在该领域任务上的效果。

数据稀缺性:某些任务可能受制于数据的稀缺性,很难获得大规模的标签数据,监督微调可以通过使用有限的标签数据来训练模型,从而在数据有限的情况下取得较好的性能。

防止过拟合:在监督微调过程中,通过使用有标签数据进行有监督训练,可以减少模型在特定任务上的过拟合风险,这是因为监督微调过程中的有标签数据可以提供更具体的任务信号,有助于约束模型的学习,避免过多地拟合预训练过程中的无监督信号。

2、整理SFT数据

整理SFT数据需要遵循以下原则:

  • 按照QA的格式整理数据
  • 如果训练多语言模型,需要准备其他语言的数据,本文训练的模型是中文的,所以只准备中文的数据
  • SFT的数据需要确保QA的数据回答是正确的,否则模型无法学习到正确的答案

(1)数据格式如下(CSV):

bash 复制代码
history,q,a
[],好的。现在请你将这个文本中的所有的逗号都替换成空格。,"好的,请稍等一下,现在我会将文本中的所有逗号替换为空格。处理后文本为:""这是一个句子 目的是看看是否可以正确地从这个句子中删除关键词。""。处理结果如何?"

其中history是历史的输入,q是问题,a是答案,但是以上数据无法直接用于微调,需要会拼接,比如翻译类型的会这样处理:

css 复制代码
instruction:
[USR]:将下列内容翻译成英语:{待翻译文本}
answer
[BOT]:{翻译结果}

拼接后的文本:
<bos_token>[USER]:将下列内容翻译成英语:{待翻译文本}<special token>[BOT]:{翻译结果} <eos_token>

(2)SFT的数据集可以参考以下数据集:

  • BelleGroup/train_3.5M_CN
  • LinkSoul/instruction_merge_set
  • stingning/ultrachat
  • BAAI/COIG-PC-core
  • shibing624/sharegpt_gpt4
  • shareAI/ShareGPT-Chinese-English-90k
  • Tiger Research
  • BelleGroup/school_math_0.25M
  • YeungNLP/moss-003-sft-data

(3)整理数据,将数据压缩:

这里为了数据处理方便,这里数据直接使用:www.modelscope.cn/datasets/de... sft_process_and_write_data 将数据转换为token。

python 复制代码
def sft_process_and_write_data(data, max_length = 1024, padding = 0):
    doc_ids = []
    for per in data:
        history, q, a = per['history'], per['q'], per['a']
        if len(q) < 10 or len(a) < 5:
            continue
        if len(q) > 512 or len(a) > 512:
            continue

        messages = []
        for history_message in history:
            if len(history_message) <= 1:
                continue
            messages.append(
                {"role": 'user', "content": history_message[0][:max_length // 2]}
            )
            messages.append(
                {"role": 'assistant', "content": history_message[1][:max_length // 2]}
            )

        messages += [
            {"role": "user", "content": q},
            {"role": "assistant", "content": a},
        ]
        new_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        input_id = tokenizer(new_prompt).data['input_ids'][:max_length]
        padding_len = max_length - len(input_id)
        input_id = input_id + [padding] * padding_len
        if len(input_id) >= 5:
            doc_ids += input_id
    
    return doc_ids

def sft_process():
    file_name = 'sft_data.bin'
    chunk_size = 2000  # 每次处理的记录数

    input_doc_ids = []
    datalist = []
    sft_datasets = [f'{basepath}/sft_data_zh.jsonl']
    chunk_num = 0
    for path in sft_datasets:
        with jsonlines.open(path) as reader:
            for idx, obj in enumerate(reader):
                try:
                    datalist.append({
                        'history': obj.get('history', ''),
                        'q': obj.get('input', '') + obj.get('q', ''),
                        'a': obj.get('output', '') + obj.get('a', '')
                    })

                    if len(datalist) >= chunk_size:
                        chunk_num += 1
                        input_doc_ids += sft_process_and_write_data(datalist)
                        arr = np.array(input_doc_ids, dtype=np.uint16)
                        with open(f'{basepath}/{file_name}', 'wb') as f:
                            f.write(arr.tobytes())
                        datalist = []
                        if chunk_num % 100 == 0:
                            print(f'chunk:{chunk_num} process end, and input_doc_ids length:{len(input_doc_ids)}')
                except jsonlines.InvalidLineError as e:
                    print(f"Skipping invalid JSON line {idx + 1}: {e}")
                    continue
                    
            if len(datalist) > 0:
                input_doc_ids += sft_process_and_write_data(datalist)
                arr = np.array(input_doc_ids, dtype=np.uint16)
                with open(f'{basepath}/{file_name}', 'wb') as f:
                    f.write(arr.tobytes())
                datalist = []

3、SFT训练

SFT训练的代码和上一篇预训练的代码差别不大,区别是加载SFT数据集,代码如下(替换上一篇预训练的 PretrainDataset 函数):

python 复制代码
class SFTDataset(Dataset):
    def __init__(self, data_path_lst, tokenizer, max_length=1024, prompt_max_len=512, answer_max_len=256):
        super().__init__()
        self.max_length = max_length
        self.prompt_max_len = prompt_max_len
        self.answer_max_len = answer_max_len

        data_lst = []
        for data_path in data_path_lst:
            with open(data_path, 'rb') as f:
                data = np.fromfile(f, dtype=np.uint16)
                data_lst.append(data)
        data = np.concatenate(data_lst)
        data = data[:max_length * int(len(data) / max_length)]
        self.data = data.reshape(-1, max_length)
        print("train data.shape:{}".format(self.data.shape))
        print("SFTDataset finished.....")

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index: int):
        sample = self.data[index]
        X = np.array(sample[:-1]).astype(np.int64)
        Y = np.array(sample[1:]).astype(np.int64)
        return torch.from_numpy(X), torch.from_numpy(Y)

SFT训练的数据集很大,训练时间较长,大概需要2-3天的时间,其中部分输出(从这里看loss值已经开始下降了):

ini 复制代码
...
Epoch:[7/20](307000/341718) loss:0.644 lr:0.0000070 epoch_Time:13.0min:
Epoch:[7/20](308000/341718) loss:0.856 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](309000/341718) loss:0.424 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](310000/341718) loss:0.524 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](311000/341718) loss:0.272 lr:0.0000070 epoch_Time:11.0min:
Epoch:[7/20](312000/341718) loss:0.373 lr:0.0000069 epoch_Time:11.0min:
Epoch:[7/20](313000/341718) loss:0.387 lr:0.0000069 epoch_Time:11.0min:
Epoch:[7/20](314000/341718) loss:0.560 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](315000/341718) loss:0.365 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](316000/341718) loss:0.226 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](317000/341718) loss:0.666 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](318000/341718) loss:0.504 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](319000/341718) loss:0.534 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](320000/341718) loss:0.403 lr:0.0000069 epoch_Time:8.0min:
Epoch:[7/20](321000/341718) loss:0.445 lr:0.0000069 epoch_Time:8.0min:
Epoch:[7/20](322000/341718) loss:0.581 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](323000/341718) loss:0.655 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](324000/341718) loss:0.606 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](325000/341718) loss:0.480 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](326000/341718) loss:0.696 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](327000/341718) loss:0.634 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](328000/341718) loss:0.852 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](329000/341718) loss:0.717 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](330000/341718) loss:0.680 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](331000/341718) loss:0.415 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](332000/341718) loss:0.617 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](333000/341718) loss:0.647 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](334000/341718) loss:0.554 lr:0.0000069 epoch_Time:3.0min:
Epoch:[7/20](335000/341718) loss:0.746 lr:0.0000069 epoch_Time:3.0min:
Epoch:[7/20](336000/341718) loss:0.499 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](337000/341718) loss:0.318 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](338000/341718) loss:0.651 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](339000/341718) loss:0.424 lr:0.0000069 epoch_Time:1.0min:
Epoch:[7/20](340000/341718) loss:0.567 lr:0.0000069 epoch_Time:1.0min:
Epoch:[7/20](341000/341718) loss:0.568 lr:0.0000069 epoch_Time:1.0min:
...

参考

(1)github.com/karpathy/ll...

相关推荐
lovelin+v175030409664 分钟前
智能电商:API接口如何驱动自动化与智能化转型
大数据·人工智能·爬虫·python
rpa_top5 分钟前
RPA 助力电商:自动化商品信息上传,节省人力资源 —— 以影刀 RPA 为例【rpa.top】
大数据·前端·人工智能·自动化·rpa
视觉语言导航20 分钟前
arXiv-2024 | STMR:语义拓扑度量表示引导的大模型推理无人机视觉语言导航
人工智能·具身智能
咯咯咯伦1 小时前
AI神了,一键视频下载+翻译+配音+字幕!(整合包)
人工智能
愚者大大1 小时前
优化算法(SGD,RMSProp,Ada)
人工智能·算法·机器学习
人类群星闪耀时2 小时前
基于AI的网络流量分析:构建智能化运维体系
运维·人工智能
dundunmm2 小时前
数据挖掘之认识数据
人工智能·机器学习·信息可视化·数据挖掘
FBI78098045942 小时前
API接口在电商行业中的创新应用与趋势
运维·网络·人工智能·爬虫·python
Sword992 小时前
豆包 MarsCode AI Apply功能揭秘:自动代码应用与 Diff 实现
前端·人工智能·豆包marscode