用 RLHF 训练、微调大模型,训练自己的gpt4(四):直接偏好优化(DPO)

大模型的微调主要有以下几个方面:

  • 有监督的微调 (Supervised Fine-tuning,SFT)。
  • 奖励 / 偏好建模 (Reward / preference modeling,RM)。
  • 基于人类反馈的强化学习 (RLHF)。

相关的代码可以在github上访问:github.com/night-is-yo...

本文主要实现了4种模型:

  1. baichuan
  2. chatglm3
  3. qwen
  4. yi

DPO通过一个简单的分类目标函数直接优化最能满足偏好的策略,没有显式的奖励函数或强化学习过程。本文讲解DPO的实现原理。

zephyr-7b是Hugging Face团队基于Mistral 7B通过dpo微调得到的。

dpo微调分为两个步骤:

  1. 有监督的微调 (Supervised Fine-tuning,SFT)。
  2. 直接偏好优化 (Direct Preference Optimization,DPO)

zephyr-7b 的官方源码为:

python 复制代码
trainer = DPOTrainer(
    model,
    ref_model,
    model_init_kwargs=model_kwargs,
    ref_model_init_kwargs=ref_model_kwargs,
    args=training_args,
    beta=training_args.beta,
    train_dataset=raw_datasets["train"],
    eval_dataset=raw_datasets["test"],
    tokenizer=tokenizer,
    max_length=training_args.max_length,
    max_prompt_length=training_args.max_prompt_length,
    peft_config=get_peft_config(model_args),
    loss_type=training_args.loss_type,
)

###############
# Training loop
###############
checkpoint = None
if training_args.resume_from_checkpoint is not None:
    checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
    checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(raw_datasets["train"])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

DPOTrainer源码解读

数据加载,DPODataCollatorWithPadding源码解读

python 复制代码
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
    tokenized_batch = []

    for feature in features:
        prompt = feature["prompt"]
        chosen = feature["chosen"]
        rejected = feature["rejected"]

        batch_element = self.tokenize_batch_element(prompt, chosen, rejected)
        tokenized_batch.append(batch_element)

    # return collated batch
    return self.collate(tokenized_batch)

dpo数据加载分为两个部分

  1. 先tokenize_batch_element,将输入数据处理成统一的长度
  2. 调用collate,将批量数据收集到一起

注意,如果要 使用自定义数据集,需要提供以下3个参数:

  • prompt
  • chosen
  • rejected

数据处理的源码,将字符转为token,如果是自定义数据集,一般是放在get_item()方法中。

python 复制代码
def tokenize_batch_element(
    self,with self.accelerator.unwrap_model(self.model).disable_adapter():
    prompt: str,
    chosen: str,
    rejected: str,
) -> Dict:

    batch = {}

    if not self.is_encoder_decoder:
        chosen_tokens = self.tokenizer(chosen, add_special_tokens=False)
        rejected_tokens = self.tokenizer(rejected, add_special_tokens=False)
        prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)

        eos_token_id = self.tokenizer.eos_token_id
        # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0)
        eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id]
        # attention mask these indices to eos_token_id
        new_attention_mask = [
            0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"])
        ]
        prompt_tokens["attention_mask"] = new_attention_mask

        # do the same for chosen and rejected
        eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id]
        new_attention_mask_c = [
            0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"])
        ]
        chosen_tokens["attention_mask"] = new_attention_mask_c

        eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id]
        new_attention_mask_r = [
            0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"])
        ]
        rejected_tokens["attention_mask"] = new_attention_mask_r

        # add EOS token to end of prompt
        chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
        chosen_tokens["attention_mask"].append(1)

        rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
        rejected_tokens["attention_mask"].append(1)

        longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

        # if combined sequence is too long, truncate the prompt
        if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
            if self.truncation_mode == "keep_start":
                prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()}
            elif self.truncation_mode == "keep_end":
                prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()}
            else:
                raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

        # if that's still too long, truncate the response
        if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
            chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()}
            rejected_tokens = {
                k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items()
            }

        # Create labels
        chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
        rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}
        chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
        chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
            prompt_tokens["input_ids"]
        )
        rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
        rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
            prompt_tokens["input_ids"]
        )

        for k, toks in {
            "chosen": chosen_sequence_tokens,
            "rejected": rejected_sequence_tokens,
            "prompt": prompt_tokens,
        }.items():
            for type_key, tokens in toks.items():
                if type_key == "token_type_ids":
                    continue
                batch[f"{k}_{type_key}"] = tokens

    batch["prompt"] = prompt
    batch["chosen"] = prompt + chosen
    batch["rejected"] = prompt + rejected
    batch["chosen_response_only"] = chosen
    batch["rejected_response_only"] = rejected

    return batch

经过处理得到:

['chosen_input_ids', 'chosen_attention_mask', 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels', 'prompt_input_ids', 'prompt_attention_mask']

[ 'prompt', 'chosen', 'rejected', 'chosen_response_only', 'rejected_response_only'],原始数据用于日志

python 复制代码
def collate(self, batch):
    # first, pad everything to the same length
    padded_batch = {}
    for k in batch[0].keys():
        if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
            # adapted from https://stackoverflow.com/questions/73256206
            if "prompt" in k:
                to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
            else:
                to_pad = [torch.LongTensor(ex[k]) for ex in batch]
            if k.endswith("_input_ids"):
                padding_value = self.tokenizer.pad_token_id
            elif k.endswith("_labels"):
                padding_value = self.label_pad_token_id
            elif k.endswith("_attention_mask"):
                padding_value = self.padding_value
            else:
                raise ValueError(f"Unexpected key in batch '{k}'")

            padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
            # for the prompt, flip back so padding is on left side
            if "prompt" in k:
                padded_batch[k] = padded_batch[k].flip(dims=[1])
        else:
            padded_batch[k] = [ex[k] for ex in batch]

    return padded_batch

值得注意的是,prompt的padding是在左边,为了实现这一功能,注意上述代码片段

ini 复制代码
to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
padded_batch[k] = padded_batch[k].flip(dims=[1])

DPOTrainer的模型

训练需要两个模型,分别是model、ref model。

对于peft微调来说,ref model用以下代码表示:

python 复制代码
with self.accelerator.unwrap_model(self.model).disable_adapter():

也就是说,model和ref model公用一个模型,但 model多了可训练的lora模块,ref model不可训练

如果不是peft微调,则需要传入ref model,但是显存会多占用一倍

DPOTrainer计算损失的源码

python 复制代码
def compute_loss(
    self,
    model: Union[PreTrainedModel, nn.Module],
    inputs: Dict[str, Union[torch.Tensor, Any]],
    return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:

    loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
    if self.accelerator.is_main_process:
        self.store_metrics(metrics, train_eval="train")

    if return_outputs:
        return (loss, metrics)
    return lossA

注意计算损失调用了get_batch_metrics,现在转到该源代码

python 复制代码
def get_batch_metrics(
    self,
    model,
    batch: Dict[str, Union[List, torch.LongTensor]],
    train_eval: Literal["train", "eval"] = "train",
):
    """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
    metrics = {}

    (
        policy_chosen_logps,
        policy_rejected_logps,
        policy_chosen_logits,
        policy_rejected_logits,
    ) = self.concatenated_forward(model, batch)
    with torch.no_grad():
        if self.ref_model is None:
            with self.accelerator.unwrap_model(self.model).disable_adapter():
                (
                    reference_chosen_logps,
                    reference_rejected_logps,
                    _,
                    _,
                ) = self.concatenated_forward(self.model, batch)

    losses, chosen_rewards, rejected_rewards = self.dpo_loss(
        policy_chosen_logps,
        policy_rejected_logps,
        reference_chosen_logps,
        reference_rejected_logps,
    )
    reward_accuracies = (chosen_rewards > rejected_rewards).float()

    return losses.mean(), metrics

计算损失共分为3步:

  1. 计算model的

    policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits,

  2. 计算ref model 的

    reference_chosen_logps, reference_rejected_logps,

  3. 计算dpo loss

下面先看concatenated_inputs是如何得出模型logps、logits的

python 复制代码
def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
    concatenated_batch = {}
    max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])

    for k in batch:
        if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
            pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value
            concatenated_key = k.replace("chosen", "concatenated")
            concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
    for k in batch:
        if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
            pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value
            concatenated_key = k.replace("rejected", "concatenated")
            concatenated_batch[concatenated_key] = torch.cat(
                (
                    concatenated_batch[concatenated_key],
                    pad_to_length(batch[k], max_length, pad_value=pad_value),
                ),
                dim=0,
            ).to(self.accelerator.device)

    if self.is_encoder_decoder:
        concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1)
        concatenated_batch["concatenated_attention_mask"] = batch["prompt_attention_mask"].repeat(2, 1)

    return concatenated_batch

先将chosen和rejected拼接在一起

python 复制代码
def concatenated_forward(
    self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
 
    concatenated_batch = self.concatenated_inputs(batch)
    len_chosen = batch["chosen_labels"].shape[0]

    model_kwargs = (
        {
            "labels": concatenated_batch["concatenated_labels"],
            "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
        }
        if self.is_encoder_decoder
        else {}
    )
    all_logits = model(
        concatenated_batch["concatenated_input_ids"],
        attention_mask=concatenated_batch["concatenated_attention_mask"],
        **model_kwargs,
    ).logits.to(torch.float32)

    all_logps = self._get_batch_logps(
        all_logits,
        concatenated_batch["concatenated_labels"],
        average_log_prob=False,
    )

    chosen_logps = all_logps[:len_chosen]
    rejected_logps = all_logps[len_chosen:]

    chosen_logits = all_logits[:len_chosen]
    rejected_logits = all_logits[len_chosen:]

    return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)

再调用model统一计算logps、logits

_get_batch_logps的源码略,比较通用的方法,可以看官github上的源码。

最后是损失计算部分:

python 复制代码
def dpo_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
    reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:

    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps

    if reference_free:
        ref_logratios = 0

    logits = pi_logratios - ref_logratios

    if self.loss_type == "sigmoid":
        losses = -F.logsigmoid(self.beta * logits)
    elif self.loss_type == "hinge":
        losses = torch.relu(1 - self.beta * logits)
    else:
        raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']")

    chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
    rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

    return losses, chosen_rewards, rejected_rewards

损失计算就很简单了,但是其数学推倒很麻烦。dpo训练不需要reward模型,但又需要防止模型训歪,所以简化了损失计算方法。

python 复制代码
_chosen_logps - _rejected_logps

上述代码代替了reward方法的作用。

python 复制代码
pi_logratios - ref_logratios

上述代码防止模型训歪

相关推荐
我爱学Python!5 小时前
面试问我LLM中的RAG,秒过!!!
人工智能·面试·llm·prompt·ai大模型·rag·大模型应用
蛋先生DX8 小时前
网页也能跑大模型?
前端·机器学习·llm
知来者逆9 小时前
探索大型语言模型在文化常识方面的理解能力与局限性
人工智能·gpt·深度学习·语言模型·自然语言处理·chatgpt·llm
卷心菜小温20 小时前
【BUG】P-tuningv2微调ChatGLM2-6B时所踩的坑
python·深度学习·语言模型·nlp·bug
爱喝白开水a1 天前
关于大模型在企业生产环境中的独立部署问题
人工智能·深度学习·llm·大语言模型·ai大模型·计算机技术·本地部署大模型
Langchain1 天前
不可错过!CMU最新《生成式人工智能大模型》课程:从文本、图像到多模态大模型
人工智能·自然语言处理·langchain·大模型·llm·大语言模型·多模态大模型
龙的爹23331 天前
论文翻译 | Generated Knowledge Prompting for Commonsense Reasoning
人工智能·gpt·机器学习·语言模型·自然语言处理·nlp·prompt
龙的爹23331 天前
论文翻译 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·语言模型·自然语言处理·nlp·prompt
幽影相随1 天前
构建llama.cpp并在linux上使用gpu
llm·llama.cpp
AAI机器之心1 天前
LLM大模型:开源RAG框架汇总
人工智能·chatgpt·开源·大模型·llm·大语言模型·rag