从零到一 | CV转多模态大模型 | week14 | Mini-LLaVA多轮对话稳定性与JSON结构化输出控制
摘要:本周在 week13 的 OpenAI Chat 三角色对齐基础上,继续补齐 Mini-LLaVA 的多轮对话工程能力。主要包括四件事:推理侧增加历史轮数截断,训练侧增加轮次级和 token 级截断,针对 JSON 业务输出补充 few-shot prompt 模板,并封装 JSON 提取、校验和失败重试模块。最后通过固定多轮测试脚本,把每轮问题、回答、JSON 校验状态和截断后的 history 轮数保存为 JSONL,方便后续对比不同 --max-history-turns 配置下的稳定性。
代码地址:
https://github.com/wz940216/From0to1-MLLM-StudyLog
本周目标
week13 已经把数据格式推进到 OpenAI Chat 风格的三角色消息格式,即 system/user/assistant。这一步解决的是输入协议问题。
week14 继续往工程化方向走,重点处理多轮对话和结构化输出:
- 长对话只保留必要上下文,降低 prompt 过长、遗忘图片信息和角色跑偏的风险。
- 针对 JSON 业务输出增加 few-shot prompt 模板,让模型看到明确的字段和示例。
- 封装输出解析模块,统一完成 JSON 提取、校验和失败重试。
- 编写固定多轮对话测试脚本,自动喂问题并把每轮结果保存为 JSONL。
这部分在整个路线里的位置是:
text
单轮 caption / QA
-> 多任务数据混合
-> 多轮图文对话
-> Chat 三角色格式
-> 多轮稳定性和结构化输出控制
对话截断策略
推理侧在 infer.py 中使用 --max-history-turns 控制历史轮数,默认保留最近 3 轮。
python
def truncate_history(history, max_history_turns=None):
"""保留最近 N 轮 user/assistant 历史,降低长对话跑偏风险。"""
if max_history_turns is None or int(max_history_turns) <= 0:
return list(history)
return list(history)[-int(max_history_turns) * 2:]
这里的 history 实际上是 user/assistant 交替保存的消息列表。一轮对话包含一条 user 和一条 assistant,所以保留最近 N 轮时要截取 N * 2 条消息。
训练侧在 dataset.py 中使用两级截断:
- 轮次级:如果样本轮数大于 N,保留第 1 轮图文输入和最近 N-1 轮。
- Token 级:如果文本 token 仍超过
MAX_LENGTH,优先保留最后一段 assistant answer,剩余预算再从最近上下文向前填充。
核心实现如下:
python
def _crop_segment_keep_image(self, segment, budget, image_token_id):
"""裁剪包含 <image> 的片段,至少保留图片 token。"""
if budget <= 0:
return None
input_ids = segment["input_ids"]
labels = segment["labels"]
if image_token_id not in input_ids:
return {
"input_ids": input_ids[-budget:],
"labels": labels[-budget:],
"train": segment["train"],
}
image_pos = input_ids.index(image_token_id)
if len(input_ids) <= budget:
return segment
end = min(len(input_ids), image_pos + budget)
start = max(0, end - budget)
if not (start <= image_pos < end):
start = image_pos
end = min(len(input_ids), image_pos + budget)
return {
"input_ids": input_ids[start:end],
"labels": labels[start:end],
"train": segment["train"],
}
def _fit_segments_to_budget(self, encoded_segments):
"""按 token 预算裁剪:保留 <image>,优先保留最后 assistant answer。"""
max_length = int(self.max_length)
if max_length <= 1:
raise ValueError("max_length 必须大于 1,至少要容纳 <image> 和一个训练 token。")
image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
total_len = sum(len(segment["input_ids"]) for segment in encoded_segments)
if total_len <= max_length:
return encoded_segments
image_idx = None
for idx, segment in enumerate(encoded_segments):
if image_token_id in segment["input_ids"]:
image_idx = idx
break
if image_idx is None:
raise ValueError("样本缺少 <image> token,无法构造多模态输入。")
last_train_idx = None
for idx in range(len(encoded_segments) - 1, -1, -1):
if encoded_segments[idx]["train"]:
last_train_idx = idx
break
if last_train_idx is None:
raise ValueError("样本没有可训练的 assistant answer。")
image_anchor = self._crop_segment_keep_image(encoded_segments[image_idx], 1, image_token_id)
last_answer = encoded_segments[last_train_idx]
answer_budget = max_length - len(image_anchor["input_ids"])
answer_ids = last_answer["input_ids"]
answer_labels = last_answer["labels"]
if len(answer_ids) > answer_budget:
return [
image_anchor,
{
"input_ids": answer_ids[-answer_budget:],
"labels": answer_labels[-answer_budget:],
"train": True,
},
]
selected_middle = []
used = len(image_anchor["input_ids"]) + len(answer_ids)
for idx in range(last_train_idx - 1, -1, -1):
remain = max_length - used
if remain <= 0:
break
segment = encoded_segments[idx]
if idx == image_idx:
cropped_image = self._crop_segment_keep_image(
segment,
remain + len(image_anchor["input_ids"]),
image_token_id,
)
used += len(cropped_image["input_ids"]) - len(image_anchor["input_ids"])
image_anchor = cropped_image
break
seg_len = len(segment["input_ids"])
if seg_len <= remain:
selected_middle.insert(0, segment)
used += seg_len
continue
selected_middle.insert(0, {
"input_ids": segment["input_ids"][-remain:],
"labels": segment["labels"][-remain:],
"train": segment["train"],
})
break
return [image_anchor] + selected_middle + [last_answer]
def _flatten_segments(self, encoded_segments):
input_ids = []
labels = []
for segment in encoded_segments:
input_ids.extend(segment["input_ids"])
labels.extend(segment["labels"])
if self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) not in input_ids:
raise ValueError("截断后样本缺少 <image> token,请增大 MAX_LENGTH 或减小 MAX_TURNS。")
return input_ids, labels
这样做有两个目的:
- 保留图片进入模型的位置,避免多模态输入在截断后退化成纯文本样本。
- 避免简单
[:MAX_LENGTH]把尾部 answer 训练目标截掉。
预算极小时会至少保留一个 <image> token 和最后 answer 的尾部;如果 MAX_LENGTH <= 1,collator 会直接报错。
JSON few-shot 模板
prompt_templates.py 中提供 JSON_FEW_SHOT_PROMPT,固定输出字段:
json
{"answer": "...", "confidence": "...", "evidence": "..."}
推理时,当前用户问题会被拼到 few-shot 模板之后,要求模型只输出合法 JSON 对象。
这里没有假设 prompt 一定能百分百约束住模型。few-shot 的作用是提高模型遵守格式的概率,后面仍然需要解析和校验兜底。
输出解析和失败重试
output_parser.py 提供几个函数:
normalize_model_text():清理 prompt 回显、角色前缀和 markdown 代码块。extract_json_candidate():从模型输出中截取 JSON 对象。parse_json_output():校验 JSON 顶层对象和必需字段。JSON_REPAIR_PROMPT:当 JSON 校验失败时,触发一次重新生成。
python
import json
import re
from dataclasses import dataclass
@dataclass
class ParseResult:
ok: bool
data: object = None
text: str = ""
error: str = ""
JSON_REPAIR_PROMPT = """上一轮回答不是合法 JSON。请只根据原始问题重新输出一个合法 JSON 对象。
不要输出 markdown 代码块,不要输出解释文字。
固定格式:
{"answer": "这里填写对用户问题的回答"}"""
def strip_code_fence(text):
text = str(text).strip()
if not text.startswith("```"):
return text
lines = text.splitlines()
if lines and lines[0].strip().startswith("```"):
lines = lines[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
return "\n".join(lines).strip()
def extract_json_candidate(text):
"""从模型输出中提取最可能的 JSON 对象片段。"""
text = strip_code_fence(text)
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or start >= end:
return text.strip()
return text[start:end + 1].strip()
def parse_json_output(text, required_keys=None):
"""校验 JSON 输出,并可选检查必需字段。"""
candidate = extract_json_candidate(text)
try:
data = json.loads(candidate)
except json.JSONDecodeError as exc:
return ParseResult(False, text=candidate, error=str(exc))
if not isinstance(data, dict):
return ParseResult(False, data=data, text=candidate, error="JSON 顶层必须是对象。")
missing = [key for key in (required_keys or []) if key not in data]
if missing:
return ParseResult(False, data=data, text=candidate, error=f"JSON 缺少字段: {', '.join(missing)}")
return ParseResult(True, data=data, text=json.dumps(data, ensure_ascii=False))
def normalize_model_text(text, prompt=""):
"""尽量只保留当前轮 assistant 的回答。"""
text = str(text).strip()
if prompt and text.startswith(prompt):
text = text[len(prompt):].strip()
if "ASSISTANT:" in text:
text = text.split("ASSISTANT:")[-1].strip()
if "USER:" in text:
text = text.split("USER:", 1)[0].strip()
return strip_code_fence(text)
def looks_like_json_request(text):
return re.search(r"\bjson\b|结构化|字段|格式", str(text), flags=re.IGNORECASE) is not None
这里的设计比较朴素,但对工程脚本已经够用:
- 模型如果输出 markdown 代码块,先去掉外层 fence。
- 模型如果在 JSON 前后夹了说明文字,就截取第一个
{到最后一个}。 - JSON 能解析后,再检查顶层类型和必需字段。
- 解析失败时,不直接信任这次回答,而是触发一次修复重试。
推理侧失败重试
infer.py 的 answer_one_turn() 默认会校验 answer 字段;失败后自动追加修复提示重试一次。
python
def answer_one_turn(
model,
image,
history,
question,
gen_config,
system_prompt=None,
max_history_turns=None,
retry_on_json_error=True,
):
"""执行一轮带上下文的图文对话推理,并返回 assistant 回答。"""
prompt = build_context_prompt(
history,
question,
system_prompt=system_prompt,
max_history_turns=max_history_turns,
)
answer = _generate_text(model, image, prompt, gen_config)
parsed = parse_json_output(answer, required_keys=["answer"])
if parsed.ok:
return parsed.text
if not retry_on_json_error:
return answer
retry_question = f"{clean_user_text(question)}\n\n{JSON_REPAIR_PROMPT}\n上一轮错误:{parsed.error}"
retry_prompt = build_context_prompt(
history,
retry_question,
system_prompt=system_prompt,
max_history_turns=max_history_turns,
)
retry_answer = _generate_text(model, image, retry_prompt, gen_config)
retry_parsed = parse_json_output(retry_answer, required_keys=["answer"])
return retry_parsed.text if retry_parsed.ok else retry_answer
这里没有引入更复杂的约束解码,只做了一层轻量兜底。原因是当前阶段的重点还是把 mini-LLaVA 的训练、推理、对话格式和基础工程链路跑顺。
多轮对话测试
固定测试脚本:
shell
python week14_dialogue_stability_output_control/code/test_multiturn_dialogue.py \
--config week14_dialogue_stability_output_control/configs/caption_only_cpu.yaml \
--checkpoint none \
--max-history-turns 3
输出默认写入:
text
week14_dialogue_stability_output_control/outputs/logs/multiturn_test.jsonl
每条日志包含:
turnquestionanswerjson_okjson_error- 截断后的 history 轮数
这样后续可以比较不同 --max-history-turns 下的稳定性。
训练命令
shell
accelerate launch --multi_gpu week14_dialogue_stability_output_control/code/train.py \
--config week14_dialogue_stability_output_control/configs/multitask_balanced.yaml
推理命令
shell
python week14_dialogue_stability_output_control/code/infer.py \
--config week14_dialogue_stability_output_control/configs/config.yaml \
--checkpoint week14_dialogue_stability_output_control/outputs/checkpoints/step_2109.pt \
--image dataset/coco128/images/train2017/000000000025.jpg \
--question "请用 JSON 描述这张图片" \
--question "继续用 JSON 说明判断依据" \
--max-history-turns 3
本周小结
这一周主要解决的不是模型能力上限,而是工程稳定性:
- 多轮对话不能无限堆 history,需要有明确的上下文窗口策略。
- 多模态样本截断时不能把
<image>token 截掉。 - 训练侧要优先保留最后的 assistant answer,避免 loss 目标被截掉。
- JSON 输出不能只靠 prompt,需要解析、校验和失败重试。
- 固定测试脚本和 JSONL 日志能让后续对比更容易。
下一周会继续沿着第13-16周的指令对齐阶段往前推进,开始补安全策略和简单对齐思想。
以上笔记来源于我的仓库: https://github.com/wz940216/From0to1-MLLM-StudyLog.git
我正在连载一个从零到一的多模态大模型学习笔记。
如果你对多模态大模型感兴趣,或者也在准备往大模型方向转
可以点赞/Fork我的仓库: https://github.com/wz940216/From0to1-MLLM-StudyLog.git
也可评论区留言交流,后面我会继续把每周的学习记录、踩坑经验陆续更新到仓库和这里。