大模型的微调主要有以下几个方面:
- 有监督的微调 (Supervised Fine-tuning,SFT)。
- 奖励 / 偏好建模 (Reward / preference modeling,RM)。
- 基于人类反馈的强化学习 (RLHF)。
相关的代码可以在github上访问:github.com/night-is-yo...
本文主要实现了4种模型:
- baichuan
- chatglm3
- qwen
- yi
本文主要是介绍第一部分, 微调
sft官方的例子:github.com/huggingface...
python
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_into_dataclasses()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
raw_datasets = load_dataset(args.dataset_name)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
################
# Training
################
trainer = SFTTrainer(
model=model_config.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text",
max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
packing=True,
peft_config=get_peft_config(model_config),
)
trainer.train()
trainer.save_model(training_args.output_dir)
本文不建议这么写。
SFTTrainer源码解读
大模型微调主要是使用SFTTrainer,相比于标准的Train,作了一些改变
在初始化时,会自动加载模型,不过建议自己初始化模型,传入
python
if isinstance(model, str):
warnings.warn(
"You passed a model_id to the SFTTrainer. This will automatically create an "
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
如果传入peft_config,会自动初始化peft微调模型
python
if is_peft_available() and peft_config is not None:
if not isinstance(peft_config, PeftConfig):
raise ValueError(
"If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer."
f" and you passed a {type(peft_config)}."
)
if not isinstance(model, PeftModel):
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {}
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
preprare_model_kwargs = {
"use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)
}
if _support_gc_kwargs:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
if args is not None:
args = dataclasses.replace(args, gradient_checkpointing=False)
elif getattr(args, "gradient_checkpointing", False) and (
"use_reentrant" not in gradient_checkpointing_kwargs
or gradient_checkpointing_kwargs["use_reentrant"]
):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model = get_peft_model(model, peft_config)
if args is not None and args.bf16 and getattr(model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(model)
数据加载是一个比较麻烦的地方
为了高效利用数据,我们采用了称之为 打包 的技术: 与 batch 中的每个样本均由单一文本组成,最后基于最长的文本来 padding (填充),我们把很多文本拼接起来,用 EOS token 来隔开,然后分割成一些 chunk (切块) 来做成 batch,避免 padding。

ConstantLengthDataset实现了 "打包" 功能,ConstantLengthDataset的源码如下
python
class ConstantLengthDataset(IterableDataset):
def __iter__(self):
iterator = iter(self.dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(self.formatting_func(next(iterator)))
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
warnings.warn("The dataset reached end and the iterator is reset to the start.")
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
"input_ids"
]
all_token_ids = []
for tokenized_input in tokenized_inputs:
if self.append_concat_token:
tokenized_input = tokenized_input + [self.concat_token_id]
all_token_ids.extend(tokenized_input)
examples = []
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
examples.append(input_ids)
if self.shuffle:
random.shuffle(examples)
for example in examples:
self.current_size += 1
yield {
"input_ids": torch.LongTensor(example),
"labels": torch.LongTensor(example),
}
1.首先为了避免数据量过大,一次加载到内存会内存溢出,因此,每次加载一部分数据
python
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(self.formatting_func(next(iterator)))
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
warnings.warn("The dataset reached end and the iterator is reset to the start.")
else:
more_examples = False
break
上面第一个while是为了完整加载数据,第二个while是为了分批量加载,批量的设置在初始化方法中
ini
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
这里的chars_per_token是为了把字符串转为token数字,一个字符转占用的token数目
2.将所有的数据拼接在一起
python
tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
"input_ids"
]
all_token_ids = []
for tokenized_input in tokenized_inputs:
if self.append_concat_token:
tokenized_input = tokenized_input + [self.concat_token_id]
all_token_ids.extend(tokenized_input)
3.将拼接的数据切块(chunk)
python
examples = []
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
examples.append(input_ids)