大模型的微调主要有以下几个方面:
- 有监督的微调 (Supervised Fine-tuning,SFT)。
- 奖励 / 偏好建模 (Reward / preference modeling,RM)。
- 基于人类反馈的强化学习 (RLHF)。
相关的代码可以在github上访问:github.com/night-is-yo...
本文主要实现了4种模型:
- baichuan
- chatglm3
- qwen
- yi
DPO通过一个简单的分类目标函数直接优化最能满足偏好的策略,没有显式的奖励函数或强化学习过程。本文讲解DPO的实现原理。
zephyr-7b是Hugging Face团队基于Mistral 7B通过dpo微调得到的。
dpo微调分为两个步骤:
- 有监督的微调 (Supervised Fine-tuning,SFT)。
- 直接偏好优化 (Direct Preference Optimization,DPO)
zephyr-7b 的官方源码为:
python
trainer = DPOTrainer(
model,
ref_model,
model_init_kwargs=model_kwargs,
ref_model_init_kwargs=ref_model_kwargs,
args=training_args,
beta=training_args.beta,
train_dataset=raw_datasets["train"],
eval_dataset=raw_datasets["test"],
tokenizer=tokenizer,
max_length=training_args.max_length,
max_prompt_length=training_args.max_prompt_length,
peft_config=get_peft_config(model_args),
loss_type=training_args.loss_type,
)
###############
# Training loop
###############
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(raw_datasets["train"])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
DPOTrainer源码解读
数据加载,DPODataCollatorWithPadding源码解读
python
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
tokenized_batch = []
for feature in features:
prompt = feature["prompt"]
chosen = feature["chosen"]
rejected = feature["rejected"]
batch_element = self.tokenize_batch_element(prompt, chosen, rejected)
tokenized_batch.append(batch_element)
# return collated batch
return self.collate(tokenized_batch)
dpo数据加载分为两个部分
- 先tokenize_batch_element,将输入数据处理成统一的长度
- 调用collate,将批量数据收集到一起
注意,如果要 使用自定义数据集,需要提供以下3个参数:
- prompt
- chosen
- rejected
数据处理的源码,将字符转为token,如果是自定义数据集,一般是放在get_item()方法中。
python
def tokenize_batch_element(
self,with self.accelerator.unwrap_model(self.model).disable_adapter():
prompt: str,
chosen: str,
rejected: str,
) -> Dict:
batch = {}
if not self.is_encoder_decoder:
chosen_tokens = self.tokenizer(chosen, add_special_tokens=False)
rejected_tokens = self.tokenizer(rejected, add_special_tokens=False)
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
eos_token_id = self.tokenizer.eos_token_id
# Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0)
eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id]
# attention mask these indices to eos_token_id
new_attention_mask = [
0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"])
]
prompt_tokens["attention_mask"] = new_attention_mask
# do the same for chosen and rejected
eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id]
new_attention_mask_c = [
0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"])
]
chosen_tokens["attention_mask"] = new_attention_mask_c
eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id]
new_attention_mask_r = [
0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"])
]
rejected_tokens["attention_mask"] = new_attention_mask_r
# add EOS token to end of prompt
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
chosen_tokens["attention_mask"].append(1)
rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
rejected_tokens["attention_mask"].append(1)
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
# if combined sequence is too long, truncate the prompt
if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
if self.truncation_mode == "keep_start":
prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()}
elif self.truncation_mode == "keep_end":
prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()}
else:
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
# if that's still too long, truncate the response
if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()}
rejected_tokens = {
k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items()
}
# Create labels
chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
prompt_tokens["input_ids"]
)
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
prompt_tokens["input_ids"]
)
for k, toks in {
"chosen": chosen_sequence_tokens,
"rejected": rejected_sequence_tokens,
"prompt": prompt_tokens,
}.items():
for type_key, tokens in toks.items():
if type_key == "token_type_ids":
continue
batch[f"{k}_{type_key}"] = tokens
batch["prompt"] = prompt
batch["chosen"] = prompt + chosen
batch["rejected"] = prompt + rejected
batch["chosen_response_only"] = chosen
batch["rejected_response_only"] = rejected
return batch
经过处理得到:
'chosen_input_ids', 'chosen_attention_mask', 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels', 'prompt_input_ids', 'prompt_attention_mask'
'prompt', 'chosen', 'rejected', 'chosen_response_only', 'rejected_response_only'\],原始数据用于日志 ```python def collate(self, batch): # first, pad everything to the same length padded_batch = {} for k in batch[0].keys(): if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): # adapted from https://stackoverflow.com/questions/73256206 if "prompt" in k: to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] else: to_pad = [torch.LongTensor(ex[k]) for ex in batch] if k.endswith("_input_ids"): padding_value = self.tokenizer.pad_token_id elif k.endswith("_labels"): padding_value = self.label_pad_token_id elif k.endswith("_attention_mask"): padding_value = self.padding_value else: raise ValueError(f"Unexpected key in batch '{k}'") padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) # for the prompt, flip back so padding is on left side if "prompt" in k: padded_batch[k] = padded_batch[k].flip(dims=[1]) else: padded_batch[k] = [ex[k] for ex in batch] return padded_batch ``` 值得注意的是,prompt的padding是在左边,为了实现这一功能,注意上述代码片段 ```ini to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) padded_batch[k] = padded_batch[k].flip(dims=[1]) ``` ### DPOTrainer的模型 训练需要两个模型,分别是model、ref model。 对于peft微调来说,ref model用以下代码表示: ```python with self.accelerator.unwrap_model(self.model).disable_adapter(): ``` 也就是说,model和ref model公用一个模型,但 model多了可训练的lora模块,ref model不可训练 如果不是peft微调,则需要传入ref model,但是显存会多占用一倍 ### DPOTrainer计算损失的源码 ```python def compute_loss( self, model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") if self.accelerator.is_main_process: self.store_metrics(metrics, train_eval="train") if return_outputs: return (loss, metrics) return lossA ``` 注意计算损失调用了get_batch_metrics,现在转到该源代码 ```python def get_batch_metrics( self, model, batch: Dict[str, Union[List, torch.LongTensor]], train_eval: Literal["train", "eval"] = "train", ): """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} ( policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, ) = self.concatenated_forward(model, batch) with torch.no_grad(): if self.ref_model is None: with self.accelerator.unwrap_model(self.model).disable_adapter(): ( reference_chosen_logps, reference_rejected_logps, _, _, ) = self.concatenated_forward(self.model, batch) losses, chosen_rewards, rejected_rewards = self.dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, ) reward_accuracies = (chosen_rewards > rejected_rewards).float() return losses.mean(), metrics ``` 计算损失共分为3步: 1. 计算model的 policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, 2. 计算ref model 的 reference_chosen_logps, reference_rejected_logps, 3. 计算dpo loss 下面先看concatenated_inputs是如何得出模型logps、logits的 ```python def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: concatenated_batch = {} max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) for k in batch: if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value concatenated_key = k.replace("chosen", "concatenated") concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) for k in batch: if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value concatenated_key = k.replace("rejected", "concatenated") concatenated_batch[concatenated_key] = torch.cat( ( concatenated_batch[concatenated_key], pad_to_length(batch[k], max_length, pad_value=pad_value), ), dim=0, ).to(self.accelerator.device) if self.is_encoder_decoder: concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1) concatenated_batch["concatenated_attention_mask"] = batch["prompt_attention_mask"].repeat(2, 1) return concatenated_batch ``` 先将chosen和rejected拼接在一起 ```python def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: concatenated_batch = self.concatenated_inputs(batch) len_chosen = batch["chosen_labels"].shape[0] model_kwargs = ( { "labels": concatenated_batch["concatenated_labels"], "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), } if self.is_encoder_decoder else {} ) all_logits = model( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], **model_kwargs, ).logits.to(torch.float32) all_logps = self._get_batch_logps( all_logits, concatenated_batch["concatenated_labels"], average_log_prob=False, ) chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) ``` 再调用model统一计算logps、logits _get_batch_logps的源码略,比较通用的方法,可以看官github上的源码。 #### 最后是损失计算部分: ```python def dpo_loss( self, policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, reference_free: bool = False, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps if reference_free: ref_logratios = 0 logits = pi_logratios - ref_logratios if self.loss_type == "sigmoid": losses = -F.logsigmoid(self.beta * logits) elif self.loss_type == "hinge": losses = torch.relu(1 - self.beta * logits) else: raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']") chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() return losses, chosen_rewards, rejected_rewards ``` 损失计算就很简单了,但是其数学推倒很麻烦。dpo训练不需要reward模型,但又需要防止模型训歪,所以简化了损失计算方法。 ```python _chosen_logps - _rejected_logps ``` 上述代码代替了reward方法的作用。 ```python pi_logratios - ref_logratios ``` 上述代码防止模型训歪