DPO微调

1.基本原理

直接利用"优选回答 vs. 次优回答"的成对偏好数据,对策略模型进行对数概率差分优化。

2.代码实现

python 复制代码
# -*- coding: utf-8 -*-

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM

DATA_PATH = "train.json"
"""
{"prompt": "解释梯度下降。", "chosen": "梯度下降是一种优化算法", "rejected": "梯度下降是天气现象"}
{"prompt": "什么是DPO?", "chosen": "DPO是一种偏好优化方法", "rejected": "DPO是一种显卡接口"}
"""
MODEL_NAME = "Qwen/Qwen3-8B"
OUTPUT_DIR = "./qwen3_lora_dpo"

BATCH_SIZE = 1
EPOCHS = 500
LR = 2e-5
BETA = 0.1
MAX_LEN = 512

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def preprocess(example):
    prompt = example["prompt"]
    chosen = example["chosen"]
    rejected = example["rejected"]

    chosen_text = prompt + "\n" + chosen
    rejected_text = prompt + "\n" + rejected

    chosen_enc = tokenizer(
        chosen_text,
        max_length=MAX_LEN,
        padding="max_length",
        truncation=True,
    )

    rejected_enc = tokenizer(
        rejected_text,
        max_length=MAX_LEN,
        padding="max_length",
        truncation=True,
    )

    return {
        "chosen_input_ids": chosen_enc["input_ids"],
        "chosen_attention_mask": chosen_enc["attention_mask"],
        "rejected_input_ids": rejected_enc["input_ids"],
        "rejected_attention_mask": rejected_enc["attention_mask"],
    }

dataset = load_dataset("json", data_files=DATA_PATH)["train"]
dataset = dataset.map(preprocess)
dataset.set_format(type="torch")

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

reference_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

reference_model.eval()
for p in reference_model.parameters():
    p.requires_grad = False

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(base_model, lora_config)
model.gradient_checkpointing_enable()
model.print_trainable_parameters()
model.train()

device = next(model.parameters()).device

def compute_log_probs(model, input_ids, attention_mask):
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask
    )
    logits = outputs.logits[:, :-1, :]
    labels = input_ids[:, 1:]

    log_probs = F.log_softmax(logits, dim=-1)

    per_token_logps = torch.gather(
        log_probs,
        dim=2,
        index=labels.unsqueeze(-1)
    ).squeeze(-1)

    mask = attention_mask[:, 1:]
    return (per_token_logps * mask).sum(dim=1)

def dpo_loss(policy_c, policy_r, ref_c, ref_r, beta):
    pi_logratios = policy_c - policy_r
    ref_logratios = ref_c - ref_r
    logits = beta * (pi_logratios - ref_logratios)
    return -F.logsigmoid(logits).mean()

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(EPOCHS):
    total_loss = 0.0
    for step, batch in enumerate(dataloader):
        chosen_input_ids = batch["chosen_input_ids"].to(device)
        chosen_attention_mask = batch["chosen_attention_mask"].to(device)
        rejected_input_ids = batch["rejected_input_ids"].to(device)
        rejected_attention_mask = batch["rejected_attention_mask"].to(device)

        with torch.cuda.amp.autocast():

            policy_c = compute_log_probs(
                model,
                chosen_input_ids,
                chosen_attention_mask
            )

            policy_r = compute_log_probs(
                model,
                rejected_input_ids,
                rejected_attention_mask
            )

            with torch.no_grad():
                ref_c = compute_log_probs(
                    reference_model,
                    chosen_input_ids,
                    chosen_attention_mask
                )

                ref_r = compute_log_probs(
                    reference_model,
                    rejected_input_ids,
                    rejected_attention_mask
                )

            loss = dpo_loss(policy_c, policy_r, ref_c, ref_r, BETA)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

        if step % 10 == 0:
            print(f"Epoch {epoch} Step {step} Loss {loss.item():.4f}")

    print(f"Epoch {epoch} Mean Loss {total_loss/len(dataloader):.4f}")

model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print("训练完成")

运行效果:

bash 复制代码
Loading checkpoint shards: 100%|██████████| 8/8 [00:15<00:00, 1.87s/it]
Loading checkpoint shards: 100%|██████████| 8/8 [00:14<00:00, 1.75s/it]
trainable params: 1,572,864 || all params: 8,000,000,000 || trainable%: 0.0197
Epoch 0 Step 0 Loss 0.6931
Epoch 0 Step 10 Loss 0.6905
Epoch 0 Step 20 Loss 0.6872
Epoch 0 Step 30 Loss 0.6714
Epoch 0 Step 40 Loss 0.6589
Epoch 0 Step 50 Loss 0.6231
...
Epoch 0 Step 490 Loss 0.4521
Epoch 0 Step 500 Loss 0.4103
Epoch 0 Mean Loss 0.5842
相关推荐
天使Di María1 小时前
脑电大模型系列——第二弹:BrainBERT
人工智能·深度学习·机器学习·大模型·迁移学习·脑机接口·脑电解码
Dev7z1 小时前
基于LSTM神经网络的金属材料机器学习本构模型研究(硕士级别)
人工智能·神经网络·机器学习·机器学习本构
AI_Auto1 小时前
人工智能 - AI重构企业数字化格局
人工智能·重构
雪碧聊技术1 小时前
什么是Seedance 2.0?字节自研多模态AI视频生成引擎全解析
人工智能·音视频·seedance2.0
陈天伟教授1 小时前
人工智能应用- 材料微观:08.SliceGAN 的学习过程
人工智能·深度学习·学习
心本无晴.2 小时前
RAG中的混合检索(Hybrid Search):稀疏检索与稠密检索的强强联合
人工智能·python·算法
咚咚王者2 小时前
人工智能之视觉领域 计算机视觉 第十三章 视频背景减除
人工智能·计算机视觉·音视频
你的论文学长2 小时前
对抗知网的 N-Gram 算法:基于语义解耦的【文本重构】与【事实性核验】架构设计
人工智能·算法·重构
一水鉴天2 小时前
关于“整体设计定稿” 的高阶表述 20260222
人工智能·架构