1.基本原理
直接利用"优选回答 vs. 次优回答"的成对偏好数据,对策略模型进行对数概率差分优化。
2.代码实现
python
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
DATA_PATH = "train.json"
"""
{"prompt": "解释梯度下降。", "chosen": "梯度下降是一种优化算法", "rejected": "梯度下降是天气现象"}
{"prompt": "什么是DPO?", "chosen": "DPO是一种偏好优化方法", "rejected": "DPO是一种显卡接口"}
"""
MODEL_NAME = "Qwen/Qwen3-8B"
OUTPUT_DIR = "./qwen3_lora_dpo"
BATCH_SIZE = 1
EPOCHS = 500
LR = 2e-5
BETA = 0.1
MAX_LEN = 512
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def preprocess(example):
prompt = example["prompt"]
chosen = example["chosen"]
rejected = example["rejected"]
chosen_text = prompt + "\n" + chosen
rejected_text = prompt + "\n" + rejected
chosen_enc = tokenizer(
chosen_text,
max_length=MAX_LEN,
padding="max_length",
truncation=True,
)
rejected_enc = tokenizer(
rejected_text,
max_length=MAX_LEN,
padding="max_length",
truncation=True,
)
return {
"chosen_input_ids": chosen_enc["input_ids"],
"chosen_attention_mask": chosen_enc["attention_mask"],
"rejected_input_ids": rejected_enc["input_ids"],
"rejected_attention_mask": rejected_enc["attention_mask"],
}
dataset = load_dataset("json", data_files=DATA_PATH)["train"]
dataset = dataset.map(preprocess)
dataset.set_format(type="torch")
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
reference_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
reference_model.eval()
for p in reference_model.parameters():
p.requires_grad = False
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(base_model, lora_config)
model.gradient_checkpointing_enable()
model.print_trainable_parameters()
model.train()
device = next(model.parameters()).device
def compute_log_probs(model, input_ids, attention_mask):
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs.logits[:, :-1, :]
labels = input_ids[:, 1:]
log_probs = F.log_softmax(logits, dim=-1)
per_token_logps = torch.gather(
log_probs,
dim=2,
index=labels.unsqueeze(-1)
).squeeze(-1)
mask = attention_mask[:, 1:]
return (per_token_logps * mask).sum(dim=1)
def dpo_loss(policy_c, policy_r, ref_c, ref_r, beta):
pi_logratios = policy_c - policy_r
ref_logratios = ref_c - ref_r
logits = beta * (pi_logratios - ref_logratios)
return -F.logsigmoid(logits).mean()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler()
for epoch in range(EPOCHS):
total_loss = 0.0
for step, batch in enumerate(dataloader):
chosen_input_ids = batch["chosen_input_ids"].to(device)
chosen_attention_mask = batch["chosen_attention_mask"].to(device)
rejected_input_ids = batch["rejected_input_ids"].to(device)
rejected_attention_mask = batch["rejected_attention_mask"].to(device)
with torch.cuda.amp.autocast():
policy_c = compute_log_probs(
model,
chosen_input_ids,
chosen_attention_mask
)
policy_r = compute_log_probs(
model,
rejected_input_ids,
rejected_attention_mask
)
with torch.no_grad():
ref_c = compute_log_probs(
reference_model,
chosen_input_ids,
chosen_attention_mask
)
ref_r = compute_log_probs(
reference_model,
rejected_input_ids,
rejected_attention_mask
)
loss = dpo_loss(policy_c, policy_r, ref_c, ref_r, BETA)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
if step % 10 == 0:
print(f"Epoch {epoch} Step {step} Loss {loss.item():.4f}")
print(f"Epoch {epoch} Mean Loss {total_loss/len(dataloader):.4f}")
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("训练完成")
运行效果:
bash
Loading checkpoint shards: 100%|██████████| 8/8 [00:15<00:00, 1.87s/it]
Loading checkpoint shards: 100%|██████████| 8/8 [00:14<00:00, 1.75s/it]
trainable params: 1,572,864 || all params: 8,000,000,000 || trainable%: 0.0197
Epoch 0 Step 0 Loss 0.6931
Epoch 0 Step 10 Loss 0.6905
Epoch 0 Step 20 Loss 0.6872
Epoch 0 Step 30 Loss 0.6714
Epoch 0 Step 40 Loss 0.6589
Epoch 0 Step 50 Loss 0.6231
...
Epoch 0 Step 490 Loss 0.4521
Epoch 0 Step 500 Loss 0.4103
Epoch 0 Mean Loss 0.5842