代码的摘要说明
一、整体功能概述
这段 Python 代码主要实现了基于 Hugging Face Transformers 库对预训练语言模型(具体为 TAIDE-LX-7B-Chat 模型)进行微调(Fine-tuning)的功能,使其能更好地应用于生成唐诗相关内容的任务。整个流程涵盖了数据加载与预处理、模型配置、模型训练以及训练后模型的测试与结果保存等环节。二、代码各部分详细说明
导入模块部分
导入了多个常用的 Python 库,如 os、sys、argparse 等基础库,以及和深度学习、自然语言处理相关的库,像 torch、transformers、peft 等,还导入了用于忽略警告信息、文本颜色控制等功能的相关模块,为后续代码实现做准备。
基础配置与参数设置部分
模型相关路径:定义了模型名称、模型所在目录、工作目录、输出目录以及检查点目录等路径相关的变量,确保后续文件操作(如加载模型、保存训练结果等)能找到正确的位置。
训练参数设置:设定了训练的总轮数(num_epoch)、学习率(LEARNING_RATE),还包括一系列影响模型训练和推理的超参数,例如用于训练的数据集数量(num_train_data)、文本截断长度(CUTOFF_LEN)、LORA 相关参数(LORA_R、LORA_ALPHA、LORA_DROPOUT)、验证集大小(VAL_SET_SIZE)、微批次大小(MICRO_BATCH_SIZE)、梯度累积步数(GRADIENT_ACCUMULATION_STEPS)等,这些参数将用于控制模型训练过程和微调效果。
分布式训练配置:根据环境变量 WORLD_SIZE 判断是否进行分布式训练,若需要则配置相应的设备映射(device_map)。
数据加载与预处理部分
创建目录:首先确保输出目录和检查点目录存在,若不存在则创建它们,用于后续保存训练结果和模型检查点。
加载数据集:从指定的 json 文件路径(dataset_dir)读取原始数据集,先将其加载为 Python 的字典数据结构(通过 json.load),然后从中截取指定数量(由 num_train_data 决定)的数据保存为临时数据集文件(tmp_dataset_path),再通过 load_dataset 函数以 json 格式重新加载这个临时数据集用于后续处理。
数据处理函数定义:定义了 generate_training_data 函数,用于将包含指令、输入和输出文本的数据点转换为适合模型训练的格式,具体是通过 tokenizer 将文本转换为令牌(token)形式,并处理好输入令牌、对应标签以及注意力掩码等信息,构建训练样本。
划分数据集(可选):根据 VAL_SET_SIZE 参数的值决定是否划分训练集和验证集,如果 VAL_SET_SIZE 大于 0,则从原始训练集中划分出验证集,对训练集和验证集的数据都应用 generate_training_data 函数进行预处理,使其格式符合模型训练要求;若 VAL_SET_SIZE 为 0,则整个数据集都作为训练集进行处理。
模型相关配置部分
创建模型和令牌器:从指定的模型路径(model_path)加载预训练的语言模型,并根据相关配置(如量化配置 BitsAndBytesConfig 用于以特定量化方式加载模型以节省内存等)进行初始化;同时加载对应的分词器(tokenizer),并设置其填充令牌(pad_token)与结束令牌(eos_token)相同,方便后续文本处理。
模型准备与微调配置:将加载的模型通过 prepare_model_for_int8_training 函数进行处理,使其适合 INT8 训练方式,接着使用 LoraConfig 配置 LORA(低秩适配,一种参数高效微调方法)相关参数,并通过 get_peft_model 函数应用 LORA 配置到模型上,以便后续进行参数高效的微调训练。
生成配置定义:定义了 GenerationConfig 对象,用于指定模型在生成文本时的一些解码参数,例如是否采样(do_sample)、温度参数(temperature)、束搜索的束数量(num_beams)等,这些参数会影响模型生成唐诗内容时的质量和多样性等特性。
模型训练部分
使用 transformers.Trainer 类来实例化训练器对象(trainer),传入模型、训练数据集、验证数据集(可为 None)以及训练相关的各种参数(如批次大小、学习率、训练轮数、梯度累积步数、日志记录和模型保存相关的配置等),同时指定了数据整理器(data_collator)用于整理训练数据的格式。在训练前还禁用了模型的缓存功能(model.config.use_cache = False),然后通过 try-except 语句块尝试执行模型训练过程,若训练出现异常则打印出错误信息。
模型保存部分
同样使用 try-except 语句块尝试将训练后的模型保存到指定的检查点目录(ckpt_dir)中,若保存过程出现异常则打印相应的错误提示。
模型测试与结果保存部分
从指定的测试数据文件(test_data_path)中加载测试数据集(通过 json.load),然后循环遍历每条测试数据,调用 evaluate 函数(该函数基于给定的指令、生成配置和输入等信息,利用已训练好的模型生成相应回复内容)获取模型针对测试数据的回复,将测试数据的输入以及模型生成的回复拼接后写入到指定的结果文件(output_path)中保存起来,同时也会打印出来方便查看,用于评估模型在测试集上的表现。原文链接:https://blog.csdn.net/chenchihwen/article/details/144000079
整体代码参照这篇文章
导入必要的库
import os
import sys
import argparse
import json
import warnings
import logging
-
导入了多个 Python 标准库,
os
用于操作系统相关的操作(如文件路径处理、目录创建等),sys
用于处理 Python 运行时环境相关的配置和参数,argparse
常用于命令行参数解析(但此代码中未体现其使用),json
用于处理 JSON 格式的数据读写,warnings
用于控制警告信息的显示,logging
用于记录日志信息。忽略警告信息
warnings.filterwarnings("ignore")
-
配置
warnings
模块,使其忽略所有警告信息,避免在程序运行过程中输出大量警告干扰正常的输出和分析。import torch
import torch.nn as nn
import bitsandbytes as bnb
from datasets import load_dataset, load_from_disk
import transformers
from peft import PeftModel
from colorama import Fore, Back, Style -
导入了
torch
(深度学习常用的张量计算库)及其nn
模块(用于构建神经网络),bitsandbytes
(可能用于模型量化等相关功能),load_dataset
和load_from_disk
函数用于加载数据集,transformers
库(用于自然语言处理相关的预训练模型、工具等),PeftModel
(可能是某种特定的模型扩展相关类),colorama
(用于在终端输出添加颜色等样式,不过此处代码后续未看到其使用)。以下是根据Hugging Face Transformers库的相关功能进行的设置
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import GenerationConfig
from peft import (
prepare_model_for_int8_training,
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_kbit_training
) -
从
transformers
库中导入了多个用于创建模型、配置模型以及处理特定训练相关设置的类和函数,比如AutoTokenizer
用于自动加载合适的文本令牌化器,AutoConfig
用于自动加载模型配置,AutoModelForCausalLM
用于加载适用于因果语言建模的预训练模型,BitsAndBytesConfig
用于配置模型量化相关的参数;从peft
模块导入的函数主要用于对模型进行特定的微调(如int8
训练准备、配置LORA
相关参数、获取基于LORA
配置后的模型等)。
模型和路径相关的参数设置
# 模型名称
model_name = "TAIDE-LX-7B-Chat"
# 模型所在目录
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TAIDE-LX-7B-Chat")
-
定义了模型的名称以及模型所在的本地文件系统路径,通过
os.path.join
结合当前脚本文件所在目录(os.path.dirname(os.path.abspath(__file__))
获取)来确定完整的模型路径。设置工作目录
work_dir = os.path.dirname(os.path.abspath(file))
输出目录
output_dir = os.path.join(work_dir, "output")
检查点目录
ckpt_dir = os.path.join(work_dir, "checkpoint")
-
确定了工作目录(即当前脚本所在目录),并基于此设置了后续用于存储输出结果的目录和保存模型检查点的目录路径。
训练的总Epoch数
num_epoch = 1
学习率
LEARNING_RATE = 3e-4
-
定义了模型训练的关键超参数,总训练轮数为
1
轮,学习率设置为3e-4
,学习率决定了模型在训练过程中参数更新的步长大小。配置模型和训练参数
cache_dir = os.path.join(work_dir, "cache")
from_ckpt = False
ckpt_name = None -
设置了模型缓存目录路径,以及两个与是否从检查点加载模型、具体检查点名称相关的变量,这里表示不从检查点加载(
from_ckpt
为False
)且没有指定具体检查点名称(ckpt_name
为None
)。以下是补充的超参数设置
用于训练的数据集数量
num_train_data = 1040 # 可设置的最大值为5000,数据量越多,模型性能可能越好,但训练时间也会增加
以下参数影响模型的训练和推理
CUTOFF_LEN = 256 # 文本截断的最大长度
LORA_R = 8 # LORA的R值
LORA_ALPHA = 16 # LORA的Alpha值
LORA_DROPOUT = 0.05 # LORA的Dropout率
VAL_SET_SIZE = 0 # 验证集的大小,0表示不使用验证集
TARGET_MODULES = ["q_proj", "up_proj", "o_proj", "k_proj", "down_proj", "gate_proj", "v_proj"] # 目标模块,用于LORA训练MICRO_BATCH_SIZE = 4 # 微批次大小,这里设置为和原示例类似的值,你可根据实际情况调整
GRADIENT_ACCUMULATION_STEPS = 4 # 计算每个微批次累积的梯度步数,示例中为16//4,这里假设为4,可按需改
logging_steps = 20 # 定義訓練過程中每隔多少步驟輸出一次訓練誌
save_steps = 65 # 定義訓練過程中每隔多少步驟保存一次模型
save_total_limit = 3 # 控制最多保留幾個模型checkpoint
report_to = ["tensorboard"] # 設定上報實驗指標的目標,預設為無 # 可以根据需求调整报告的对象,比如添加其他日志记录工具等 -
这些都是详细的模型训练相关超参数设置:
num_train_data
指定了用于训练模型的数据集样本数量,这里限制为1040
个,更多数据理论上有助于提升模型性能但会增加训练时长。CUTOFF_LEN
用于限制输入文本的最大长度,超过这个长度会截断,有助于控制输入规模。LORA_R
、LORA_ALPHA
、LORA_DROPOUT
是LORA
(低秩适应)方法相关的参数,用于调整模型微调时的结构和特性。VAL_SET_SIZE
决定是否划分出验证集以及验证集的大小,这里设置为0
表示不划分验证集用于评估训练过程中的模型性能。TARGET_MODULES
明确了在LORA
训练时针对模型的哪些模块进行操作。MICRO_BATCH_SIZE
和GRADIENT_ACCUMULATION_STEPS
一起用于控制每次训练时实际处理的数据量和梯度更新的频率。logging_steps
规定了每隔多少训练步骤记录一次训练日志,save_steps
定义了每隔多少步骤保存一次模型检查点,save_total_limit
控制最多保留的检查点数量,report_to
指定了向哪些目标(如tensorboard
)上报实验指标用于可视化等分析。
分布式训练相关配置
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size!= 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
- 首先获取环境变量
WORLD_SIZE
(通常用于分布式训练,表示总的进程数),如果获取不到则默认值为1
。根据world_size
是否等于1
来判断是否是分布式训练(ddp
变量标记),如果是分布式训练(ddp
为True
),则通过环境变量LOCAL_RANK
(通常表示当前进程在本地的序号)来配置设备映射(这里简单地将其映射到一个整数,用于指定设备,比如 GPU 设备编号等)。
创建目录
# 创建输出目录和检查点目录(如果不存在)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(ckpt_dir, exist_ok=True)
- 使用
os.makedirs
函数创建输出目录和检查点目录,如果目录已经存在(exist_ok=True
参数表示存在时不报错)则不会重复创建,确保后续保存输出结果和模型检查点时有对应的目录可用。
数据集加载与处理
# 加载数据集
dataset_dir = os.path.join(work_dir, "data", "training_data", "Tang_training_data.json")
with open(dataset_dir, "r", encoding="utf-8") as f:
data_json = json.load(f)
-
构建数据集文件的完整路径(假设数据集是 JSON 格式存储在特定目录下),然后使用
json.load
函数从打开的文件中读取 JSON 数据并解析为 Python 对象(这里是字典或列表等结构)存储在data_json
变量中。保存处理后的数据集
tmp_dataset_path = os.path.join(work_dir, "tmp_dataset.json")
获取文件所在目录路径
dir_path = os.path.dirname(tmp_dataset_path)
如果目录不存在,则创建目录
if not os.path.exists(dir_path):
os.makedirs(dir_path)明确使用 'w' 模式打开文件用于写入,同时指定 encoding 为 'utf-8'
with open(tmp_dataset_path, 'w', encoding='utf-8') as f:
json.dump(data_json[:num_train_data], f, indent=2, ensure_ascii=False) -
先确定临时数据集文件的路径以及其所在目录路径,若目录不存在就创建它。然后从原始读取的数据集数据(
data_json
)中选取前面num_train_data
个样本,使用json.dump
将这部分数据以缩进为2
、不强制转换 ASCII 编码(ensure_ascii=False
,便于处理中文等非 ASCII 字符)的方式写入到临时数据集文件中。data = load_dataset('json', data_files=tmp_dataset_path, download_mode="force_redownload")
-
使用
load_dataset
函数以 JSON 格式加载刚刚处理并保存的临时数据集文件(data_files
参数指定),并设置下载模式为强制重新下载(不过对于本地文件这里可能没实际的下载行为,只是按照对应的加载逻辑处理),将加载好的数据集对象存储在data
变量中。
创建模型和令牌器
model = AutoModelForCausalLM.from_pretrained(
model_path,
cache_dir=cache_dir,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
),
low_cpu_mem_usage=True
)
-
通过
AutoModelForCausalLM
类从指定的model_path
路径加载预训练的因果语言模型,使用cache_dir
指定缓存目录,同时配置了模型量化相关参数(通过BitsAndBytesConfig
),这里设置为以4
位量化加载(load_in_4bit=True
),采用nf4
量化类型、启用双重量化以及指定计算的数据类型为torch.bfloat16
,并且设置为低 CPU 内存使用模式,以优化模型加载和运行时的内存占用情况。tokenizer = AutoTokenizer.from_pretrained(
model_name,
add_eos_token=True,
cache_dir=cache_dir,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
)
tokenizer.pad_token = tokenizer.eos_token -
类似地,使用
AutoTokenizer
从指定的model_name
加载对应的文本令牌化器,添加结束符令牌(add_eos_token=True
),配置缓存目录和量化相关参数(与模型配置的量化参数一致)。最后将令牌化器的填充令牌(pad_token
)设置为结束符令牌(eos_token
),方便后续处理文本序列长度不一致等情况。
生成训练数据的函数
def generate_training_data(data_point):
"""
此函数用于将数据点(包含指令、输入和输出文本)转换为模型可读取的令牌形式
参数:
data_point (dict): 包含 "instruction"、"input" 和 "output" 字段的字典,所有字段均为字符串
返回:
dict: 包含模型的输入令牌、注意力掩码和相应输出目标的字典
"""
# 构建完整的输入提示
prompt = f"""\
[INST] <<SYS>>
You are a helpful assistant and good at writing Tang poem. 你是一个乐于助人的助手且擅长写唐诗。
<</SYS>>
{data_point["instruction"]}
{data_point["input"]}
[/INST]"""
# 计算输入令牌的数量
len_user_prompt_tokens = (
len(
tokenizer(
prompt,
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)["input_ids"]
) - 1
)
# 将输入提示转换为令牌
full_tokens = tokenizer(
prompt + " " + data_point["output"] + "</s>",
truncation=True,
max_length=CUTOFF_LEN + 1,
padding="max_length",
)["input_ids"][:-1]
return {
"input_ids": full_tokens,
"labels": [-100] * len_user_prompt_tokens + full_tokens[len_user_prompt_tokens:],
"attention_mask": [1] * (len(full_tokens)),
}
- 这个函数用于将输入的包含指令、输入文本和输出文本的数据点转换为适合模型训练的格式:
- 首先构建完整的输入提示文本(按照特定的格式拼接指令、输入等内容),形成
prompt
字符串。 - 然后通过令牌化器对
prompt
进行处理(截断到CUTOFF_LEN + 1
长度并进行填充等操作),计算出用户提示部分的令牌数量(减去1
可能是去除某个特定的标记等原因)。 - 接着将包含输出文本的完整提示再进行令牌化处理,并去掉最后一个令牌(具体原因可能与模型训练时的输入输出格式要求有关),得到
full_tokens
。 - 最后构建并返回一个字典,包含模型输入的令牌序列(
input_ids
),对应的标签序列(labels
,用户提示部分设为-100
表示不计算这部分的损失,后面接上实际输出对应的令牌)以及注意力掩码(全1
表示都参与注意力计算,长度与令牌序列长度一致)。
- 首先构建完整的输入提示文本(按照特定的格式拼接指令、输入等内容),形成
评估模型的函数
def evaluate(instruction, generation_config, max_len, input="", verbose=True):
"""
此函数用于根据输入的指令、生成配置和最大长度,获取模型的输出
参数:
instruction (str): 描述模型要执行的操作的字符串
generation_config (transformers.GenerationConfig): 用于指定与模型推理相关的解码参数的对象
max_len (int): 模型输出的最大长度
input (str, 可选): 模型需要解决指令的输入字符串,默认值为 ""
verbose (bool, 可选): 是否打印模型的输出,默认值为True
返回:
str: 模型根据指令和输入生成的响应
"""
# 构建完整的输入提示
prompt = f"""\
[INST] <<SYS>>
You are a helpful assistant and good at writing Tang poem. 你是一个乐于助人的助手且擅长写唐诗。
<</SYS>>
{instruction}
{input}
[/INST]"""
# 将提示文本转换为模型所需的数字表示形式
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].cuda() if torch.cuda.is_available() else inputs["input_ids"]
# 使用模型进行生成回复
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_len,
)
# 将生成的回复解码并打印
for s in generation_output.sequences:
output = tokenizer.decode(s)
output = output.split("[/INST]")[1].replace("</s>", "").replace("<s>", "").replace("Assistant:", "").replace("Assistant", "").strip()
if verbose:
print(output)
return output
- 该函数用于对模型进行评估,根据给定的指令、生成配置以及最大输出长度等参数获取模型的输出:
- 先按照特定格式构建完整的输入提示文本
prompt
,包含系统描述、指令和可选的输入内容。 - 接着使用令牌化器将提示文本转换为张量形式(
return_tensors="pt"
表示返回 PyTorch 张量),如果 GPU 可用则将输入的令牌input_ids
移到 GPU 上,方便后续加速计算。
- 先按照特定格式构建完整的输入提示文本
evaluate
函数(续)
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_len,
)
-
调用模型的
generate
方法来生成回复内容,传递了几个重要参数:-
generation_config
:这是一个之前定义好的配置对象(类型为transformers.GenerationConfig
),里面包含了如采样策略(do_sample
等参数)、温度参数(temperature
)、束搜索相关参数(num_beams
等)等诸多与生成文本过程相关的设置,用于控制模型生成回复的方式。 -
return_dict_in_generate=True
:表示让generate
方法以字典形式返回生成的结果,方便后续提取和处理相关信息,比如包含生成的序列、对应的分数等信息。 -
output_scores=True
:指示模型在生成过程中同时输出每个生成步骤的分数信息(例如概率得分等,虽然代码后续没有明显体现对这些分数的进一步处理,但在一些更详细的分析场景中可能会用到)。 -
max_new_tokens=max_len
:指定了模型最多生成的新令牌数量,也就是限制了生成回复的最大长度,避免生成过长的无意义文本。for s in generation_output.sequences:
output = tokenizer.decode(s)
output = output.split("[/INST]")[1].replace("", "").replace("", "").replace("Assistant:", "").replace("Assistant", "").strip()
if verbose:
print(output)
-
-
这段代码对生成的结果进行处理和输出(如果
verbose
为True
):- 首先通过循环遍历
generation_output.sequences
(生成结果中的序列,可能存在多个候选序列等情况,不过这里代码逻辑上似乎只处理了其中一个,没有更复杂的选择逻辑体现)。 - 使用
tokenizer.decode
将生成的令牌序列s
解码为文本形式,得到原始的回复文本。 - 接着对解码后的文本进行一系列字符串处理操作:
split("[/INST]")[1]
:按照[/INST]
这个特定标记进行分割,并取分割后的第二部分内容,可能是为了去除前面的一些指令相关的固定格式内容,提取出真正的回复部分。- 然后通过
replace
方法依次去除一些多余的标记如</s>
(结束符标记)、<s>
(起始符标记)、Assistant:
(可能是回复开头的特定前缀标记)以及Assistant
(类似的前缀情况),并通过strip
方法去除文本两端的空白字符,最终得到比较干净整洁的回复文本内容。
- 如果
verbose
为True
,就将处理后的回复文本打印输出,方便查看模型生成的结果情况;如果verbose
为False
,则只是进行上述处理,不进行打印,用于后续可能的静默式评估或者将结果进一步传递使用等场景。 - 最后,函数返回处理后的回复文本(无论是否打印),使得调用该函数的地方可以获取到模型生成的结果用于后续的操作,比如和真实结果对比评估模型性能等。
- 首先通过循环遍历
划分训练集和验证集(如果需要)
if VAL_SET_SIZE > 0:
train_val = data["train"].train_test_split(
test_size=VAL_SET_SIZE, shuffle=True, seed=42
)
train_data = train_val["train"].shuffle().map(generate_training_data)
val_data = train_val["test"].shuffle().map(generate_training_data)
else:
train_data = data['train'].shuffle().map(generate_training_data)
val_data = None
- 这部分代码根据之前定义的验证集大小(
VAL_SET_SIZE
)来决定是否划分训练集和验证集:- 如果
VAL_SET_SIZE
大于0
,表示需要划分出验证集用于在训练过程中评估模型在未见过的数据上的性能。通过调用数据集对象data["train"]
(假设data
数据集有train
这个子集表示训练数据部分)的train_test_split
方法,按照给定的测试集大小比例(test_size=VAL_SET_SIZE
)、进行随机打乱(shuffle=True
)并设置随机种子(seed=42
,保证每次划分结果可复现)来划分出训练集和验证集两部分,分别存储在train_val
这个包含train
和test
两个子集的对象中。 - 然后对划分后的训练集和验证集进一步处理,通过调用
shuffle
方法再次打乱顺序(进一步打乱确保数据随机性,可能是为了后续训练更稳定等原因),并使用map
方法结合generate_training_data
函数将数据集中的每个样本转换为适合模型训练的格式(如前面generate_training_data
函数中所做的将文本数据转换为包含输入令牌、标签、注意力掩码等的字典形式),最终得到处理后的训练数据train_data
和验证数据val_data
。 - 如果
VAL_SET_SIZE
等于0
,则表示不需要划分验证集,直接将整个训练数据集(data['train']
)进行打乱(shuffle
)并通过map
方法应用generate_training_data
函数处理每个样本,得到训练数据train_data
,同时将验证数据val_data
设为None
,表明没有单独的验证集用于后续评估。
- 如果
将模型准备好以使用 INT8 训练
model = prepare_model_for_int8_training(model)
- 调用
prepare_model_for_int8_training
函数对已加载的模型进行预处理,使其能够适应以INT8
量化精度进行训练的要求,这个函数内部可能会对模型的一些层结构、参数初始化等方面进行调整,以便在后续训练过程中能够基于INT8
量化方式正确地更新参数、计算梯度等,确保训练的顺利进行和模型性能的合理优化。
使用 LoraConfig 配置 LORA 模型
config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=TARGET_MODULES,
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
- 首先创建一个
LoraConfig
配置对象,传入之前定义好的相关LORA
超参数:r=LORA_R
:设置LORA
方法中的秩(rank)参数,它决定了低秩矩阵分解的维度等特性,影响模型微调的效果和参数量等。lora_alpha=LORA_ALPHA
:与LORA
的自适应调整等相关的参数,配合r
等参数一起控制模型训练过程中的权重更新等情况。target_modules=TARGET_MODULES
:指定了在模型中哪些模块上应用LORA
技术进行微调,即明确了操作的具体对象范围。lora_dropout=LORA_DROPOUT
:设置LORA
中的丢弃率(Dropout)参数,用于防止过拟合等情况。bias="none"
:表示不处理偏置项(可能是不针对偏置进行LORA
相关的调整等情况)。task_type="CAUSAL_LM"
:表明任务类型是因果语言建模,用于让LORA
配置知晓模型的任务特性以便进行针对性的适配。
- 接着调用
get_peft_model
函数,传入原始模型和LoraConfig
配置对象,对模型进行基于LORA
配置的转换和调整,返回一个应用了LORA
技术后的新模型,替换原来的model
变量,使得后续训练能够基于LORA
微调的方式进行,在尽量少增加参数量的情况下提升模型在特定任务上的性能。
定义 nf4_config
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
- 创建一个
BitsAndBytesConfig
配置对象,重新配置了模型量化相关的参数(和之前模型加载时配置的量化参数类似,但这里单独定义了一个配置对象,可能后续有其他使用场景或者方便修改统一的量化设置等):load_in_4bit=True
:明确表示以4
位量化的方式加载模型,减少模型的内存占用等。bnb_4bit_quant_type="nf4"
:指定量化类型为nf4
,这是一种特定的量化技术选择。bnb_4bit_use_double_quant=True
:启用双重量化机制,进一步优化量化效果和模型存储、计算效率等。bnb_4bit_compute_dtype=torch.bfloat16
:设定在量化计算过程中使用的数据类型为torch.bfloat16
,用于在量化情况下兼顾计算精度和速度等方面的平衡。
定义生成配置
generation_config = GenerationConfig(
do_sample=True,
temperature=0.1,
num_beams=1,
top_p=0.3,
no_repeat_ngram_size=3,
pad_token_id=2
)
- 创建一个
GenerationConfig
对象用于配置模型生成文本时的各种参数:do_sample=True
:表示采用采样的方式生成文本,而不是确定性的生成(例如对比于基于束搜索等确定性的生成策略),这样可以增加生成文本的多样性,但也可能带来一定的不确定性和结果的波动性。temperature=0.1
:温度参数,用于控制采样的随机性程度,较小的值(如这里的0.1
)会让采样更偏向于概率高的词汇,生成的文本相对更保守、更符合常规;较大的值则会增加随机性,使生成的文本更加多样化但可能质量参差不齐。num_beams=1
:束搜索相关参数,这里设置为1
表示不使用束搜索(束搜索通常会保留多个候选路径并综合评估选择最优路径,当num_beams
大于1
时启用,这里仅设为1
相当于就是普通的逐个生成词汇的方式)。top_p=0.3
:一种概率截断采样的参数,叫核采样(nucleus sampling),表示只从概率累计和达到0.3
的那部分词汇中进行采样,有助于避免生成一些概率极低但可能出现的不合理词汇,同时保证一定的词汇选择范围和文本多样性。no_repeat_ngram_size=3
:设置生成文本时避免重复出现的n
元语法(连续的n
个词汇组成的单元)的大小为3
,即尽量不让连续的3
个词汇重复出现,以提高生成文本的流畅性和多样性,防止重复啰嗦的表述。pad_token_id=2
:指定填充令牌对应的id
为2
,用于在处理批次数据等场景下对长度不一致的文本序列进行填充对齐操作时的标识。
训练模型
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=4, # 微批次大小,可根据需要调整
gradient_accumulation_steps=16 // 4, # 计算每个微批次累积的梯度步数
warmup_steps=50,
num_train_epochs=num_epoch,
learning_rate=LEARNING_RATE,
fp16=True, # 使用混合精度训练
logging_steps=20,
save_strategy="steps",
save_steps=65,
output_dir=ckpt_dir,
save_total_limit=3,
ddp_find_unused_parameters=False if ddp else None,
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
- 创建一个
transformers.Trainer
对象用于训练模型,传入了多个关键参数:model=model
:指定要训练的模型,即前面经过INT8
训练准备、LORA
配置后的模型。train_dataset=train_data
:传入训练数据集,是经过前面处理(如转换为合适格式、划分等操作)后的用于训练的样本数据集合。eval_dataset=val_data
:传入验证数据集(如果有,即val_data
不为None
时),用于在训练过程中定期评估模型性能,若val_data
为None
则表示训练过程中不进行基于验证集的评估。args=transformers.TrainingArguments(...)
:创建一个TrainingArguments
对象来配置训练相关的各种超参数:per_device_train_batch_size=4
:设置每个设备(如每个 GPU)上每次训练的微批次大小为4
,即每次处理4
个样本数据,这个大小可以根据硬件资源、模型复杂度等情况进行调整,影响训练时的内存占用和梯度更新频率等。gradient_accumulation_steps=16 // 4
:计算每个微批次累积的梯度步数,这里实际设置为4
(通过16 // 4
计算得出),意味着每处理4
个微批次后才进行一次梯度更新,用于在内存有限等情况下模拟更大的批次大小训练效果,同时控制梯度更新的节奏。warmup_steps=50
:在训练开始阶段设置的热身步数,在这个阶段学习率会从初始值逐渐上升到设定的正常学习率,有助于模型在训练初期更稳定地收敛,避免一开始就使用较大学习率导致训练不稳定等情况。num_train_epochs=num_epoch
:指定训练的总轮数,这里使用之前定义的num_epoch
变量,即训练1
轮(根据前面的定义)。learning_rate=LEARNING_RATE
:使用之前定义好的学习率(3e-4
)来控制模型参数更新的步长大小。fp16=True
:启用混合精度训练,即使用半精度(float16
)的数据类型来存储和计算部分张量,减少内存占用同时在支持的硬件上可以加速计算,又能在一定程度上保证计算精度不至于损失太多,适合深度学习模型训练场景。logging_steps=20
:按照每训练20
步记录一次训练日志,方便查看训练过程中的各项指标变化情况,如损失值、学习率等。save_strategy="steps"
:表示按照训练步数来决定何时保存模型检查点,与之对应的还有按训练轮数等保存策略,这里选择基于步数保存。save_steps=65
:结合save_strategy="steps"
,意味着每隔65
个训练步骤就保存一次模型检查点,用于后续可以恢复训练或者选择不同阶段的模型进行评估等。output_dir=ckpt_dir
:指定保存模型检查点的目录为之前定义的ckpt_dir
,确保模型文件保存到正确的位置。save_total_limit=3
:控制最多保留3
个模型检查点,避免过多的检查点占用大量磁盘空间,当保存的检查点数量超过这个限制时,可能会自动删除较早的检查点。ddp_find_unused_parameters=False if ddp else None
:在分布式训练(ddp
为True
)的场景下,设置是否查找未使用的参数,这里设置为False
表示不进行查找(具体是否查找可能根据实际训练需求和模型特点等决定,不查找可以节省一定的计算资源和时间),如果不是分布式训练(ddp
为False
)则设为None
,可能按照默认行为处理。
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
:指定数据整理器(Data Collator),用于将训练数据整理成合适的批次格式,这里使用的是适用于语言建模任务的整理器,传入了令牌化器tokenizer
,并且设置mlm=False
表示不是做掩码语言建模(而是因果语言建模任务对应的整理方式),它会对批次内的文本序列进行填充、截断等操作,使其符合模型输入的格式要求。
禁用模型的 cache 功能
model.config.use_cache = False
- 将模型配置中的
use_cache
属性设置为False
,禁用模型的缓存功能。在某些情况下,禁用缓存可以避免一些潜在的内存问题或者不正确的结果,尤其是在训练过程中或者模型结构、输入等有动态变化的场景下,确保每次计算都是基于最新的输入和模型状态进行,而不会依赖之前缓存的中间结果。
训练过程
try:
trainer.train()
except Exception as e:
print(f"Training failed: {e}")
- 使用
try-except
语句块来执行模型训练过程,调用trainer.train()
方法启动训练,让模型根据前面配置的训练数据集、超参数等进行训练。如果在训练过程中出现任何异常(比如梯度爆炸、内存不足、运行时错误等),就会捕获该异常并打印出错误信息(Training failed: {e}
),方便排查问题,了解训练失败的原因。
保存训练后的模型
try:
model.save_pretrained(ckpt_dir)
except Exception as e:
print(f"Saving model failed: {e}
整体代码与工作目录,必要的 package
参照这篇文章
训练json数据摘要如下:
Training data:sample, 你可以试着补充
Tang Poem datasethttps://github.com/chinese-poetry/chinese-poetry/tree/master/%E5%85%A8%E5%94%90%E8%AF%97?fbclid=IwAR2bM14S42T-VtrvMi3wywCqKfYJraBtMl7QVTo0qyPMjX9jj9Vj3JepFBA
Tang_training_data.json
[
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "秦川雄帝宅,函谷壯皇居。",
"output": "綺殿千尋起,離宮百雉餘。連甍遙接漢,飛觀迥凌虛。雲日隱層闕,風煙出綺疎。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "巖廊罷機務,崇文聊駐輦。",
"output": "玉匣啓龍圖,金繩披鳳篆。韋編斷仍續,縹帙舒還卷。對此乃淹留,欹案觀墳典。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "移步出詞林,停輿欣武宴。",
"output": "琱弓寫明月,駿馬疑流電。驚雁落虛弦,啼猿悲急箭。閱賞誠多美,於茲乃忘倦。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "鳴笳臨樂館,眺聽歡芳節。",
"output": "急管韻朱絃,清歌凝白雪。彩鳳肅來儀,玄鶴紛成列。去茲鄭衛聲,雅音方可悅。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "芳辰追逸趣,禁苑信多奇。",
"output": "橋形通漢上,峰勢接雲危。煙霞交隱映,花鳥自參差。何如肆轍跡?萬里賞瑤池。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "白日依山盡,黃河入海流。",
"output": "欲窮千里目,更上一層樓。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "落日雙闕昏,回輿九重暮。",
"output": "長煙散初碧,皎月澄輕素。搴幌翫琴書,開軒引雲霧。斜漢耿層閣,清風搖玉樹。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "歡樂難再逢,芳辰良可惜。",
"output": "玉酒泛雲罍,蘭殽陳綺席。千鍾合堯禹,百獸諧金石。得志重寸陰,忘懷輕尺璧。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "建章歡賞夕,二八盡妖妍。",
"output": "羅綺昭陽殿,芬芳玳瑁筵。珮移星正動,扇掩月初圓。無勞上懸圃,即此對神仙。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "新豐停翠輦,譙邑駐鳴笳。",
"output": "園荒一徑斷,苔古半階斜。前池消舊水,昔樹發今花。一朝辭此地,四海遂爲家。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "春蒐馳駿骨,總轡俯長河。",
"output": "霞處流縈錦,風前瀁卷羅。水花翻照樹,堤蘭倒插波。豈必汾陰曲,秋雲發棹歌。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "重巒俯渭水,碧嶂插遙天。",
"output": "出紅扶嶺日,入翠貯巖煙。疊松朝若夜,複岫闕疑全。對此恬千慮,無勞訪九仙。"
},
testing data
Tang_testing_data.json
[
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "雪霽銀妝素,桔高映瓊枝。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "夫子何爲者?栖栖一代中。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "飛蓋去芳園,蘭橈遊翠渚。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "條風開獻節,灰律動初陽。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "昨夜星辰昨夜風,畫樓西畔桂堂東。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "三日入廚下,洗手作羹湯。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "嵩雲秦樹久離居,雙鯉迢迢一紙書。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "慨然撫長劒,濟世豈邀名。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "乘興南遊不戒嚴,九重誰省諫書函。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "猿鳥猶疑畏簡書,風雲常爲護儲胥。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "君問歸期未有期,巴山夜雨漲秋池。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "相見時難別亦難,東風無力百花殘。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "雲母屏風燭影深,長河漸落曉星沈。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "高閣客竟去,小園花亂飛。"
},
{
"instruction": "以下是一首唐詩的第一句話,請用你的知識判斷並完成整首詩。",
"input": "瑤池阿母綺窗開,黃竹歌聲動地哀。"
}
]