Med-R1论文阅读理解

论文介绍

这篇论文介绍了一个名为 Med-R1 的新方法,用于提升多模态视觉语言模型(VLM)在医学图像理解和推理任务中的泛化能力和可解释性。下面是对整篇论文的简洁总结:

🧠 核心思想

复制代码
•	当前医学 VLM 多依赖于监督微调(SFT),容易过拟合,缺乏泛化能力,且推理过程不可解释。
•	Med-R1 引入了强化学习(RL),采用GRPO(Group Relative Policy Optimization)方法,能在无需复杂标注的情况下,引导模型学习具备医学推理逻辑的回答路径。

🧰 方法框架

复制代码
1.	基础模型:以 Qwen2-VL-2B 为底座进行训练。
2.	强化学习方法:采用 GRPO,不依赖价值模型,利用规则和分组比较得到奖励信号,提升训练效率。
3.	奖励机制:
•	格式奖励(结构是否规范,如包含 <think>...</think>)
•	准确性奖励(预测结果是否正确)

🏥 应用范围

复制代码
•	在 8 类医学图像模态 上进行测试:
•	CT、MRI、X-ray、超声、皮肤镜、眼底、OCT、显微图像
•	针对 5 类医学任务:
•	器官识别、疾病诊断、病灶分级、模态识别、生物特征分析

📊 实验亮点

对比项 成绩

相比 Qwen2-VL-2B(SFT) 提升 15.84%

相比原始 Qwen2-VL-2B(Zero-shot) 提升 29.94%

跨任务泛化能力 比 Qwen2-VL-2B 提升 32.06%

模型规模 仅 2B 参数,超过了 Qwen2-VL-72B(72B)大模型

✅ 贡献总结

复制代码
1.	提出 Med-R1:首个支持 8 类医学模态的 RL 医学 VLM,具备结构化推理输出。
2.	使用 GRPO:用规则奖励替代人工标注,提升训练效率和可解释性。
3.	小模型超大模型:在医疗领域实现更优的性能和资源效率。

📌 总结一句话:

Med-R1 利用强化学习实现高效、可解释且泛化性强的医学多模态问答系统,为医学大模型实用化提供了新范式。

GRPO

当然可以!GRPO 是强化学习中一个新颖的策略优化方法,特别适合用于大语言模型(LLM)等需要偏好引导、排序学习、推理强化的任务。

🧠 GRPO 是什么?

GRPO 全称是:Group Relative Policy Optimization(群体相对策略优化)

它是一种 无偏好标注、无价值函数 的强化学习方法,适用于对模型输出进行相对排序优化的场景。

🎯 GRPO 核心思想

不再只是"哪一个回答更好",而是"在一组回答中,哪些更好",以排序优化代替单一偏好训练。

🔍 GRPO 的关键特征:

特性 说明

✅ 群体对比 每次优化考虑一组候选输出(Group),进行相对排序

✅ 无需 reward model 不依赖人工打分或偏好模型(不像 PPO / DPO)

✅ 可插入规则奖励 如结构格式、医学准确性等

✅ 优化方式 类似 ListNet / ListMLE 的排序 loss,用策略梯度更新

🏥 在 Med-R1 中的使用方式:

• 每个问题让模型生成多个回答(构成一个 group)

• 使用规则打分(结构合理性、医学正确性)

• 对回答进行排序比较

• 训练模型使其更倾向于产生更靠前的回答

📌 GRPO vs DPO/PPO 对比表:

方法 偏好来源 是否需要价值函数 支持排序优化 适合任务

PPO Reward model ✅ 是 ❌ 只更新打分高的 RLHF、对齐

DPO 人工偏好对 ❌ 否 ❌ 二选一偏好 LLM 对齐

GRPO 无监督排序规则 ❌ 否 ✅ 群体排序优化 推理类、多步任务

📦 总结一句话:

GRPO 是一种无需人工打分的强化学习方法,基于组内排序信号优化模型输出,在复杂任务中兼顾效率、性能和可解释性。

当然可以!以下是一个简化版的 GRPO(Group Relative Policy Optimization)伪代码实现示例,用于帮助你理解其整体流程和核心逻辑。

🧠 GRPO 实现要点:

• 给定一个 prompt,模型生成多个候选回答(组成一个 group)

• 对这些回答进行打分排序(规则或函数)

• 使用排序 loss(如 ListNet) 计算奖励

• 通过策略梯度更新模型参数

✅ 简化代码框架(伪代码 + PyTorch风格)

python 复制代码
import torch
import torch.nn.functional as F

def grpo_loss(log_probs, rewards):
    """
    GRPO 的核心 loss:排序感知 loss(比如 ListNet)
    log_probs: shape [B, G],B 是 batch,G 是 group size
    rewards: shape [B, G],评分(越高越好)
    """
    # 归一化奖励(softmax 排序分布)
    reward_dist = F.softmax(rewards, dim=1)  # [B, G]
    log_prob_dist = F.log_softmax(log_probs, dim=1)  # [B, G]

    # ListNet 损失:KL(reward_dist || log_prob_dist)
    loss = F.kl_div(log_prob_dist, reward_dist, reduction="batchmean")
    return loss

🚀 模型训练流程(简化)

python 复制代码
for batch in dataloader:
    prompts = batch["prompt"]           # 文本输入
    candidates = generate_group(prompts, group_size=4)  # 模型生成多个候选回答
    log_probs = model.get_logprobs(prompts, candidates) # 每个回答的概率对数
    rewards = score_fn(candidates)      # 用规则/模型对回答打分
    loss = grpo_loss(log_probs, rewards)
    loss.backward()
    optimizer.step()

🧠 score_fn(candidates) 可以怎么写?

• ✅ 结构奖励(是否含 ...)

• ✅ 关键词覆盖率

• ✅ 匹配参考答案

• ✅ BLEU、ROUGE、BERTScore

• ✅ 任务自定义规则

✅ 效果:

你训练出的模型就会:

• 自动倾向生成高评分排序靠前的回答

• 不依赖人工偏好打分

• 支持推理步骤的逐步奖励(比如医学领域)

GRPO和DPO的对比

python 复制代码
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments

######################################
# ✅ 通用准备
######################################
model_name = "Qwen/Qwen1.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

######################################
# ✅ DPO Loss 实现
######################################
def dpo_loss(preference_logits, ref_logits):
    """
    preference_logits: 模型偏好的输出对数概率 (chosen - rejected)
    ref_logits: 参考模型对同一对的输出
    """
    diff = preference_logits - ref_logits
    loss = -F.logsigmoid(diff).mean()
    return loss

######################################
# ✅ GRPO Loss 实现(排序式)
######################################
def grpo_loss(group_logps, group_rewards):
    # 归一化奖励分布作为 target 分布
    target_dist = F.softmax(group_rewards, dim=1)
    logp_dist = F.log_softmax(group_logps, dim=1)
    return F.kl_div(logp_dist, target_dist, reduction='batchmean')

######################################
# ✅ 模拟训练数据结构
######################################
python 复制代码
# 对于 DPO: [prompt, chosen_response, rejected_response]
# 对于 GRPO: [prompt, [response1, response2, response3, response4]], rewards

data_dpo = [
    {
        "prompt": "解释什么是高血压?",
        "chosen": "高血压是指血压持续高于正常值的一种慢性病...",
        "rejected": "高血压就是血太多的病..."
    }
]

data_grpo = [
    {
        "prompt": "解释什么是高血压?",
        "responses": [
            "高血压是常见慢性病...",
            "血压升高可能由...",
            "高血压就是血太多了...",
            "一种高发病的状态..."
        ],
        "rewards": [5.0, 4.5, 1.0, 3.0]  # 用规则或参考答案打分
    }
]

######################################
# ✅ 模型输出 logp(简化版)
######################################
def get_log_probs(prompt, responses):
    """给定 prompt 和多个 responses,返回 logp"""
    logps = []
    for r in responses:
        inputs = tokenizer(prompt + r, return_tensors="pt")
        with torch.no_grad():
            output = model(**inputs)
            logp = output.logits[:, :-1, :].log_softmax(-1)
            token_ids = inputs.input_ids[:, 1:]
            score = logp.gather(2, token_ids.unsqueeze(-1)).sum() / token_ids.size(1)
            logps.append(score)
    return torch.stack(logps)

######################################
# ✅ DPO 一次训练
######################################
def train_dpo_step():
    sample = data_dpo[0]
    prompt = sample["prompt"]
    logp_chosen = get_log_probs(prompt, [sample["chosen"]])[0]
    logp_rejected = get_log_probs(prompt, [sample["rejected"]])[0]
    loss = dpo_loss(logp_chosen, logp_rejected)
    print("[DPO] Loss:", loss.item())

######################################
# ✅ GRPO 一次训练
######################################
def train_grpo_step():
    sample = data_grpo[0]
    prompt = sample["prompt"]
    rewards = torch.tensor(sample["rewards"])
    logps = get_log_probs(prompt, sample["responses"])
    loss = grpo_loss(logps.unsqueeze(0), rewards.unsqueeze(0))
    print("[GRPO] Loss:", loss.item())

######################################
# ✅ 执行对比训练
######################################
train_dpo_step()
train_grpo_step()
相关推荐
寻丶幽风5 小时前
论文阅读笔记——Generating Long Sequences with Sparse Transformers
论文阅读·笔记·语言模型·transformer·稀疏自注意力
水深00安东尼14 小时前
GAT-GRAPH ATTENTION NETWORKS(论文笔记)
论文阅读·python
人有一心16 小时前
【论文阅读】MOE奠基论文《Adaptive Mixtures of Local Experts》
论文阅读
寻丶幽风1 天前
论文阅读笔记——Reactive Diffusion Policy
论文阅读·笔记·机器人·dp·具身智能
L-含光承影1 天前
【第三十一周】ViT 论文阅读笔记
论文阅读·计算机视觉
Ayakanoinu1 天前
【论文阅读】UniAD: Planning-oriented Autonomous Driving
论文阅读
Allen_LVyingbo2 天前
数智读书笔记系列028 《奇点更近》
论文阅读·笔记
不是吧这都有重名2 天前
[论文阅读]Transformers without Normalization
论文阅读
踏雪亦无痕2 天前
论文笔记:Dynamic Spectral Graph Anomaly Detection
论文阅读·深度学习·图论·异常检测