论文介绍
这篇论文介绍了一个名为 Med-R1 的新方法,用于提升多模态视觉语言模型(VLM)在医学图像理解和推理任务中的泛化能力和可解释性。下面是对整篇论文的简洁总结:
⸻
🧠 核心思想
• 当前医学 VLM 多依赖于监督微调(SFT),容易过拟合,缺乏泛化能力,且推理过程不可解释。
• Med-R1 引入了强化学习(RL),采用GRPO(Group Relative Policy Optimization)方法,能在无需复杂标注的情况下,引导模型学习具备医学推理逻辑的回答路径。
⸻
🧰 方法框架
1. 基础模型:以 Qwen2-VL-2B 为底座进行训练。
2. 强化学习方法:采用 GRPO,不依赖价值模型,利用规则和分组比较得到奖励信号,提升训练效率。
3. 奖励机制:
• 格式奖励(结构是否规范,如包含 <think>...</think>)
• 准确性奖励(预测结果是否正确)
⸻
🏥 应用范围
• 在 8 类医学图像模态 上进行测试:
• CT、MRI、X-ray、超声、皮肤镜、眼底、OCT、显微图像
• 针对 5 类医学任务:
• 器官识别、疾病诊断、病灶分级、模态识别、生物特征分析
⸻
📊 实验亮点
对比项 成绩
相比 Qwen2-VL-2B(SFT) 提升 15.84%
相比原始 Qwen2-VL-2B(Zero-shot) 提升 29.94%
跨任务泛化能力 比 Qwen2-VL-2B 提升 32.06%
模型规模 仅 2B 参数,超过了 Qwen2-VL-72B(72B)大模型
⸻
✅ 贡献总结
1. 提出 Med-R1:首个支持 8 类医学模态的 RL 医学 VLM,具备结构化推理输出。
2. 使用 GRPO:用规则奖励替代人工标注,提升训练效率和可解释性。
3. 小模型超大模型:在医疗领域实现更优的性能和资源效率。
⸻
📌 总结一句话:
Med-R1 利用强化学习实现高效、可解释且泛化性强的医学多模态问答系统,为医学大模型实用化提供了新范式。
GRPO
当然可以!GRPO 是强化学习中一个新颖的策略优化方法,特别适合用于大语言模型(LLM)等需要偏好引导、排序学习、推理强化的任务。
⸻
🧠 GRPO 是什么?
GRPO 全称是:Group Relative Policy Optimization(群体相对策略优化)
它是一种 无偏好标注、无价值函数 的强化学习方法,适用于对模型输出进行相对排序优化的场景。
⸻
🎯 GRPO 核心思想
不再只是"哪一个回答更好",而是"在一组回答中,哪些更好",以排序优化代替单一偏好训练。
⸻
🔍 GRPO 的关键特征:
特性 说明
✅ 群体对比 每次优化考虑一组候选输出(Group),进行相对排序
✅ 无需 reward model 不依赖人工打分或偏好模型(不像 PPO / DPO)
✅ 可插入规则奖励 如结构格式、医学准确性等
✅ 优化方式 类似 ListNet / ListMLE 的排序 loss,用策略梯度更新
⸻
🏥 在 Med-R1 中的使用方式:
• 每个问题让模型生成多个回答(构成一个 group)
• 使用规则打分(结构合理性、医学正确性)
• 对回答进行排序比较
• 训练模型使其更倾向于产生更靠前的回答
⸻
📌 GRPO vs DPO/PPO 对比表:
方法 偏好来源 是否需要价值函数 支持排序优化 适合任务
PPO Reward model ✅ 是 ❌ 只更新打分高的 RLHF、对齐
DPO 人工偏好对 ❌ 否 ❌ 二选一偏好 LLM 对齐
GRPO 无监督排序规则 ❌ 否 ✅ 群体排序优化 推理类、多步任务
⸻
📦 总结一句话:
GRPO 是一种无需人工打分的强化学习方法,基于组内排序信号优化模型输出,在复杂任务中兼顾效率、性能和可解释性。
⸻
当然可以!以下是一个简化版的 GRPO(Group Relative Policy Optimization)伪代码实现示例,用于帮助你理解其整体流程和核心逻辑。
⸻
🧠 GRPO 实现要点:
• 给定一个 prompt,模型生成多个候选回答(组成一个 group)
• 对这些回答进行打分排序(规则或函数)
• 使用排序 loss(如 ListNet) 计算奖励
• 通过策略梯度更新模型参数
⸻
✅ 简化代码框架(伪代码 + PyTorch风格)
python
import torch
import torch.nn.functional as F
def grpo_loss(log_probs, rewards):
"""
GRPO 的核心 loss:排序感知 loss(比如 ListNet)
log_probs: shape [B, G],B 是 batch,G 是 group size
rewards: shape [B, G],评分(越高越好)
"""
# 归一化奖励(softmax 排序分布)
reward_dist = F.softmax(rewards, dim=1) # [B, G]
log_prob_dist = F.log_softmax(log_probs, dim=1) # [B, G]
# ListNet 损失:KL(reward_dist || log_prob_dist)
loss = F.kl_div(log_prob_dist, reward_dist, reduction="batchmean")
return loss
⸻
🚀 模型训练流程(简化)
python
for batch in dataloader:
prompts = batch["prompt"] # 文本输入
candidates = generate_group(prompts, group_size=4) # 模型生成多个候选回答
log_probs = model.get_logprobs(prompts, candidates) # 每个回答的概率对数
rewards = score_fn(candidates) # 用规则/模型对回答打分
loss = grpo_loss(log_probs, rewards)
loss.backward()
optimizer.step()
⸻
🧠 score_fn(candidates) 可以怎么写?
• ✅ 结构奖励(是否含 ...)
• ✅ 关键词覆盖率
• ✅ 匹配参考答案
• ✅ BLEU、ROUGE、BERTScore
• ✅ 任务自定义规则
⸻
✅ 效果:
你训练出的模型就会:
• 自动倾向生成高评分排序靠前的回答
• 不依赖人工偏好打分
• 支持推理步骤的逐步奖励(比如医学领域)
⸻
GRPO和DPO的对比
python
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
######################################
# ✅ 通用准备
######################################
model_name = "Qwen/Qwen1.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
######################################
# ✅ DPO Loss 实现
######################################
def dpo_loss(preference_logits, ref_logits):
"""
preference_logits: 模型偏好的输出对数概率 (chosen - rejected)
ref_logits: 参考模型对同一对的输出
"""
diff = preference_logits - ref_logits
loss = -F.logsigmoid(diff).mean()
return loss
######################################
# ✅ GRPO Loss 实现(排序式)
######################################
def grpo_loss(group_logps, group_rewards):
# 归一化奖励分布作为 target 分布
target_dist = F.softmax(group_rewards, dim=1)
logp_dist = F.log_softmax(group_logps, dim=1)
return F.kl_div(logp_dist, target_dist, reduction='batchmean')
######################################
# ✅ 模拟训练数据结构
######################################
python
# 对于 DPO: [prompt, chosen_response, rejected_response]
# 对于 GRPO: [prompt, [response1, response2, response3, response4]], rewards
data_dpo = [
{
"prompt": "解释什么是高血压?",
"chosen": "高血压是指血压持续高于正常值的一种慢性病...",
"rejected": "高血压就是血太多的病..."
}
]
data_grpo = [
{
"prompt": "解释什么是高血压?",
"responses": [
"高血压是常见慢性病...",
"血压升高可能由...",
"高血压就是血太多了...",
"一种高发病的状态..."
],
"rewards": [5.0, 4.5, 1.0, 3.0] # 用规则或参考答案打分
}
]
######################################
# ✅ 模型输出 logp(简化版)
######################################
def get_log_probs(prompt, responses):
"""给定 prompt 和多个 responses,返回 logp"""
logps = []
for r in responses:
inputs = tokenizer(prompt + r, return_tensors="pt")
with torch.no_grad():
output = model(**inputs)
logp = output.logits[:, :-1, :].log_softmax(-1)
token_ids = inputs.input_ids[:, 1:]
score = logp.gather(2, token_ids.unsqueeze(-1)).sum() / token_ids.size(1)
logps.append(score)
return torch.stack(logps)
######################################
# ✅ DPO 一次训练
######################################
def train_dpo_step():
sample = data_dpo[0]
prompt = sample["prompt"]
logp_chosen = get_log_probs(prompt, [sample["chosen"]])[0]
logp_rejected = get_log_probs(prompt, [sample["rejected"]])[0]
loss = dpo_loss(logp_chosen, logp_rejected)
print("[DPO] Loss:", loss.item())
######################################
# ✅ GRPO 一次训练
######################################
def train_grpo_step():
sample = data_grpo[0]
prompt = sample["prompt"]
rewards = torch.tensor(sample["rewards"])
logps = get_log_probs(prompt, sample["responses"])
loss = grpo_loss(logps.unsqueeze(0), rewards.unsqueeze(0))
print("[GRPO] Loss:", loss.item())
######################################
# ✅ 执行对比训练
######################################
train_dpo_step()
train_grpo_step()