下面进入 "持续规划 + 源码精读第一阶段"。这次我们做两件事:
- 先给你制定一套完整吃透路线。
- 正式开始拆解
scripts/train.sh和dataset.py,因为这两个文件决定你能否复现。
一、完整吃透 SegEarth-R2 的学习路线
我建议把整个项目拆成 10 讲,每一讲都有明确目标。
text
第 1 讲:训练入口 scripts/train.sh
目标:知道训练命令每个参数的作用,能改成自己的机器环境。
第 2 讲:数据准备与数据格式
目标:知道 LaSeRS 数据集每个字段含义,知道 image、description、answer、mask 如何对应。
第 3 讲:dataset.py 数据流
目标:吃透图像预处理、mask 预处理、文本 token 化、[SEG] 标记构造、batch 拼接。
第 4 讲:train.py 训练主控
目标:知道模型如何加载,tokenizer 如何扩展,[SEG] 如何加入词表,LoRA 如何注入,哪些模块训练、哪些模块冻结。
第 5 讲:模型整体结构 llava_phi.py
目标:知道 SegEarthR2Model、SegEarthR2、Mipha、SigLIP、Swin、Mask2Former 之间的关系。
第 6 讲:多模态输入拼接
目标:吃透图像 token 如何插入语言模型序列,images_clip 和 images 为什么分两路。
第 7 讲:[SEG] token 到 mask 的核心机制
目标:吃透 [SEG] hidden state 如何变成 mask query。
第 8 讲:loss 计算
目标:吃透 language loss、mask loss、dice loss、attention loss。
第 9 讲:eval.py 推理流程
目标:知道测试时如何输入 prompt,如何生成 [SEG],如何输出 mask。
第 10 讲:复现、调试、改进
目标:在 RTX 5090 上跑通,解决 CUDA / Detectron2 / MSDeformAttn / DeepSpeed / LoRA 合并问题,并思考如何改成你的博士课题方向。
你现在要先建立一个总认识:
text
SegEarth-R2 = 多模态大语言模型 + 遥感图像编码器 + Mask2Former 分割头
输入:
遥感图像 + 文本指令
输出:
文本回答 + 与 [SEG] 对应的分割 mask
核心桥梁:
[SEG] token 的 hidden state
也就是说,这个项目不是简单"图像进、mask 出",而是:
text
图像 + 语言
↓
语言模型理解
↓
生成或使用 [SEG]
↓
取出 [SEG] 的隐藏状态
↓
作为 mask query
↓
Mask2Former 输出分割结果
二、第 1 讲:scripts/train.sh 逐项拆解
官方训练脚本只有一个核心命令:用 deepspeed 启动 segearth_r2/train/train.py,并传入模型路径、视觉塔路径、Mask2Former 权重、数据路径、训练步数、LoRA、DeepSpeed 配置和 mask 配置等参数。(GitHub)
原始逻辑可以概括成:
bash
export NCCL_P2P_DISABLE="1"
export NCCL_IB_DISABLE="1"
deepspeed --master_port=29500 --include localhost:4 segearth_r2/train/train.py \
--model_name_or_path "pretrained_model/mllm/Mipha-3B" \
--vision_tower "pretrained_model/CLIP/siglip-so400m-patch14-384" \
--vision_tower_mask "pretrained_model/mask2former/model_final_54b88a.pkl" \
--base_data_path '/data1/xzp/data' \
--output_dir output_folder \
--max_steps 5000 \
--per_device_train_batch_size 1 \
--bf16 True \
--learning_rate 5e-5 \
--lora_r 4 \
--deepspeed scripts/zero3.json \
--mask_config 'segearth_r2/model/mask_decoder/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml' \
--data_ratio '1' \
--switch_bs 4
下面逐项解释。
1. export NCCL_P2P_DISABLE="1"
作用:关闭 NCCL 的 P2P 通信。
NCCL 是多 GPU 训练时负责显卡之间通信的库。关闭 P2P 的常见原因是:部分机器多卡拓扑、驱动、PCIe/NVLink 兼容不好时,P2P 通信容易卡死或报错。
如果你是 单卡 RTX 5090,这个参数通常不会成为核心问题;如果是多卡训练,遇到 NCCL 卡死时它很有用。
2. export NCCL_IB_DISABLE="1"
作用:关闭 InfiniBand 通信。
如果你不是多节点集群,而是单机训练,InfiniBand 一般用不到。关闭它可以减少通信初始化问题。
你的情况大概率是:
text
单机 + RTX 5090
所以这两个 NCCL 环境变量可以保留。
3. deepspeed --master_port=29500
deepspeed 是训练启动器。
不用普通:
bash
python segearth_r2/train/train.py
而是用:
bash
deepspeed segearth_r2/train/train.py
原因是这个项目包含:
text
Mipha-3B 大语言模型
SigLIP 视觉塔
Swin / Mask2Former 分割模块
LoRA
大尺寸遥感图像
显存压力比较大,所以用 DeepSpeed 做显存优化、梯度同步和分布式训练。
--master_port=29500 是分布式通信端口。如果端口被占用,会报类似:
text
Address already in use
解决方法:
bash
--master_port=29501
或者换成其他端口。
4. --include localhost:4
这个参数非常重要。
官方脚本写的是:
bash
--include localhost:4
含义是:只使用本机第 4 号 GPU。
如果你只有一张 RTX 5090,应该改成:
bash
--include localhost:0
如果你有 4 张卡:
bash
--include localhost:0,1,2,3
如果这里不改,最常见错误是:
text
No slot '4' specified on host 'localhost'
或者 DeepSpeed 找不到对应 GPU。
5. --model_name_or_path
bash
--model_name_or_path "pretrained_model/mllm/Mipha-3B"
这是基础多模态语言模型路径。
它负责:
text
理解文本 prompt
接收图像 token
生成语言输出
输出 hidden states
提供 [SEG] token 的语义表征
注意,它不是专门做 mask 的模块。真正产生 mask 的是后面的 Mask2Former 分割头。
6. --vision_tower
bash
--vision_tower "pretrained_model/CLIP/siglip-so400m-patch14-384"
这是给多模态大语言模型使用的视觉塔,也就是 SigLIP。
它的作用是把图像变成视觉 token,再通过 projector 映射到语言模型空间。
你可以理解为:
text
原始图像
↓
SigLIP
↓
视觉 patch embedding
↓
mm_projector
↓
语言模型可理解的图像 token
7. --vision_tower_mask
bash
--vision_tower_mask "pretrained_model/mask2former/model_final_54b88a.pkl"
这个是 Mask2Former / Swin 分割分支的预训练权重。
注意,这和 vision_tower 不是一个东西。
SegEarth-R2 里有两套视觉处理逻辑:
text
images_clip → SigLIP → 多模态语言理解
images → Swin / Mask2Former → 像素级分割
这是你必须记住的第一个关键点。
8. --base_data_path
bash
--base_data_path '/data1/xzp/data'
这是数据集根目录。
根据 dataset.py 的逻辑,如果是训练集,它会去找:
text
base_data_path/train/images
base_data_path/train/annotations/train_data.json
如果是测试集,它会去找:
text
base_data_path/test/images
base_data_path/test/annotations/xxx.json
LaSeRSDataset 会根据 split 参数判断使用 train 还是 test 子目录。(GitHub)
你的路径必须改成自己的数据目录,例如:
bash
--base_data_path "/home/yourname/datasets/LaSeRS"
9. --output_dir
bash
--output_dir output_folder
训练输出目录。
这里会保存:
text
checkpoint
LoRA 权重
trainer state
模型配置
最终保存模型
建议你不要直接叫 output_folder,而是改成清楚的实验名:
bash
--output_dir outputs/segearth_r2_lora_r4_bs1_5090
这样后面做多组实验不会混乱。
10. --max_steps 5000
训练 5000 step。
注意它不是 epoch,而是 step。因为使用 HuggingFace Trainer / DeepSpeed 时,经常通过 step 控制训练。
如果你第一次复现,不要直接跑 5000 step。建议先改成:
bash
--max_steps 20
先确认:
text
数据能读
模型能加载
forward 能通
loss 正常下降或至少能计算
checkpoint 能保存
跑通后再改成:
bash
--max_steps 1000
最后再跑完整训练。
11. --per_device_train_batch_size 1
每张 GPU 上 batch size 为 1。
这很正常,因为:
text
大语言模型很大
图像尺寸 1024×1024
还有 mask decoder
还要保存 hidden states 和 attention
显存压力非常高。
如果 RTX 5090 显存够,可以尝试 2,但第一次复现建议保持 1。
12. --bf16 True
使用 bfloat16 训练。
RTX 5090 支持 BF16,建议保留。BF16 比 FP16 更稳定,尤其是大模型训练。
不过如果你遇到某些算子 BF16 不支持,可以临时改成:
bash
--bf16 False
--fp16 True
但优先建议 BF16。
13. --learning_rate 5e-5
LoRA 微调时 5e-5 是常见学习率。
如果你训练的是 LoRA + 分割头,这个学习率可以先保持。后续做实验可以试:
text
1e-5
2e-5
5e-5
1e-4
但第一次复现不要乱改。
14. --lora_r 4
LoRA 秩为 4。
LoRA 的作用是:不全量微调大语言模型,而是在部分线性层上增加低秩可训练矩阵。
你可以理解为:
text
冻结大模型大部分参数
只训练少量 LoRA 参数
降低显存
降低训练成本
降低过拟合风险
在 train.py 中,项目会通过 find_linear_layers() 找到可插入 LoRA 的线性层,再用 get_peft_model() 注入 LoRA。(GitHub)
15. --deepspeed scripts/zero3.json
使用 DeepSpeed ZeRO-3。
ZeRO-3 会切分:
text
模型参数
梯度
优化器状态
显存最省,但通信和调试复杂度更高。
如果你单卡 RTX 5090,只是先跑通,可以考虑:
bash
--deepspeed scripts/zero2.json
甚至先不用 DeepSpeed 做最小调试,但官方脚本是 DeepSpeed 路线。
16. --mask_config
bash
--mask_config 'segearth_r2/model/mask_decoder/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml'
这是 Mask2Former 分割模块的配置文件。
它决定:
text
Swin backbone 配置
Mask2Former hidden_dim
query 数量
loss 权重
decoder 层数
训练采样点数
在 llava_phi.py 中,初始化 mask module 时会读取这个配置,构造 pixel decoder、predictor、criterion 等模块。(GitHub)
17. --data_ratio '1'
控制数据集混合比例。
当前只有 LaSeRS 一个数据源,所以是:
bash
--data_ratio '1'
如果以后你把自己的遥感数据加入进去,可以扩展成:
bash
--data_ratio '1||1'
表示两个数据集混合。
train.py 中 make_unify_datamodule() 会读取 data_ratio,然后构建 LaSeRSDataset 和 UnifyDatasetSingleDatasetForBatch。(GitHub)
18. --switch_bs 4
这个参数用于 UnifyDatasetSingleDatasetForBatch。
简单理解:
text
每隔 switch_bs 个 batch 切换一次数据集
当前只有一个数据集时,它的意义不大;多个数据集混训时才重要。
三、你应该把 train.sh 改成这样先跑通
你的 RTX 5090 单卡调试版建议先这样:
bash
export NCCL_P2P_DISABLE="1"
export NCCL_IB_DISABLE="1"
deepspeed --master_port=29501 --include localhost:0 segearth_r2/train/train.py \
--model_name_or_path "pretrained_model/mllm/Mipha-3B" \
--vision_tower "pretrained_model/CLIP/siglip-so400m-patch14-384" \
--vision_tower_mask "pretrained_model/mask2former/model_final_54b88a.pkl" \
--base_data_path "/your/path/LaSeRS" \
--output_dir outputs/debug_5090_lora_r4 \
--max_steps 20 \
--per_device_train_batch_size 1 \
--save_strategy "steps" \
--save_steps 10 \
--bf16 True \
--save_total_limit 1 \
--learning_rate 5e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 False \
--model_max_length 2048 \
--gradient_checkpointing False \
--dataloader_num_workers 2 \
--lora_r 4 \
--deepspeed scripts/zero3.json \
--mask_config 'segearth_r2/model/mask_decoder/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml' \
--data_ratio '1' \
--switch_bs 4
第一次只跑 20 step。不要一开始跑 5000 step。
你的第一目标不是训练出好结果,而是确认:
text
1. 环境没问题
2. 数据能读
3. 模型能加载
4. forward 能通
5. loss 能算
6. checkpoint 能保存
四、第 2 讲:数据链路整体理解
现在开始看 dataset.py。这个文件的作用是:
text
原始 JSON + 图片 + mask
↓
LaSeRSDataset.__getitem__()
↓
构造单样本 data_dict
↓
DataCollatorForCOCODatasetV2.__call__()
↓
拼成 batch
↓
送入 model.forward()
你必须先理解它产生了哪些东西。
一个 batch 大致长这样:
python
{
"input_ids": ...,
"labels": ...,
"attention_mask": ...,
"images_clip": ...,
"images": ...,
"seg_info": ...,
"dataset_type": ...,
"token_refer_id": ...,
"SEG_token_embedding_indices": ...,
"mask_num": ...
}
每个字段含义:
| 字段 | 作用 |
|---|---|
input_ids |
文本 prompt + answer token 序列 |
labels |
语言模型监督标签,人类指令部分会被设为 IGNORE_INDEX |
attention_mask |
标记哪些 token 是有效 token |
images_clip |
给 SigLIP / 多模态 LLM 使用的图像 |
images |
给 Swin / Mask2Former 使用的 1024 图像 |
seg_info |
存放 mask、image_id、height、width 等分割监督信息 |
token_refer_id |
referring instruction 的 token |
SEG_token_embedding_indices |
标记 input_ids 中哪些位置是 [SEG] |
mask_num |
当前样本中 [SEG] 数量,也就是需要预测几个 mask |
五、dataset.py 的关键函数 1:preprocess_mask
这个函数的作用是:把 mask 按比例缩放,并 padding 到固定大小。
核心逻辑:
text
输入 mask
↓
如果是二维 mask,就扩展成 [1, H, W]
↓
逐个 mask 处理
↓
保持长宽比 resize
↓
不足部分 padding 0
↓
输出 [N, image_size, image_size]
为什么 mask resize 要用最近邻插值?
因为 mask 是类别标签,不是自然图像。
如果用双线性插值,0/1 mask 会变成 0.2、0.5、0.8 这种灰度值,类别边界会被污染。
所以代码里使用的是:
python
cv2.INTER_NEAREST
这很关键。
六、dataset.py 的关键函数 2:preprocess_image
这个函数处理的是给 Mask2Former 分割分支使用的图像。
逻辑是:
text
读取图像
↓
按照短边 resize
↓
限制最大边
↓
padding 到 1024×1024
↓
返回 numpy array
这一步得到的图像后面会进入:
text
Swin backbone
pixel decoder
Mask2Former predictor
而不是直接给 SigLIP。
注意这里和 images_clip 是两条不同路线。
七、dataset.py 的关键函数 3:tokenizer_special_tokens
这个函数负责处理特殊 token。
代码里有两个特殊占位:
text
IMAGE_TOKEN_INDEX
REFER_TOKEN_INDEX
它通过正则把 prompt 切开,一旦遇到特殊 token,就不用普通 tokenizer 编码,而是直接替换成特殊 index。(GitHub)
你可以理解为:
text
普通文本 → tokenizer.encode()
图像占位符 → IMAGE_TOKEN_INDEX
refer 占位符 → REFER_TOKEN_INDEX
这一步是多模态模型常见做法。
因为图像不是普通文字,但它要插入语言模型序列,所以需要一个特殊 token 位置告诉模型:
text
这里要插入图像特征
八、dataset.py 的关键函数 4:preprocess_llama2
这个函数是文本监督的核心。
它做了三件事:
第一,构造对话模板
数据被包装成人类和助手的对话:
python
[
{"from": "human", "value": "...instruction..."},
{"from": "gpt", "value": "...answer..."}
]
然后根据 conversation template 拼成完整 prompt。
第二,把 prompt token 化
调用:
python
self.tokenizer_special_tokens(prompt, tokenizer, return_tensors='pt')
得到:
python
input_ids
第三,构造 labels,并屏蔽人类问题部分
训练语言模型时,不应该让模型学习"人类输入问题",而应该让模型学习"助手回答"。
所以代码会把 human instruction 部分的 label 设置为:
python
IGNORE_INDEX
这样 loss 只在 assistant answer 部分计算。
你可以理解为:
text
Human: segment the airport
Assistant: airport [SEG]
训练时:
Human 部分不算 loss
Assistant 部分算 language loss
这也是 LLaVA、Vicuna、Instruct tuning 类项目的标准做法。
九、LaSeRSDataset.__init__
初始化阶段主要做:
text
设置 pixel mean/std
保存 base_data_path
保存 tokenizer
根据 split 判断 train/test
读取 annotation json
记录 [SEG] token id
代码逻辑中,如果 split 包含 train,就读取 train/images 和 train/annotations;如果 split 包含 test,就读取 test/images 和 test/annotations。(GitHub)
这里要注意一个细节:
python
self.SEG_token_id = self.tokenizer.convert_tokens_to_ids("[SEG]")
这说明 [SEG] 必须已经被加入 tokenizer。
而它是在 train.py 里通过:
python
tokenizer.add_tokens("[SEG]")
model.resize_token_embeddings(len(tokenizer))
加入的。(GitHub)
所以顺序是:
text
train.py 加 [SEG] 到 tokenizer
↓
构建 LaSeRSDataset
↓
dataset.py 才能拿到 [SEG] token id
如果顺序错了,[SEG] 可能会被当成普通未知 token,整个分割桥梁就断了。
十、LaSeRSDataset.__getitem__ 逐步理解
这是数据文件最重要的函数。
1. 取出一条标注
python
data_info = self.reason_file[idx]
一条数据通常包含:
text
image_name
description
answer
id
mask
含义:
| 字段 | 含义 |
|---|---|
image_name |
图像文件名 |
description |
文本指令 |
answer |
模型应该学习的回答 |
id |
样本 id |
mask |
RLE 格式分割 mask |
2. 读取图片路径
python
image_path = os.path.join(self.LaSeRS_image_path, data_info['image_name'])
得到完整图像路径。
如果这里报错,通常是:
text
base_data_path 不对
image_name 与实际文件名不一致
train/images 目录结构不对
3. 读取文本指令和答案
python
ref = data_info['description']
answer = data_info['answer']
description 是给模型看的 instruction。
answer 是模型要学习输出的内容。它里面非常关键的是 [SEG]。
例如:
text
The target is building [SEG].
模型后面会通过 [SEG] 的 hidden state 预测 mask。
4. 解码 RLE mask
如果 data_info 中有 mask 字段,代码会用 pycocotools 解码:
python
mask = M.decode(rle)
然后多个 mask 堆叠成:
python
[N, H, W]
这里的 N 应该和 answer 中 [SEG] 的数量一致。
如果:
text
answer 中有 2 个 [SEG]
mask 只有 1 个
后面就会错位。
所以你调试数据集时必须检查:
python
assert answer.count("[SEG]") == len(mask_list)
5. 读取图像高宽
代码会用 OpenCV 读取 BGR 图像:
python
image_BGR = cv2.imread(image_path)
image_height = image_BGR.shape[0]
image_width = image_BGR.shape[1]
这些原始尺寸会保存到 data_dict,后面用于 mask 还原、评估或可视化。
6. 构造分割分支图像
python
image_RGB = preprocess_image(image_path)
image_tensor = torch.as_tensor(np.ascontiguousarray(image_RGB.transpose(2, 0, 1)))
data_dict['image'] = (image_tensor - self.pixel_mean) / self.pixel_std
这里得到的是给 Mask2Former 分支的图像。
注意几件事:
text
1. 图像变成 C×H×W
2. H 和 W 最终是 1024
3. 使用 ImageNet mean/std 归一化
4. 这个 image 不是给 SigLIP 的
7. 根据 [SEG] 数量构造 annotations
python
mask_num = answer.count("[SEG]")
这行极其重要。
它说明:
text
answer 中有几个 [SEG]
就认为要预测几个 mask
然后循环:
python
for i in range(mask_num):
data_dict['annotations'].append(...)
每个 annotation 里包含:
text
data_id
mask_id
mask
image_path
height
width
image_id
也就是说,[SEG] 和 mask 是一一对应的。
你必须记住这句话:
text
SegEarth-R2 的分割监督不是靠类别 id 对齐,而是靠 answer 中的 [SEG] 顺序和 mask 顺序对齐。
8. 构造 human prompt
代码里有类似:
python
prefix_inst = 'This is an image \n\n, please doing Reasoning Segmentation according to the following instruction:'
instruction = ref.strip()
然后构造:
python
sources = [
[
{'from': 'human', 'value': prefix_inst + '\n <|assistant|>'},
{'from': 'gpt', 'value': '\n' + answer}
]
]
这里要理解:
text
human 部分:告诉模型看图并根据 instruction 做 reasoning segmentation
gpt 部分:标准答案,里面包含 [SEG]
有一个细节:instruction = ref.strip() 得到了文本指令,但你要认真检查官方代码中它是否真正拼入了 human prompt。如果你后面做自己的数据格式,建议确保 instruction 明确进入 prompt,否则模型可能看不到具体目标描述。
9. 构造 input_ids 和 labels
python
text_dict = self.preprocess_llama2(sources, self.tokenizer)
得到:
python
input_ids
labels
其中:
text
input_ids:完整 token 序列
labels:训练监督标签
human 输入部分的 labels 会被设为 IGNORE_INDEX,assistant answer 部分才计算语言模型 loss。
10. 构造 [SEG] 位置标记
python
SEG_token_embedding_indices = torch.zeros_like(input_ids)
SEG_token_embedding_indices[input_ids == self.SEG_token_id] = 1
这一步是整个项目的核心之一。
它的含义是:
text
在 input_ids 中找到所有 [SEG]
标记这些位置
后面从 LLM hidden states 中取这些位置的 hidden state
之后在 llava_phi.py 里会调用:
python
self.get_SEG_embedding(hidden_states, SEG_token_embedding_indices)
取出 [SEG] 对应 hidden state,再经过 SEG_token_projector 映射到 Mask2Former hidden dim。llava_phi.py 中可以看到 get_SEG_embedding 负责根据索引取出 SEG embedding,随后用 SEG_token_projector 处理。(GitHub)
十一、DataCollatorForCOCODatasetV2.__call__
这个类负责把多个样本拼成一个 batch。
你可以理解为:
text
Dataset.__getitem__ 只返回一个样本
DataCollator 把多个样本整理成 batch
它主要做 5 件事。
1. padding 文本
因为每个样本文本长度不同,所以需要 padding。
python
input_ids = pad_sequence(...)
labels = pad_sequence(...)
然后生成:
python
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
2. 构造 images_clip
代码会重新读取原图,然后用 SigLIP 的 image processor 处理:
python
self.clip_image_processor.preprocess(..., return_tensors="pt")
这一路给 SigLIP / 多模态 LLM 用。
注意:这是 images_clip,不是 images。
3. 构造 images
python
images = [instance['image'] for instance in instances]
batch['images'] = torch.stack(images)
这一路是前面 preprocess_image() 得到的 1024 图像,用于 Swin / Mask2Former 分割分支。
所以 batch 中有两个图像字段:
text
images_clip:语言视觉塔
images:分割视觉塔
这个设计非常重要。
4. 构造 seg_info
每个 annotation 被整理到 seg_info 里。
其中 mask 会被转成 tensor:
python
seg['mask'] = torch.as_tensor(seg['mask'], dtype=torch.uint8)
后面 llava_phi.py 会根据 seg_info 构造 mask targets,并传给 criterion 计算 loss_mask 和 loss_dice。相关代码中可以看到 targets 包含 labels 和 masks,之后调用 self.criterion(mask_outputs, targets),并汇总 loss_mask 和 loss_dice。(GitHub)
5. padding [SEG] 位置标记
python
SEG_token_embedding_indices = pad_sequence(...)
因为不同样本长度不同,所以 [SEG] 标记也要 padding。
最终 batch 会包含:
python
batch['SEG_token_embedding_indices']
这个字段后面会直接传入模型 forward。
十二、到这里你应该形成第一张"数据流图"
你现在必须能自己画出下面这张图:
text
train_data.json
↓
data_info = reason_file[idx]
↓
image_name → image_path
description → instruction
answer → answer.count("[SEG]")
mask RLE → M.decode → masks
↓
preprocess_image → images
SigLIP processor → images_clip
preprocess_llama2 → input_ids + labels
input_ids == [SEG] → SEG_token_embedding_indices
↓
DataCollator
↓
batch
↓
model.forward(
input_ids,
labels,
images,
images_clip,
seg_info,
SEG_token_embedding_indices,
mask_num
)
你如果能把这张图讲清楚,就说明你已经吃透了数据部分的 50%。
十三、复现时必须先做的 5 个数据检查
在正式训练前,你一定要写一个小调试脚本,或者在 train.py 里临时打印一个 batch。
你要检查:
python
print(batch.keys())
print(batch["input_ids"].shape)
print(batch["labels"].shape)
print(batch["images"].shape)
print(batch["images_clip"].shape)
print(batch["SEG_token_embedding_indices"].shape)
print(batch["mask_num"])
print(len(batch["seg_info"]))
你希望看到类似:
text
input_ids: [B, L]
labels: [B, L]
images: [B, 3, 1024, 1024]
images_clip: [B, 3, 384, 384] 或类似 SigLIP 输入尺寸
SEG_token_embedding_indices: [B, L]
mask_num: [每个样本的 mask 数量]
seg_info: 总 mask 数量
再检查:
python
print(batch["SEG_token_embedding_indices"].sum())
print(sum(batch["mask_num"]))
这两个值应该相等。
因为:
text
batch 中 [SEG] 的总数 = batch 中需要预测的 mask 总数
如果不相等,后面 [SEG] embedding 和 mask target 就会错位。
十四、你现在应该掌握的关键结论
到这一讲为止,你要记住 8 个结论。
结论 1
train.sh 是复现入口,不是普通配置文件。所有关键路径、训练策略、LoRA、DeepSpeed、mask config 都从这里进入。
结论 2
官方脚本默认 --include localhost:4,你的 RTX 5090 单卡应该改成 localhost:0。
结论 3
SegEarth-R2 有两路图像输入:
text
images_clip → SigLIP / LLM
images → Swin / Mask2Former
结论 4
answer.count("[SEG]") 决定当前样本预测几个 mask。
结论 5
[SEG] token 是语言模型和分割模型之间的桥。
结论 6
SEG_token_embedding_indices 是为了告诉模型:哪些 hidden state 要拿去做 mask query。
结论 7
labels 不是简单等于 input_ids,human instruction 部分会被 IGNORE_INDEX 屏蔽。
结论 8
训练前必须检查:
text
[SEG] 数量 == mask 数量
否则训练一定出问题。
十五、下一阶段的阅读任务
下一阶段进入 train.py 主训练流程。重点看:
text
1. ModelArguments
2. DataArguments
3. TrainingArguments
4. get_mask_config()
5. SegEarthR2.from_pretrained()
6. model.initial_mask_module()
7. tokenizer.add_tokens("[SEG]")
8. model.resize_token_embeddings()
9. find_linear_layers()
10. get_peft_model()
11. train_module_list
12. make_unify_datamodule()
13. LLaVATrainer
下一讲你要重点吃透这句话:
text
train.py 决定了:模型怎么加载、哪些参数训练、哪些参数冻结、[SEG] 怎么加入、LoRA 怎么注入、数据怎么送进 Trainer。
而再下一讲才进入最核心的:
text
llava_phi.py:从 [SEG] hidden state 到 mask 的全过程。