用 RLHF 训练、微调大模型,训练自己的gpt4(一):模型微调(SFT)

大模型的微调主要有以下几个方面:

  • 有监督的微调 (Supervised Fine-tuning,SFT)。
  • 奖励 / 偏好建模 (Reward / preference modeling,RM)。
  • 基于人类反馈的强化学习 (RLHF)。

相关的代码可以在github上访问:github.com/night-is-yo...

本文主要实现了4种模型:

  1. baichuan
  2. chatglm3
  3. qwen
  4. yi

本文主要是介绍第一部分, 微调

sft官方的例子:github.com/huggingface...

python 复制代码
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_into_dataclasses()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)

################
# Model & Tokenizer
################
torch_dtype = (
    model_config.torch_dtype
    if model_config.torch_dtype in ["auto", None]
    else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
    revision=model_config.model_revision,
    trust_remote_code=model_config.trust_remote_code,
    attn_implementation=model_config.attn_implementation,
    torch_dtype=torch_dtype,
    use_cache=False if training_args.gradient_checkpointing else True,
    device_map=get_kbit_device_map() if quantization_config is not None else None,
    quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

################
# Dataset
################
raw_datasets = load_dataset(args.dataset_name)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]

################
# Training
################
trainer = SFTTrainer(
    model=model_config.model_name_or_path,
    model_init_kwargs=model_kwargs,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    max_seq_length=args.max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    peft_config=get_peft_config(model_config),
)
trainer.train()
trainer.save_model(training_args.output_dir)

本文不建议这么写。

SFTTrainer源码解读

大模型微调主要是使用SFTTrainer,相比于标准的Train,作了一些改变

在初始化时,会自动加载模型,不过建议自己初始化模型,传入

python 复制代码
if isinstance(model, str):
    warnings.warn(
        "You passed a model_id to the SFTTrainer. This will automatically create an "
        "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
    )
    model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

如果传入peft_config,会自动初始化peft微调模型

python 复制代码
if is_peft_available() and peft_config is not None:
    if not isinstance(peft_config, PeftConfig):
        raise ValueError(
            "If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer."
            f" and you passed a {type(peft_config)}."
        )

    if not isinstance(model, PeftModel):
        _support_gc_kwargs = hasattr(
            args, "gradient_checkpointing_kwargs"
        ) and "gradient_checkpointing_kwargs" in list(
            inspect.signature(prepare_model_for_kbit_training).parameters
        )
        gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {}
        if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
            preprare_model_kwargs = {
                "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)
            }

            if _support_gc_kwargs:
                preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs

            model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)

            if args is not None:
                args = dataclasses.replace(args, gradient_checkpointing=False)
        elif getattr(args, "gradient_checkpointing", False) and (
            "use_reentrant" not in gradient_checkpointing_kwargs
            or gradient_checkpointing_kwargs["use_reentrant"]
        ):
            # For backward compatibility with older versions of transformers
            if hasattr(model, "enable_input_require_grads"):
                model.enable_input_require_grads()
            else:

                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)

                model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        model = get_peft_model(model, peft_config)
        if args is not None and args.bf16 and getattr(model, "is_loaded_in_4bit", False):
            peft_module_casting_to_bf16(model)

数据加载是一个比较麻烦的地方

为了高效利用数据,我们采用了称之为 打包 的技术: 与 batch 中的每个样本均由单一文本组成,最后基于最长的文本来 padding (填充),我们把很多文本拼接起来,用 EOS token 来隔开,然后分割成一些 chunk (切块) 来做成 batch,避免 padding。

ConstantLengthDataset实现了 "打包" 功能,ConstantLengthDataset的源码如下

python 复制代码
class ConstantLengthDataset(IterableDataset):
    def __iter__(self):
        iterator = iter(self.dataset)
        more_examples = True
        while more_examples:
            buffer, buffer_len = [], 0
            while True:
                if buffer_len >= self.max_buffer_size:
                    break
                try:
                    buffer.append(self.formatting_func(next(iterator)))
                    buffer_len += len(buffer[-1])
                except StopIteration:
                    if self.infinite:
                        iterator = iter(self.dataset)
                        warnings.warn("The dataset reached end and the iterator is reset to the start.")
                    else:
                        more_examples = False
                        break
            tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
                "input_ids"
            ]
            all_token_ids = []
            for tokenized_input in tokenized_inputs:
                if self.append_concat_token:
                    tokenized_input = tokenized_input + [self.concat_token_id]
                all_token_ids.extend(tokenized_input)
            examples = []
            for i in range(0, len(all_token_ids), self.seq_length):
                input_ids = all_token_ids[i : i + self.seq_length]
                if len(input_ids) == self.seq_length:
                    examples.append(input_ids)
            if self.shuffle:
                random.shuffle(examples)
            for example in examples:
                self.current_size += 1
                yield {
                    "input_ids": torch.LongTensor(example),
                    "labels": torch.LongTensor(example),
                }

1.首先为了避免数据量过大,一次加载到内存会内存溢出,因此,每次加载一部分数据

python 复制代码
while more_examples:
    buffer, buffer_len = [], 0
    while True:
        if buffer_len >= self.max_buffer_size:
            break
        try:
            buffer.append(self.formatting_func(next(iterator)))
            buffer_len += len(buffer[-1])
        except StopIteration:
            if self.infinite:
                iterator = iter(self.dataset)
                warnings.warn("The dataset reached end and the iterator is reset to the start.")
            else:
                more_examples = False
                break

上面第一个while是为了完整加载数据,第二个while是为了分批量加载,批量的设置在初始化方法中

ini 复制代码
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences

这里的chars_per_token是为了把字符串转为token数字,一个字符转占用的token数目

2.将所有的数据拼接在一起

python 复制代码
tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
                "input_ids"
            ]
all_token_ids = []
for tokenized_input in tokenized_inputs:
    if self.append_concat_token:
        tokenized_input = tokenized_input + [self.concat_token_id]
    all_token_ids.extend(tokenized_input)

3.将拼接的数据切块(chunk)

python 复制代码
examples = []
for i in range(0, len(all_token_ids), self.seq_length):
    input_ids = all_token_ids[i : i + self.seq_length]
    if len(input_ids) == self.seq_length:
        examples.append(input_ids)
相关推荐
董厂长5 小时前
langchain :记忆组件混淆概念澄清 & 创建Conversational ReAct后显示指定 记忆组件
人工智能·深度学习·langchain·llm
G皮T9 小时前
【人工智能】ChatGPT、DeepSeek-R1、DeepSeek-V3 辨析
人工智能·chatgpt·llm·大语言模型·deepseek·deepseek-v3·deepseek-r1
雷羿 LexChien10 小时前
从 Prompt 管理到人格稳定:探索 Cursor AI 编辑器如何赋能 Prompt 工程与人格风格设计(上)
人工智能·python·llm·编辑器·prompt
堆栈future10 小时前
上下文工程(Context-Engineering): AI应用核心技术剖析
llm·ai编程·mcp
亚里随笔12 小时前
L0:让大模型成为通用智能体的强化学习新范式
人工智能·llm·大语言模型·rlhf
吴佳浩13 小时前
Python入门指南-番外-LLM-Fingerprint(大语言模型指纹):从技术视角看AI开源生态的边界与挑战
python·llm·mcp
吴佳浩13 小时前
Python入门指南-AI模型相似性检测方法:技术原理与实现
人工智能·python·llm
Spider_Man15 小时前
🚀 从阻塞到丝滑:React中DeepSeek LLM流式输出的实现秘籍
前端·react.js·llm
大模型开发16 小时前
Java开发者LLM实战——使用LangChain4j构建本地RAG系统
程序员·langchain·llm
用户307429716715816 小时前
LLM-as-a-Judge :构建可扩展的自动化 AI 评估体系
llm·aigc