大模型的微调主要有以下几个方面:
- 有监督的微调 (Supervised Fine-tuning,SFT)。
- 奖励 / 偏好建模 (Reward / preference modeling,RM)。
- 基于人类反馈的强化学习 (RLHF)。
相关的代码可以在github上访问:github.com/night-is-yo...
本文主要实现了4种模型:
- baichuan
- chatglm3
- qwen
- yi
DPO通过一个简单的分类目标函数直接优化最能满足偏好的策略,没有显式的奖励函数或强化学习过程。本文讲解DPO的实现原理。
zephyr-7b是Hugging Face团队基于Mistral 7B通过dpo微调得到的。
dpo微调分为两个步骤:
- 有监督的微调 (Supervised Fine-tuning,SFT)。
- 直接偏好优化 (Direct Preference Optimization,DPO)
zephyr-7b 的官方源码为:
python
trainer = DPOTrainer(
model,
ref_model,
model_init_kwargs=model_kwargs,
ref_model_init_kwargs=ref_model_kwargs,
args=training_args,
beta=training_args.beta,
train_dataset=raw_datasets["train"],
eval_dataset=raw_datasets["test"],
tokenizer=tokenizer,
max_length=training_args.max_length,
max_prompt_length=training_args.max_prompt_length,
peft_config=get_peft_config(model_args),
loss_type=training_args.loss_type,
)
###############
# Training loop
###############
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(raw_datasets["train"])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
DPOTrainer源码解读
数据加载,DPODataCollatorWithPadding源码解读
python
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
tokenized_batch = []
for feature in features:
prompt = feature["prompt"]
chosen = feature["chosen"]
rejected = feature["rejected"]
batch_element = self.tokenize_batch_element(prompt, chosen, rejected)
tokenized_batch.append(batch_element)
# return collated batch
return self.collate(tokenized_batch)
dpo数据加载分为两个部分
- 先tokenize_batch_element,将输入数据处理成统一的长度
- 调用collate,将批量数据收集到一起
注意,如果要 使用自定义数据集,需要提供以下3个参数:
- prompt
- chosen
- rejected
数据处理的源码,将字符转为token,如果是自定义数据集,一般是放在get_item()方法中。
python
def tokenize_batch_element(
self,with self.accelerator.unwrap_model(self.model).disable_adapter():
prompt: str,
chosen: str,
rejected: str,
) -> Dict:
batch = {}
if not self.is_encoder_decoder:
chosen_tokens = self.tokenizer(chosen, add_special_tokens=False)
rejected_tokens = self.tokenizer(rejected, add_special_tokens=False)
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
eos_token_id = self.tokenizer.eos_token_id
# Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0)
eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id]
# attention mask these indices to eos_token_id
new_attention_mask = [
0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"])
]
prompt_tokens["attention_mask"] = new_attention_mask
# do the same for chosen and rejected
eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id]
new_attention_mask_c = [
0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"])
]
chosen_tokens["attention_mask"] = new_attention_mask_c
eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id]
new_attention_mask_r = [
0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"])
]
rejected_tokens["attention_mask"] = new_attention_mask_r
# add EOS token to end of prompt
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
chosen_tokens["attention_mask"].append(1)
rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
rejected_tokens["attention_mask"].append(1)
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
# if combined sequence is too long, truncate the prompt
if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
if self.truncation_mode == "keep_start":
prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()}
elif self.truncation_mode == "keep_end":
prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()}
else:
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
# if that's still too long, truncate the response
if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()}
rejected_tokens = {
k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items()
}
# Create labels
chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
prompt_tokens["input_ids"]
)
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
prompt_tokens["input_ids"]
)
for k, toks in {
"chosen": chosen_sequence_tokens,
"rejected": rejected_sequence_tokens,
"prompt": prompt_tokens,
}.items():
for type_key, tokens in toks.items():
if type_key == "token_type_ids":
continue
batch[f"{k}_{type_key}"] = tokens
batch["prompt"] = prompt
batch["chosen"] = prompt + chosen
batch["rejected"] = prompt + rejected
batch["chosen_response_only"] = chosen
batch["rejected_response_only"] = rejected
return batch
经过处理得到:
['chosen_input_ids', 'chosen_attention_mask', 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels', 'prompt_input_ids', 'prompt_attention_mask']
[ 'prompt', 'chosen', 'rejected', 'chosen_response_only', 'rejected_response_only'],原始数据用于日志
python
def collate(self, batch):
# first, pad everything to the same length
padded_batch = {}
for k in batch[0].keys():
if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
# adapted from https://stackoverflow.com/questions/73256206
if "prompt" in k:
to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
else:
to_pad = [torch.LongTensor(ex[k]) for ex in batch]
if k.endswith("_input_ids"):
padding_value = self.tokenizer.pad_token_id
elif k.endswith("_labels"):
padding_value = self.label_pad_token_id
elif k.endswith("_attention_mask"):
padding_value = self.padding_value
else:
raise ValueError(f"Unexpected key in batch '{k}'")
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
# for the prompt, flip back so padding is on left side
if "prompt" in k:
padded_batch[k] = padded_batch[k].flip(dims=[1])
else:
padded_batch[k] = [ex[k] for ex in batch]
return padded_batch
值得注意的是,prompt的padding是在左边,为了实现这一功能,注意上述代码片段
ini
to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
padded_batch[k] = padded_batch[k].flip(dims=[1])
DPOTrainer的模型
训练需要两个模型,分别是model、ref model。
对于peft微调来说,ref model用以下代码表示:
python
with self.accelerator.unwrap_model(self.model).disable_adapter():
也就是说,model和ref model公用一个模型,但 model多了可训练的lora模块,ref model不可训练
如果不是peft微调,则需要传入ref model,但是显存会多占用一倍
DPOTrainer计算损失的源码
python
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
if self.accelerator.is_main_process:
self.store_metrics(metrics, train_eval="train")
if return_outputs:
return (loss, metrics)
return lossA
注意计算损失调用了get_batch_metrics,现在转到该源代码
python
def get_batch_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.model, batch)
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
return losses.mean(), metrics
计算损失共分为3步:
-
计算model的
policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits,
-
计算ref model 的
reference_chosen_logps, reference_rejected_logps,
-
计算dpo loss
下面先看concatenated_inputs是如何得出模型logps、logits的
python
def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
concatenated_batch = {}
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
for k in batch:
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value
concatenated_key = k.replace("chosen", "concatenated")
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
for k in batch:
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
pad_value = self.label_pad_token_id if "labels" in k or self.is_encoder_decoder else self.padding_value
concatenated_key = k.replace("rejected", "concatenated")
concatenated_batch[concatenated_key] = torch.cat(
(
concatenated_batch[concatenated_key],
pad_to_length(batch[k], max_length, pad_value=pad_value),
),
dim=0,
).to(self.accelerator.device)
if self.is_encoder_decoder:
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1)
concatenated_batch["concatenated_attention_mask"] = batch["prompt_attention_mask"].repeat(2, 1)
return concatenated_batch
先将chosen和rejected拼接在一起
python
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
concatenated_batch = self.concatenated_inputs(batch)
len_chosen = batch["chosen_labels"].shape[0]
model_kwargs = (
{
"labels": concatenated_batch["concatenated_labels"],
"decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
}
if self.is_encoder_decoder
else {}
)
all_logits = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
**model_kwargs,
).logits.to(torch.float32)
all_logps = self._get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=False,
)
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
再调用model统一计算logps、logits
_get_batch_logps的源码略,比较通用的方法,可以看官github上的源码。
最后是损失计算部分:
python
def dpo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
if reference_free:
ref_logratios = 0
logits = pi_logratios - ref_logratios
if self.loss_type == "sigmoid":
losses = -F.logsigmoid(self.beta * logits)
elif self.loss_type == "hinge":
losses = torch.relu(1 - self.beta * logits)
else:
raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']")
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
return losses, chosen_rewards, rejected_rewards
损失计算就很简单了,但是其数学推倒很麻烦。dpo训练不需要reward模型,但又需要防止模型训歪,所以简化了损失计算方法。
python
_chosen_logps - _rejected_logps
上述代码代替了reward方法的作用。
python
pi_logratios - ref_logratios
上述代码防止模型训歪