DPO微调

1.基本原理

直接利用"优选回答 vs. 次优回答"的成对偏好数据,对策略模型进行对数概率差分优化。

2.代码实现

python 复制代码
# -*- coding: utf-8 -*-

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM

DATA_PATH = "train.json"
"""
{"prompt": "解释梯度下降。", "chosen": "梯度下降是一种优化算法", "rejected": "梯度下降是天气现象"}
{"prompt": "什么是DPO?", "chosen": "DPO是一种偏好优化方法", "rejected": "DPO是一种显卡接口"}
"""
MODEL_NAME = "Qwen/Qwen3-8B"
OUTPUT_DIR = "./qwen3_lora_dpo"

BATCH_SIZE = 1
EPOCHS = 500
LR = 2e-5
BETA = 0.1
MAX_LEN = 512

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def preprocess(example):
    prompt = example["prompt"]
    chosen = example["chosen"]
    rejected = example["rejected"]

    chosen_text = prompt + "\n" + chosen
    rejected_text = prompt + "\n" + rejected

    chosen_enc = tokenizer(
        chosen_text,
        max_length=MAX_LEN,
        padding="max_length",
        truncation=True,
    )

    rejected_enc = tokenizer(
        rejected_text,
        max_length=MAX_LEN,
        padding="max_length",
        truncation=True,
    )

    return {
        "chosen_input_ids": chosen_enc["input_ids"],
        "chosen_attention_mask": chosen_enc["attention_mask"],
        "rejected_input_ids": rejected_enc["input_ids"],
        "rejected_attention_mask": rejected_enc["attention_mask"],
    }

dataset = load_dataset("json", data_files=DATA_PATH)["train"]
dataset = dataset.map(preprocess)
dataset.set_format(type="torch")

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

reference_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

reference_model.eval()
for p in reference_model.parameters():
    p.requires_grad = False

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(base_model, lora_config)
model.gradient_checkpointing_enable()
model.print_trainable_parameters()
model.train()

device = next(model.parameters()).device

def compute_log_probs(model, input_ids, attention_mask):
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask
    )
    logits = outputs.logits[:, :-1, :]
    labels = input_ids[:, 1:]

    log_probs = F.log_softmax(logits, dim=-1)

    per_token_logps = torch.gather(
        log_probs,
        dim=2,
        index=labels.unsqueeze(-1)
    ).squeeze(-1)

    mask = attention_mask[:, 1:]
    return (per_token_logps * mask).sum(dim=1)

def dpo_loss(policy_c, policy_r, ref_c, ref_r, beta):
    pi_logratios = policy_c - policy_r
    ref_logratios = ref_c - ref_r
    logits = beta * (pi_logratios - ref_logratios)
    return -F.logsigmoid(logits).mean()

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(EPOCHS):
    total_loss = 0.0
    for step, batch in enumerate(dataloader):
        chosen_input_ids = batch["chosen_input_ids"].to(device)
        chosen_attention_mask = batch["chosen_attention_mask"].to(device)
        rejected_input_ids = batch["rejected_input_ids"].to(device)
        rejected_attention_mask = batch["rejected_attention_mask"].to(device)

        with torch.cuda.amp.autocast():

            policy_c = compute_log_probs(
                model,
                chosen_input_ids,
                chosen_attention_mask
            )

            policy_r = compute_log_probs(
                model,
                rejected_input_ids,
                rejected_attention_mask
            )

            with torch.no_grad():
                ref_c = compute_log_probs(
                    reference_model,
                    chosen_input_ids,
                    chosen_attention_mask
                )

                ref_r = compute_log_probs(
                    reference_model,
                    rejected_input_ids,
                    rejected_attention_mask
                )

            loss = dpo_loss(policy_c, policy_r, ref_c, ref_r, BETA)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

        if step % 10 == 0:
            print(f"Epoch {epoch} Step {step} Loss {loss.item():.4f}")

    print(f"Epoch {epoch} Mean Loss {total_loss/len(dataloader):.4f}")

model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print("训练完成")

运行效果:

bash 复制代码
Loading checkpoint shards: 100%|██████████| 8/8 [00:15<00:00, 1.87s/it]
Loading checkpoint shards: 100%|██████████| 8/8 [00:14<00:00, 1.75s/it]
trainable params: 1,572,864 || all params: 8,000,000,000 || trainable%: 0.0197
Epoch 0 Step 0 Loss 0.6931
Epoch 0 Step 10 Loss 0.6905
Epoch 0 Step 20 Loss 0.6872
Epoch 0 Step 30 Loss 0.6714
Epoch 0 Step 40 Loss 0.6589
Epoch 0 Step 50 Loss 0.6231
...
Epoch 0 Step 490 Loss 0.4521
Epoch 0 Step 500 Loss 0.4103
Epoch 0 Mean Loss 0.5842
相关推荐
汽车仪器仪表相关领域1 分钟前
Kvaser Memorator R SemiPro:双通道CAN总线记录仪,汽车与工业测试的高性价比之选
大数据·网络·人工智能·功能测试·汽车·安全性测试
天天爱吃肉82182 分钟前
空间智能上车:新能源OEM决胜「第三空间」的底层技术革命|研发工程师深度解析
大数据·人工智能·嵌入式硬件·汽车
初圣魔门首席弟子2 分钟前
深度学习 欠拟合、过拟合讲透
人工智能
开开心心就好3 分钟前
支持批量添加水印的实用工具推荐
人工智能·游戏·ci/cd·docker·音视频·语音识别·媒体
毕胜客源码4 分钟前
卷积神经网络的手势识别系统(有技术文档)深度学习 图像识别 卷积神经网络 Django python 人工智能
人工智能·python·深度学习·cnn·django
戏言zare5 分钟前
基于改进EfficientNet的植物性状预测系统设计
人工智能
Elastic 中国社区官方博客8 分钟前
通过受管控的控制平面加速商品陈列优化
大数据·数据库·人工智能·elasticsearch·搜索引擎·平面·ai
CoderJia程序员甲10 分钟前
GitHub 热榜项目 - 日榜(2026-04-28)
人工智能·ai·大模型·github·ai教程
我是大聪明.13 分钟前
大模型Tokenizer原理:BPE、WordPiece与子词编码的核心机制深度解析
人工智能·线性代数·算法·机器学习·矩阵
hhhhhh_we15 分钟前
再定义“皮肤人格”:从Baumann 16型分型到预颜美历的AI时序人格
前端·图像处理·人工智能·python·aigc