用 RLHF 训练、微调大模型,训练自己的gpt4(一):模型微调(SFT)

大模型的微调主要有以下几个方面:

  • 有监督的微调 (Supervised Fine-tuning,SFT)。
  • 奖励 / 偏好建模 (Reward / preference modeling,RM)。
  • 基于人类反馈的强化学习 (RLHF)。

相关的代码可以在github上访问:github.com/night-is-yo...

本文主要实现了4种模型:

  1. baichuan
  2. chatglm3
  3. qwen
  4. yi

本文主要是介绍第一部分, 微调

sft官方的例子:github.com/huggingface...

python 复制代码
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_into_dataclasses()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)

################
# Model & Tokenizer
################
torch_dtype = (
    model_config.torch_dtype
    if model_config.torch_dtype in ["auto", None]
    else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
    revision=model_config.model_revision,
    trust_remote_code=model_config.trust_remote_code,
    attn_implementation=model_config.attn_implementation,
    torch_dtype=torch_dtype,
    use_cache=False if training_args.gradient_checkpointing else True,
    device_map=get_kbit_device_map() if quantization_config is not None else None,
    quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

################
# Dataset
################
raw_datasets = load_dataset(args.dataset_name)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]

################
# Training
################
trainer = SFTTrainer(
    model=model_config.model_name_or_path,
    model_init_kwargs=model_kwargs,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    max_seq_length=args.max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    peft_config=get_peft_config(model_config),
)
trainer.train()
trainer.save_model(training_args.output_dir)

本文不建议这么写。

SFTTrainer源码解读

大模型微调主要是使用SFTTrainer,相比于标准的Train,作了一些改变

在初始化时,会自动加载模型,不过建议自己初始化模型,传入

python 复制代码
if isinstance(model, str):
    warnings.warn(
        "You passed a model_id to the SFTTrainer. This will automatically create an "
        "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
    )
    model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

如果传入peft_config,会自动初始化peft微调模型

python 复制代码
if is_peft_available() and peft_config is not None:
    if not isinstance(peft_config, PeftConfig):
        raise ValueError(
            "If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer."
            f" and you passed a {type(peft_config)}."
        )

    if not isinstance(model, PeftModel):
        _support_gc_kwargs = hasattr(
            args, "gradient_checkpointing_kwargs"
        ) and "gradient_checkpointing_kwargs" in list(
            inspect.signature(prepare_model_for_kbit_training).parameters
        )
        gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {}
        if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
            preprare_model_kwargs = {
                "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)
            }

            if _support_gc_kwargs:
                preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs

            model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)

            if args is not None:
                args = dataclasses.replace(args, gradient_checkpointing=False)
        elif getattr(args, "gradient_checkpointing", False) and (
            "use_reentrant" not in gradient_checkpointing_kwargs
            or gradient_checkpointing_kwargs["use_reentrant"]
        ):
            # For backward compatibility with older versions of transformers
            if hasattr(model, "enable_input_require_grads"):
                model.enable_input_require_grads()
            else:

                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)

                model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        model = get_peft_model(model, peft_config)
        if args is not None and args.bf16 and getattr(model, "is_loaded_in_4bit", False):
            peft_module_casting_to_bf16(model)

数据加载是一个比较麻烦的地方

为了高效利用数据,我们采用了称之为 打包 的技术: 与 batch 中的每个样本均由单一文本组成,最后基于最长的文本来 padding (填充),我们把很多文本拼接起来,用 EOS token 来隔开,然后分割成一些 chunk (切块) 来做成 batch,避免 padding。

ConstantLengthDataset实现了 "打包" 功能,ConstantLengthDataset的源码如下

python 复制代码
class ConstantLengthDataset(IterableDataset):
    def __iter__(self):
        iterator = iter(self.dataset)
        more_examples = True
        while more_examples:
            buffer, buffer_len = [], 0
            while True:
                if buffer_len >= self.max_buffer_size:
                    break
                try:
                    buffer.append(self.formatting_func(next(iterator)))
                    buffer_len += len(buffer[-1])
                except StopIteration:
                    if self.infinite:
                        iterator = iter(self.dataset)
                        warnings.warn("The dataset reached end and the iterator is reset to the start.")
                    else:
                        more_examples = False
                        break
            tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
                "input_ids"
            ]
            all_token_ids = []
            for tokenized_input in tokenized_inputs:
                if self.append_concat_token:
                    tokenized_input = tokenized_input + [self.concat_token_id]
                all_token_ids.extend(tokenized_input)
            examples = []
            for i in range(0, len(all_token_ids), self.seq_length):
                input_ids = all_token_ids[i : i + self.seq_length]
                if len(input_ids) == self.seq_length:
                    examples.append(input_ids)
            if self.shuffle:
                random.shuffle(examples)
            for example in examples:
                self.current_size += 1
                yield {
                    "input_ids": torch.LongTensor(example),
                    "labels": torch.LongTensor(example),
                }

1.首先为了避免数据量过大,一次加载到内存会内存溢出,因此,每次加载一部分数据

python 复制代码
while more_examples:
    buffer, buffer_len = [], 0
    while True:
        if buffer_len >= self.max_buffer_size:
            break
        try:
            buffer.append(self.formatting_func(next(iterator)))
            buffer_len += len(buffer[-1])
        except StopIteration:
            if self.infinite:
                iterator = iter(self.dataset)
                warnings.warn("The dataset reached end and the iterator is reset to the start.")
            else:
                more_examples = False
                break

上面第一个while是为了完整加载数据,第二个while是为了分批量加载,批量的设置在初始化方法中

ini 复制代码
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences

这里的chars_per_token是为了把字符串转为token数字,一个字符转占用的token数目

2.将所有的数据拼接在一起

python 复制代码
tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
                "input_ids"
            ]
all_token_ids = []
for tokenized_input in tokenized_inputs:
    if self.append_concat_token:
        tokenized_input = tokenized_input + [self.concat_token_id]
    all_token_ids.extend(tokenized_input)

3.将拼接的数据切块(chunk)

python 复制代码
examples = []
for i in range(0, len(all_token_ids), self.seq_length):
    input_ids = all_token_ids[i : i + self.seq_length]
    if len(input_ids) == self.seq_length:
        examples.append(input_ids)
相关推荐
冬奇Lab7 小时前
RAG 系列(八):RAG 评估体系——用数据说话
人工智能·llm
Irissgwe10 小时前
LangChain之核心组件(输出解析器)
ai·langchain·llm·ai编程·输出解析器
阿里云大数据AI技术13 小时前
Qwen3.6、Kimi-K2.6、Minimax-M2.7、GLM-5.1 来啦!PAI支持海量模型一键部署!
人工智能·llm
Irissgwe15 小时前
LangChain之核心组件(少样本提示词)
人工智能·langchain·llm·langgraph
litble16 小时前
如何速成LLM以伪装成一个AI研究者(4)——PPO,GRPO,DAPO,GSPO
人工智能·llm·ppo·grpo·gspo·dapo
强殖装甲凯普17 小时前
我把「3小时播客变成可搜索文本」做成了 Claude Code 的一条命令
llm·skill·播客·claude code
Baihai IDP17 小时前
为什么 AI Agent 重新爱上了文件系统(Filesystems)
人工智能·ai·llm·agi
雪碧聊技术18 小时前
一文讲透AI大模型相关的专业名词
llm·token
山顶夕景20 小时前
【多模态RAG】Purifying Multimodal Retrieval
大模型·llm·mllm·多模态rag
swipe1 天前
别再把 AI 聊天做成纯文本:从 agui 这个前后端项目,拆解“可感知工具调用”的流式 AI UI
后端·langchain·llm