用 RLHF 训练、微调大模型,训练自己的gpt4(四):直接偏好优化(DPO)

大模型的微调主要有以下几个方面:

  • 有监督的微调 (Supervised Fine-tuning,SFT)。
  • 奖励 / 偏好建模 (Reward / preference modeling,RM)。
  • 基于人类反馈的强化学习 (RLHF)。

相关的代码可以在github上访问:github.com/night-is-yo...

本文主要实现了4种模型:

  1. baichuan
  2. chatglm3
  3. qwen
  4. yi

DPO通过一个简单的分类目标函数直接优化最能满足偏好的策略,没有显式的奖励函数或强化学习过程。本文讲解DPO的实现原理。

zephyr-7b是Hugging Face团队基于Mistral 7B通过dpo微调得到的。

dpo微调分为两个步骤:

  1. 有监督的微调 (Supervised Fine-tuning,SFT)。
  2. 直接偏好优化 (Direct Preference Optimization,DPO)

zephyr-7b 的官方源码为:

python 复制代码
trainer = DPOTrainer(
    model,
    ref_model,
    model_init_kwargs=model_kwargs,
    ref_model_init_kwargs=ref_model_kwargs,
    args=training_args,
    beta=training_args.beta,
    train_dataset=raw_datasets["train"],
    eval_dataset=raw_datasets["test"],
    tokenizer=tokenizer,
    max_length=training_args.max_length,
    max_prompt_length=training_args.max_prompt_length,
    peft_config=get_peft_config(model_args),
    loss_type=training_args.loss_type,
)

###############
# Training loop
###############
checkpoint = None
if training_args.resume_from_checkpoint is not None:
    checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
    checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(raw_datasets["train"])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

DPOTrainer源码解读

数据加载,DPODataCollatorWithPadding源码解读

python 复制代码
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
    tokenized_batch = []

    for feature in features:
        prompt = feature["prompt"]
        chosen = feature["chosen"]
        rejected = feature["rejected"]

        batch_element = self.tokenize_batch_element(prompt, chosen, rejected)
        tokenized_batch.append(batch_element)

    # return collated batch
    return self.collate(tokenized_batch)

dpo数据加载分为两个部分

  1. 先tokenize_batch_element,将输入数据处理成统一的长度
  2. 调用collate,将批量数据收集到一起

注意,如果要 使用自定义数据集,需要提供以下3个参数:

  • prompt
  • chosen
  • rejected

数据处理的源码,将字符转为token,如果是自定义数据集,一般是放在get_item()方法中。

python 复制代码
def tokenize_batch_element(
    self,with self.accelerator.unwrap_model(self.model).disable_adapter():
    prompt: str,
    chosen: str,
    rejected: str,
) -> Dict:

    batch = {}

    if not self.is_encoder_decoder:
        chosen_tokens = self.tokenizer(chosen, add_special_tokens=False)
        rejected_tokens = self.tokenizer(rejected, add_special_tokens=False)
        prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)

        eos_token_id = self.tokenizer.eos_token_id
        # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0)
        eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id]
        # attention mask these indices to eos_token_id
        new_attention_mask = [
            0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"])
        ]
        prompt_tokens["attention_mask"] = new_attention_mask

        # do the same for chosen and rejected
        eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id]
        new_attention_mask_c = [
            0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"])
        ]
        chosen_tokens["attention_mask"] = new_attention_mask_c

        eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id]
        new_attention_mask_r = [
            0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"])
        ]
        rejected_tokens["attention_mask"] = new_attention_mask_r

        # add EOS token to end of prompt
        chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
        chosen_tokens["attention_mask"].append(1)

        rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
        rejected_tokens["attention_mask"].append(1)

        longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

        # if combined sequence is too long, truncate the prompt
        if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
            if self.truncation_mode == "keep_start":
                prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()}
            elif self.truncation_mode == "keep_end":
                prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()}
            else:
                raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

        # if that's still too long, truncate the response
        if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
            chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()}
            rejected_tokens = {
                k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items()
            }

        # Create labels
        chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
        rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}
        chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
        chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
            prompt_tokens["input_ids"]
        )
        rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
        rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
            prompt_tokens["input_ids"]
        )

        for k, toks in {
            "chosen": chosen_sequence_tokens,
            "rejected": rejected_sequence_tokens,
            "prompt": prompt_tokens,
        }.items():
            for type_key, tokens in toks.items():
                if type_key == "token_type_ids":
                    continue
                batch[f"{k}_{type_key}"] = tokens

    batch["prompt"] = prompt
    batch["chosen"] = prompt + chosen
    batch["rejected"] = prompt + rejected
    batch["chosen_response_only"] = chosen
    batch["rejected_response_only"] = rejected

    return batch

经过处理得到:

'chosen_input_ids', 'chosen_attention_mask', 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels', 'prompt_input_ids', 'prompt_attention_mask'

'prompt', 'chosen', 'rejected', 'chosen_response_only', 'rejected_response_only'\],原始数据用于日志 ```python def collate(self, batch): # first, pad everything to the same length padded_batch = {} for k in batch[0].keys(): if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): # adapted from https://stackoverflow.com/questions/73256206 if "prompt" in k: to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] else: to_pad = [torch.LongTensor(ex[k]) for ex in batch] if k.endswith("_input_ids"): padding_value = self.tokenizer.pad_token_id elif k.endswith("_labels"): padding_value = self.label_pad_token_id elif k.endswith("_attention_mask"): padding_value = self.padding_value else: raise ValueError(f"Unexpected key in batch '{k}'") padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) # for the prompt, flip back so padding is on left side if "prompt" in k: padded_batch[k] = padded_batch[k].flip(dims=[1]) else: padded_batch[k] = [ex[k] for ex in batch] return padded_batch ``` 值得注意的是,prompt的padding是在左边,为了实现这一功能,注意上述代码片段 ```ini to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) padded_batch[k] = padded_batch[k].flip(dims=[1]) ``` ### DPOTrainer的模型 训练需要两个模型,分别是model、ref model。 对于peft微调来说,ref model用以下代码表示: ```python with self.accelerator.unwrap_model(self.model).disable_adapter(): ``` 也就是说,model和ref model公用一个模型,但 model多了可训练的lora模块,ref model不可训练 如果不是peft微调,则需要传入ref model,但是显存会多占用一倍 ### DPOTrainer计算损失的源码 ```python def compute_loss( self, model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") if self.accelerator.is_main_process: self.store_metrics(metrics, train_eval="train") if return_outputs: return (loss, metrics) return lossA ``` 注意计算损失调用了get_batch_metrics,现在转到该源代码 ```python def get_batch_metrics( self, model, batch: Dict[str, Union[List, torch.LongTensor]], train_eval: Literal["train", "eval"] = "train", ): """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} ( policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, ) = self.concatenated_forward(model, batch) with torch.no_grad(): if self.ref_model is None: with self.accelerator.unwrap_model(self.model).disable_adapter(): ( reference_chosen_logps, reference_rejected_logps, _, _, ) = self.concatenated_forward(self.model, batch) losses, chosen_rewards, rejected_rewards = self.dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, ) reward_accuracies = (chosen_rewards > rejected_rewards).float() return losses.mean(), metrics ``` 计算损失共分为3步: 1. 计算model的 policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, 2. 计算ref model 的 reference_chosen_logps, reference_rejected_logps, 3. 计算dpo loss 下面先看concatenated_inputs是如何得出模型logps、logits的 ```python def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: concatenated_batch = {} max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) for k in batch: if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value concatenated_key = k.replace("chosen", "concatenated") concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) for k in batch: if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value concatenated_key = k.replace("rejected", "concatenated") concatenated_batch[concatenated_key] = torch.cat( ( concatenated_batch[concatenated_key], pad_to_length(batch[k], max_length, pad_value=pad_value), ), dim=0, ).to(self.accelerator.device) if self.is_encoder_decoder: concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1) concatenated_batch["concatenated_attention_mask"] = batch["prompt_attention_mask"].repeat(2, 1) return concatenated_batch ``` 先将chosen和rejected拼接在一起 ```python def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: concatenated_batch = self.concatenated_inputs(batch) len_chosen = batch["chosen_labels"].shape[0] model_kwargs = ( { "labels": concatenated_batch["concatenated_labels"], "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), } if self.is_encoder_decoder else {} ) all_logits = model( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], **model_kwargs, ).logits.to(torch.float32) all_logps = self._get_batch_logps( all_logits, concatenated_batch["concatenated_labels"], average_log_prob=False, ) chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) ``` 再调用model统一计算logps、logits _get_batch_logps的源码略,比较通用的方法,可以看官github上的源码。 #### 最后是损失计算部分: ```python def dpo_loss( self, policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, reference_free: bool = False, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps if reference_free: ref_logratios = 0 logits = pi_logratios - ref_logratios if self.loss_type == "sigmoid": losses = -F.logsigmoid(self.beta * logits) elif self.loss_type == "hinge": losses = torch.relu(1 - self.beta * logits) else: raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']") chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() return losses, chosen_rewards, rejected_rewards ``` 损失计算就很简单了,但是其数学推倒很麻烦。dpo训练不需要reward模型,但又需要防止模型训歪,所以简化了损失计算方法。 ```python _chosen_logps - _rejected_logps ``` 上述代码代替了reward方法的作用。 ```python pi_logratios - ref_logratios ``` 上述代码防止模型训歪

相关推荐
藏锋入鞘2 小时前
AI First 编程:Cursor 深度体验和”智驾式编程“实操
llm·ai编程
大尾巴青年5 小时前
07 一分钟搞懂langchain如何调用tool
langchain·llm
AI大模型学习教程5 小时前
Transformer:BERT模型和代码解析
人工智能·llm
LLM大模型5 小时前
LangChain篇- 一文读懂 LCEL工作流编排
人工智能·程序员·llm
仙人掌_lz7 小时前
如何打造一款金融推理工具Financial Reasoning Workflow:WebUI+Ollama+Fin-R1+MCP/RAG
人工智能·搜索引擎·ai·金融·llm·rag·mcp
风雨中的小七7 小时前
解密prompt系列55.Agent Memory的工程实现 - Mem0 & LlamaIndex
llm·nlp
SpikeKing7 小时前
LLM - LlamaFactory 的大模型推理 踩坑记录
人工智能·llm·llamafactory
SpikeKing16 小时前
Server - 使用 Docker 配置 PyTorch 研发环境
pytorch·docker·llm
掘金安东尼1 天前
字节-Trae、阿里-通义灵码、腾讯-CodeBuddy,为什么都在“卷”AI编码?
面试·llm·github
土豆12501 天前
告别“专属”编辑器:为什么 GitHub Copilot 是比 Cursor 更优的 AI 编程选择
llm·cursor·github copilot