DPO微调

1.基本原理

直接利用"优选回答 vs. 次优回答"的成对偏好数据,对策略模型进行对数概率差分优化。

2.代码实现

python 复制代码
# -*- coding: utf-8 -*-

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM

DATA_PATH = "train.json"
"""
{"prompt": "解释梯度下降。", "chosen": "梯度下降是一种优化算法", "rejected": "梯度下降是天气现象"}
{"prompt": "什么是DPO?", "chosen": "DPO是一种偏好优化方法", "rejected": "DPO是一种显卡接口"}
"""
MODEL_NAME = "Qwen/Qwen3-8B"
OUTPUT_DIR = "./qwen3_lora_dpo"

BATCH_SIZE = 1
EPOCHS = 500
LR = 2e-5
BETA = 0.1
MAX_LEN = 512

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def preprocess(example):
    prompt = example["prompt"]
    chosen = example["chosen"]
    rejected = example["rejected"]

    chosen_text = prompt + "\n" + chosen
    rejected_text = prompt + "\n" + rejected

    chosen_enc = tokenizer(
        chosen_text,
        max_length=MAX_LEN,
        padding="max_length",
        truncation=True,
    )

    rejected_enc = tokenizer(
        rejected_text,
        max_length=MAX_LEN,
        padding="max_length",
        truncation=True,
    )

    return {
        "chosen_input_ids": chosen_enc["input_ids"],
        "chosen_attention_mask": chosen_enc["attention_mask"],
        "rejected_input_ids": rejected_enc["input_ids"],
        "rejected_attention_mask": rejected_enc["attention_mask"],
    }

dataset = load_dataset("json", data_files=DATA_PATH)["train"]
dataset = dataset.map(preprocess)
dataset.set_format(type="torch")

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

reference_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

reference_model.eval()
for p in reference_model.parameters():
    p.requires_grad = False

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(base_model, lora_config)
model.gradient_checkpointing_enable()
model.print_trainable_parameters()
model.train()

device = next(model.parameters()).device

def compute_log_probs(model, input_ids, attention_mask):
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask
    )
    logits = outputs.logits[:, :-1, :]
    labels = input_ids[:, 1:]

    log_probs = F.log_softmax(logits, dim=-1)

    per_token_logps = torch.gather(
        log_probs,
        dim=2,
        index=labels.unsqueeze(-1)
    ).squeeze(-1)

    mask = attention_mask[:, 1:]
    return (per_token_logps * mask).sum(dim=1)

def dpo_loss(policy_c, policy_r, ref_c, ref_r, beta):
    pi_logratios = policy_c - policy_r
    ref_logratios = ref_c - ref_r
    logits = beta * (pi_logratios - ref_logratios)
    return -F.logsigmoid(logits).mean()

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(EPOCHS):
    total_loss = 0.0
    for step, batch in enumerate(dataloader):
        chosen_input_ids = batch["chosen_input_ids"].to(device)
        chosen_attention_mask = batch["chosen_attention_mask"].to(device)
        rejected_input_ids = batch["rejected_input_ids"].to(device)
        rejected_attention_mask = batch["rejected_attention_mask"].to(device)

        with torch.cuda.amp.autocast():

            policy_c = compute_log_probs(
                model,
                chosen_input_ids,
                chosen_attention_mask
            )

            policy_r = compute_log_probs(
                model,
                rejected_input_ids,
                rejected_attention_mask
            )

            with torch.no_grad():
                ref_c = compute_log_probs(
                    reference_model,
                    chosen_input_ids,
                    chosen_attention_mask
                )

                ref_r = compute_log_probs(
                    reference_model,
                    rejected_input_ids,
                    rejected_attention_mask
                )

            loss = dpo_loss(policy_c, policy_r, ref_c, ref_r, BETA)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

        if step % 10 == 0:
            print(f"Epoch {epoch} Step {step} Loss {loss.item():.4f}")

    print(f"Epoch {epoch} Mean Loss {total_loss/len(dataloader):.4f}")

model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print("训练完成")

运行效果:

bash 复制代码
Loading checkpoint shards: 100%|██████████| 8/8 [00:15<00:00, 1.87s/it]
Loading checkpoint shards: 100%|██████████| 8/8 [00:14<00:00, 1.75s/it]
trainable params: 1,572,864 || all params: 8,000,000,000 || trainable%: 0.0197
Epoch 0 Step 0 Loss 0.6931
Epoch 0 Step 10 Loss 0.6905
Epoch 0 Step 20 Loss 0.6872
Epoch 0 Step 30 Loss 0.6714
Epoch 0 Step 40 Loss 0.6589
Epoch 0 Step 50 Loss 0.6231
...
Epoch 0 Step 490 Loss 0.4521
Epoch 0 Step 500 Loss 0.4103
Epoch 0 Mean Loss 0.5842
相关推荐
IT_陈寒3 分钟前
Redis性能提升3倍的5个冷门技巧,90%开发者都不知道!
前端·人工智能·后端
Rsun045516 分钟前
SpringAI相关内容
人工智能
yc_Blog12 分钟前
卷积神经网络是什么:从图像识别问题说起
人工智能·神经网络·cnn
love530love22 分钟前
ComfyUI rgthree-comfy Image Comparer 节点无输出问题排查与解决
人工智能·windows·python·comfyui·rgthree-comfy·nodes 2.0·vue 节点
新缸中之脑34 分钟前
应该使用AI构建内部工具吗?
人工智能
badhope38 分钟前
Docker从零开始安装配置全攻略
运维·人工智能·vscode·python·docker·容器·github
AI攻城狮1 小时前
lossless-claw vs mem0:别再把上下文管理和长期记忆混为一谈
人工智能·云原生·aigc
qq_349523261 小时前
OpenClaw 架构全解析:本地优先的开源 AI Agent 框架
人工智能·架构·开源
寻见9031 小时前
智能体开发_07Function Calling道法术器拆解,一文搞懂大模型如何“做事”
人工智能·后端·ai编程
未来之窗软件服务1 小时前
vosk-ASR asterisk调用[AI人工智能(五十三)]—东方仙盟
人工智能·语音识别·vosk·仙盟创梦ide·东方仙盟