继续写《从0开发大模型》系列文章,上一章主要数据数据预训练,让模型能学到句子接龙和部分语言理解能力,获取基座版本,但是用基座版本的模型的对话能力太弱了,需要用大量的数据微调,本文主要介绍如何用SFT训练模型。
1、什么是SFT
SFT是有监督微调(Supervised Fine-Tuning),指采用预先训练好的网络模型,并针对你自己的专门任务在少量的监督数据上对其进行重新训练的技术。
SFT在大语言模型中的应用有以下重要原因:
任务特定性能提升:预训练语言模型通过大规模的无监督训练学习了语言的统计模式和语义表示,然而它在特定任务下的效果可能并不令人满意,通过在任务特定的有标签数据上进行微调,模型可以进一步学习任务相关的特征和模式,从而提高性能。
领域适应性:预训练语言模型可能在不同领域的数据上表现不一致,通过在特定领域的有标签数据上进行微调,可以使模型更好地适应该领域的特殊术语、结构和语义,提高在该领域任务上的效果。
数据稀缺性:某些任务可能受制于数据的稀缺性,很难获得大规模的标签数据,监督微调可以通过使用有限的标签数据来训练模型,从而在数据有限的情况下取得较好的性能。
防止过拟合:在监督微调过程中,通过使用有标签数据进行有监督训练,可以减少模型在特定任务上的过拟合风险,这是因为监督微调过程中的有标签数据可以提供更具体的任务信号,有助于约束模型的学习,避免过多地拟合预训练过程中的无监督信号。
2、整理SFT数据
整理SFT数据需要遵循以下原则:
- 按照QA的格式整理数据
- 如果训练多语言模型,需要准备其他语言的数据,本文训练的模型是中文的,所以只准备中文的数据
- SFT的数据需要确保QA的数据回答是正确的,否则模型无法学习到正确的答案
(1)数据格式如下(CSV):
bash
history,q,a
[],好的。现在请你将这个文本中的所有的逗号都替换成空格。,"好的,请稍等一下,现在我会将文本中的所有逗号替换为空格。处理后文本为:""这是一个句子 目的是看看是否可以正确地从这个句子中删除关键词。""。处理结果如何?"
其中history是历史的输入,q是问题,a是答案,但是以上数据无法直接用于微调,需要会拼接,比如翻译类型的会这样处理:
css
instruction:
[USR]:将下列内容翻译成英语:{待翻译文本}
answer
[BOT]:{翻译结果}
拼接后的文本:
<bos_token>[USER]:将下列内容翻译成英语:{待翻译文本}<special token>[BOT]:{翻译结果} <eos_token>
(2)SFT的数据集可以参考以下数据集:
- BelleGroup/train_3.5M_CN
- LinkSoul/instruction_merge_set
- stingning/ultrachat
- BAAI/COIG-PC-core
- shibing624/sharegpt_gpt4
- shareAI/ShareGPT-Chinese-English-90k
- Tiger Research
- BelleGroup/school_math_0.25M
- YeungNLP/moss-003-sft-data
(3)整理数据,将数据压缩:
这里为了数据处理方便,这里数据直接使用:www.modelscope.cn/datasets/de... sft_process_and_write_data
将数据转换为token。
python
def sft_process_and_write_data(data, max_length = 1024, padding = 0):
doc_ids = []
for per in data:
history, q, a = per['history'], per['q'], per['a']
if len(q) < 10 or len(a) < 5:
continue
if len(q) > 512 or len(a) > 512:
continue
messages = []
for history_message in history:
if len(history_message) <= 1:
continue
messages.append(
{"role": 'user', "content": history_message[0][:max_length // 2]}
)
messages.append(
{"role": 'assistant', "content": history_message[1][:max_length // 2]}
)
messages += [
{"role": "user", "content": q},
{"role": "assistant", "content": a},
]
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
input_id = tokenizer(new_prompt).data['input_ids'][:max_length]
padding_len = max_length - len(input_id)
input_id = input_id + [padding] * padding_len
if len(input_id) >= 5:
doc_ids += input_id
return doc_ids
def sft_process():
file_name = 'sft_data.bin'
chunk_size = 2000 # 每次处理的记录数
input_doc_ids = []
datalist = []
sft_datasets = [f'{basepath}/sft_data_zh.jsonl']
chunk_num = 0
for path in sft_datasets:
with jsonlines.open(path) as reader:
for idx, obj in enumerate(reader):
try:
datalist.append({
'history': obj.get('history', ''),
'q': obj.get('input', '') + obj.get('q', ''),
'a': obj.get('output', '') + obj.get('a', '')
})
if len(datalist) >= chunk_size:
chunk_num += 1
input_doc_ids += sft_process_and_write_data(datalist)
arr = np.array(input_doc_ids, dtype=np.uint16)
with open(f'{basepath}/{file_name}', 'wb') as f:
f.write(arr.tobytes())
datalist = []
if chunk_num % 100 == 0:
print(f'chunk:{chunk_num} process end, and input_doc_ids length:{len(input_doc_ids)}')
except jsonlines.InvalidLineError as e:
print(f"Skipping invalid JSON line {idx + 1}: {e}")
continue
if len(datalist) > 0:
input_doc_ids += sft_process_and_write_data(datalist)
arr = np.array(input_doc_ids, dtype=np.uint16)
with open(f'{basepath}/{file_name}', 'wb') as f:
f.write(arr.tobytes())
datalist = []
3、SFT训练
SFT训练的代码和上一篇预训练的代码差别不大,区别是加载SFT数据集,代码如下(替换上一篇预训练的 PretrainDataset
函数):
python
class SFTDataset(Dataset):
def __init__(self, data_path_lst, tokenizer, max_length=1024, prompt_max_len=512, answer_max_len=256):
super().__init__()
self.max_length = max_length
self.prompt_max_len = prompt_max_len
self.answer_max_len = answer_max_len
data_lst = []
for data_path in data_path_lst:
with open(data_path, 'rb') as f:
data = np.fromfile(f, dtype=np.uint16)
data_lst.append(data)
data = np.concatenate(data_lst)
data = data[:max_length * int(len(data) / max_length)]
self.data = data.reshape(-1, max_length)
print("train data.shape:{}".format(self.data.shape))
print("SFTDataset finished.....")
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index: int):
sample = self.data[index]
X = np.array(sample[:-1]).astype(np.int64)
Y = np.array(sample[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y)
SFT训练的数据集很大,训练时间较长,大概需要2-3天的时间,其中部分输出(从这里看loss值已经开始下降了):
ini
...
Epoch:[7/20](307000/341718) loss:0.644 lr:0.0000070 epoch_Time:13.0min:
Epoch:[7/20](308000/341718) loss:0.856 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](309000/341718) loss:0.424 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](310000/341718) loss:0.524 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](311000/341718) loss:0.272 lr:0.0000070 epoch_Time:11.0min:
Epoch:[7/20](312000/341718) loss:0.373 lr:0.0000069 epoch_Time:11.0min:
Epoch:[7/20](313000/341718) loss:0.387 lr:0.0000069 epoch_Time:11.0min:
Epoch:[7/20](314000/341718) loss:0.560 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](315000/341718) loss:0.365 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](316000/341718) loss:0.226 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](317000/341718) loss:0.666 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](318000/341718) loss:0.504 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](319000/341718) loss:0.534 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](320000/341718) loss:0.403 lr:0.0000069 epoch_Time:8.0min:
Epoch:[7/20](321000/341718) loss:0.445 lr:0.0000069 epoch_Time:8.0min:
Epoch:[7/20](322000/341718) loss:0.581 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](323000/341718) loss:0.655 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](324000/341718) loss:0.606 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](325000/341718) loss:0.480 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](326000/341718) loss:0.696 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](327000/341718) loss:0.634 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](328000/341718) loss:0.852 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](329000/341718) loss:0.717 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](330000/341718) loss:0.680 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](331000/341718) loss:0.415 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](332000/341718) loss:0.617 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](333000/341718) loss:0.647 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](334000/341718) loss:0.554 lr:0.0000069 epoch_Time:3.0min:
Epoch:[7/20](335000/341718) loss:0.746 lr:0.0000069 epoch_Time:3.0min:
Epoch:[7/20](336000/341718) loss:0.499 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](337000/341718) loss:0.318 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](338000/341718) loss:0.651 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](339000/341718) loss:0.424 lr:0.0000069 epoch_Time:1.0min:
Epoch:[7/20](340000/341718) loss:0.567 lr:0.0000069 epoch_Time:1.0min:
Epoch:[7/20](341000/341718) loss:0.568 lr:0.0000069 epoch_Time:1.0min:
...