机器学习|从0开发大模型之SFT训练

继续写《从0开发大模型》系列文章,上一章主要数据数据预训练,让模型能学到句子接龙和部分语言理解能力,获取基座版本,但是用基座版本的模型的对话能力太弱了,需要用大量的数据微调,本文主要介绍如何用SFT训练模型。

1、什么是SFT

SFT是有监督微调(Supervised Fine-Tuning),指采用预先训练好的网络模型,并针对你自己的专门任务在少量的监督数据上对其进行重新训练的技术。

SFT在大语言模型中的应用有以下重要原因:

任务特定性能提升:预训练语言模型通过大规模的无监督训练学习了语言的统计模式和语义表示,然而它在特定任务下的效果可能并不令人满意,通过在任务特定的有标签数据上进行微调,模型可以进一步学习任务相关的特征和模式,从而提高性能。

领域适应性:预训练语言模型可能在不同领域的数据上表现不一致,通过在特定领域的有标签数据上进行微调,可以使模型更好地适应该领域的特殊术语、结构和语义,提高在该领域任务上的效果。

数据稀缺性:某些任务可能受制于数据的稀缺性,很难获得大规模的标签数据,监督微调可以通过使用有限的标签数据来训练模型,从而在数据有限的情况下取得较好的性能。

防止过拟合:在监督微调过程中,通过使用有标签数据进行有监督训练,可以减少模型在特定任务上的过拟合风险,这是因为监督微调过程中的有标签数据可以提供更具体的任务信号,有助于约束模型的学习,避免过多地拟合预训练过程中的无监督信号。

2、整理SFT数据

整理SFT数据需要遵循以下原则:

  • 按照QA的格式整理数据
  • 如果训练多语言模型,需要准备其他语言的数据,本文训练的模型是中文的,所以只准备中文的数据
  • SFT的数据需要确保QA的数据回答是正确的,否则模型无法学习到正确的答案

(1)数据格式如下(CSV):

bash 复制代码
history,q,a
[],好的。现在请你将这个文本中的所有的逗号都替换成空格。,"好的,请稍等一下,现在我会将文本中的所有逗号替换为空格。处理后文本为:""这是一个句子 目的是看看是否可以正确地从这个句子中删除关键词。""。处理结果如何?"

其中history是历史的输入,q是问题,a是答案,但是以上数据无法直接用于微调,需要会拼接,比如翻译类型的会这样处理:

css 复制代码
instruction:
[USR]:将下列内容翻译成英语:{待翻译文本}
answer
[BOT]:{翻译结果}

拼接后的文本:
<bos_token>[USER]:将下列内容翻译成英语:{待翻译文本}<special token>[BOT]:{翻译结果} <eos_token>

(2)SFT的数据集可以参考以下数据集:

  • BelleGroup/train_3.5M_CN
  • LinkSoul/instruction_merge_set
  • stingning/ultrachat
  • BAAI/COIG-PC-core
  • shibing624/sharegpt_gpt4
  • shareAI/ShareGPT-Chinese-English-90k
  • Tiger Research
  • BelleGroup/school_math_0.25M
  • YeungNLP/moss-003-sft-data

(3)整理数据,将数据压缩:

这里为了数据处理方便,这里数据直接使用:www.modelscope.cn/datasets/de... sft_process_and_write_data 将数据转换为token。

python 复制代码
def sft_process_and_write_data(data, max_length = 1024, padding = 0):
    doc_ids = []
    for per in data:
        history, q, a = per['history'], per['q'], per['a']
        if len(q) < 10 or len(a) < 5:
            continue
        if len(q) > 512 or len(a) > 512:
            continue

        messages = []
        for history_message in history:
            if len(history_message) <= 1:
                continue
            messages.append(
                {"role": 'user', "content": history_message[0][:max_length // 2]}
            )
            messages.append(
                {"role": 'assistant', "content": history_message[1][:max_length // 2]}
            )

        messages += [
            {"role": "user", "content": q},
            {"role": "assistant", "content": a},
        ]
        new_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        input_id = tokenizer(new_prompt).data['input_ids'][:max_length]
        padding_len = max_length - len(input_id)
        input_id = input_id + [padding] * padding_len
        if len(input_id) >= 5:
            doc_ids += input_id
    
    return doc_ids

def sft_process():
    file_name = 'sft_data.bin'
    chunk_size = 2000  # 每次处理的记录数

    input_doc_ids = []
    datalist = []
    sft_datasets = [f'{basepath}/sft_data_zh.jsonl']
    chunk_num = 0
    for path in sft_datasets:
        with jsonlines.open(path) as reader:
            for idx, obj in enumerate(reader):
                try:
                    datalist.append({
                        'history': obj.get('history', ''),
                        'q': obj.get('input', '') + obj.get('q', ''),
                        'a': obj.get('output', '') + obj.get('a', '')
                    })

                    if len(datalist) >= chunk_size:
                        chunk_num += 1
                        input_doc_ids += sft_process_and_write_data(datalist)
                        arr = np.array(input_doc_ids, dtype=np.uint16)
                        with open(f'{basepath}/{file_name}', 'wb') as f:
                            f.write(arr.tobytes())
                        datalist = []
                        if chunk_num % 100 == 0:
                            print(f'chunk:{chunk_num} process end, and input_doc_ids length:{len(input_doc_ids)}')
                except jsonlines.InvalidLineError as e:
                    print(f"Skipping invalid JSON line {idx + 1}: {e}")
                    continue
                    
            if len(datalist) > 0:
                input_doc_ids += sft_process_and_write_data(datalist)
                arr = np.array(input_doc_ids, dtype=np.uint16)
                with open(f'{basepath}/{file_name}', 'wb') as f:
                    f.write(arr.tobytes())
                datalist = []

3、SFT训练

SFT训练的代码和上一篇预训练的代码差别不大,区别是加载SFT数据集,代码如下(替换上一篇预训练的 PretrainDataset 函数):

python 复制代码
class SFTDataset(Dataset):
    def __init__(self, data_path_lst, tokenizer, max_length=1024, prompt_max_len=512, answer_max_len=256):
        super().__init__()
        self.max_length = max_length
        self.prompt_max_len = prompt_max_len
        self.answer_max_len = answer_max_len

        data_lst = []
        for data_path in data_path_lst:
            with open(data_path, 'rb') as f:
                data = np.fromfile(f, dtype=np.uint16)
                data_lst.append(data)
        data = np.concatenate(data_lst)
        data = data[:max_length * int(len(data) / max_length)]
        self.data = data.reshape(-1, max_length)
        print("train data.shape:{}".format(self.data.shape))
        print("SFTDataset finished.....")

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index: int):
        sample = self.data[index]
        X = np.array(sample[:-1]).astype(np.int64)
        Y = np.array(sample[1:]).astype(np.int64)
        return torch.from_numpy(X), torch.from_numpy(Y)

SFT训练的数据集很大,训练时间较长,大概需要2-3天的时间,其中部分输出(从这里看loss值已经开始下降了):

ini 复制代码
...
Epoch:[7/20](307000/341718) loss:0.644 lr:0.0000070 epoch_Time:13.0min:
Epoch:[7/20](308000/341718) loss:0.856 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](309000/341718) loss:0.424 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](310000/341718) loss:0.524 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](311000/341718) loss:0.272 lr:0.0000070 epoch_Time:11.0min:
Epoch:[7/20](312000/341718) loss:0.373 lr:0.0000069 epoch_Time:11.0min:
Epoch:[7/20](313000/341718) loss:0.387 lr:0.0000069 epoch_Time:11.0min:
Epoch:[7/20](314000/341718) loss:0.560 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](315000/341718) loss:0.365 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](316000/341718) loss:0.226 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](317000/341718) loss:0.666 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](318000/341718) loss:0.504 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](319000/341718) loss:0.534 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](320000/341718) loss:0.403 lr:0.0000069 epoch_Time:8.0min:
Epoch:[7/20](321000/341718) loss:0.445 lr:0.0000069 epoch_Time:8.0min:
Epoch:[7/20](322000/341718) loss:0.581 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](323000/341718) loss:0.655 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](324000/341718) loss:0.606 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](325000/341718) loss:0.480 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](326000/341718) loss:0.696 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](327000/341718) loss:0.634 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](328000/341718) loss:0.852 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](329000/341718) loss:0.717 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](330000/341718) loss:0.680 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](331000/341718) loss:0.415 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](332000/341718) loss:0.617 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](333000/341718) loss:0.647 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](334000/341718) loss:0.554 lr:0.0000069 epoch_Time:3.0min:
Epoch:[7/20](335000/341718) loss:0.746 lr:0.0000069 epoch_Time:3.0min:
Epoch:[7/20](336000/341718) loss:0.499 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](337000/341718) loss:0.318 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](338000/341718) loss:0.651 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](339000/341718) loss:0.424 lr:0.0000069 epoch_Time:1.0min:
Epoch:[7/20](340000/341718) loss:0.567 lr:0.0000069 epoch_Time:1.0min:
Epoch:[7/20](341000/341718) loss:0.568 lr:0.0000069 epoch_Time:1.0min:
...

参考

(1)github.com/karpathy/ll...

相关推荐
艾思科蓝-何老师【H8053】13 分钟前
【ACM出版】第四届信号处理与通信技术国际学术会议(SPCT 2024)
人工智能·信号处理·论文发表·香港中文大学
weixin_4526006941 分钟前
《青牛科技 GC6125:驱动芯片中的璀璨之星,点亮 IPcamera 和云台控制(替代 BU24025/ROHM)》
人工智能·科技·单片机·嵌入式硬件·新能源充电桩·智能充电枪
学术搬运工41 分钟前
【珠海科技学院主办,暨南大学协办 | IEEE出版 | EI检索稳定 】2024年健康大数据与智能医疗国际会议(ICHIH 2024)
大数据·图像处理·人工智能·科技·机器学习·自然语言处理
右恩1 小时前
AI大模型重塑软件开发:流程革新与未来展望
人工智能
图片转成excel表格1 小时前
WPS Office Excel 转 PDF 后图片丢失的解决方法
人工智能·科技·深度学习
ApiHug2 小时前
ApiSmart x Qwen2.5-Coder 开源旗舰编程模型媲美 GPT-4o, ApiSmart 实测!
人工智能·spring boot·spring·ai编程·apihug
哇咔咔哇咔2 小时前
【科普】简述CNN的各种模型
人工智能·神经网络·cnn
李歘歘2 小时前
万字长文解读深度学习——多模态模型CLIP、BLIP、ViLT
人工智能·深度学习
Chatopera 研发团队2 小时前
机器学习 - 为 Jupyter Notebook 安装新的 Kernel
人工智能·机器学习·jupyter