Segearth-R2-03

下面进入 "持续规划 + 源码精读第一阶段"。这次我们做两件事:

  1. 先给你制定一套完整吃透路线。
  2. 正式开始拆解 scripts/train.shdataset.py,因为这两个文件决定你能否复现。

一、完整吃透 SegEarth-R2 的学习路线

我建议把整个项目拆成 10 讲,每一讲都有明确目标。

text 复制代码
第 1 讲:训练入口 scripts/train.sh
目标:知道训练命令每个参数的作用,能改成自己的机器环境。

第 2 讲:数据准备与数据格式
目标:知道 LaSeRS 数据集每个字段含义,知道 image、description、answer、mask 如何对应。

第 3 讲:dataset.py 数据流
目标:吃透图像预处理、mask 预处理、文本 token 化、[SEG] 标记构造、batch 拼接。

第 4 讲:train.py 训练主控
目标:知道模型如何加载,tokenizer 如何扩展,[SEG] 如何加入词表,LoRA 如何注入,哪些模块训练、哪些模块冻结。

第 5 讲:模型整体结构 llava_phi.py
目标:知道 SegEarthR2Model、SegEarthR2、Mipha、SigLIP、Swin、Mask2Former 之间的关系。

第 6 讲:多模态输入拼接
目标:吃透图像 token 如何插入语言模型序列,images_clip 和 images 为什么分两路。

第 7 讲:[SEG] token 到 mask 的核心机制
目标:吃透 [SEG] hidden state 如何变成 mask query。

第 8 讲:loss 计算
目标:吃透 language loss、mask loss、dice loss、attention loss。

第 9 讲:eval.py 推理流程
目标:知道测试时如何输入 prompt,如何生成 [SEG],如何输出 mask。

第 10 讲:复现、调试、改进
目标:在 RTX 5090 上跑通,解决 CUDA / Detectron2 / MSDeformAttn / DeepSpeed / LoRA 合并问题,并思考如何改成你的博士课题方向。

你现在要先建立一个总认识:

text 复制代码
SegEarth-R2 = 多模态大语言模型 + 遥感图像编码器 + Mask2Former 分割头

输入:
遥感图像 + 文本指令

输出:
文本回答 + 与 [SEG] 对应的分割 mask

核心桥梁:
[SEG] token 的 hidden state

也就是说,这个项目不是简单"图像进、mask 出",而是:

text 复制代码
图像 + 语言
    ↓
语言模型理解
    ↓
生成或使用 [SEG]
    ↓
取出 [SEG] 的隐藏状态
    ↓
作为 mask query
    ↓
Mask2Former 输出分割结果

二、第 1 讲:scripts/train.sh 逐项拆解

官方训练脚本只有一个核心命令:用 deepspeed 启动 segearth_r2/train/train.py,并传入模型路径、视觉塔路径、Mask2Former 权重、数据路径、训练步数、LoRA、DeepSpeed 配置和 mask 配置等参数。(GitHub)

原始逻辑可以概括成:

bash 复制代码
export NCCL_P2P_DISABLE="1"
export NCCL_IB_DISABLE="1"

deepspeed --master_port=29500 --include localhost:4 segearth_r2/train/train.py \
  --model_name_or_path "pretrained_model/mllm/Mipha-3B" \
  --vision_tower "pretrained_model/CLIP/siglip-so400m-patch14-384" \
  --vision_tower_mask "pretrained_model/mask2former/model_final_54b88a.pkl" \
  --base_data_path '/data1/xzp/data' \
  --output_dir output_folder \
  --max_steps 5000 \
  --per_device_train_batch_size 1 \
  --bf16 True \
  --learning_rate 5e-5 \
  --lora_r 4 \
  --deepspeed scripts/zero3.json \
  --mask_config 'segearth_r2/model/mask_decoder/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml' \
  --data_ratio '1' \
  --switch_bs 4

下面逐项解释。


1. export NCCL_P2P_DISABLE="1"

作用:关闭 NCCL 的 P2P 通信。

NCCL 是多 GPU 训练时负责显卡之间通信的库。关闭 P2P 的常见原因是:部分机器多卡拓扑、驱动、PCIe/NVLink 兼容不好时,P2P 通信容易卡死或报错。

如果你是 单卡 RTX 5090,这个参数通常不会成为核心问题;如果是多卡训练,遇到 NCCL 卡死时它很有用。


2. export NCCL_IB_DISABLE="1"

作用:关闭 InfiniBand 通信。

如果你不是多节点集群,而是单机训练,InfiniBand 一般用不到。关闭它可以减少通信初始化问题。

你的情况大概率是:

text 复制代码
单机 + RTX 5090

所以这两个 NCCL 环境变量可以保留。


3. deepspeed --master_port=29500

deepspeed 是训练启动器。

不用普通:

bash 复制代码
python segearth_r2/train/train.py

而是用:

bash 复制代码
deepspeed segearth_r2/train/train.py

原因是这个项目包含:

text 复制代码
Mipha-3B 大语言模型
SigLIP 视觉塔
Swin / Mask2Former 分割模块
LoRA
大尺寸遥感图像

显存压力比较大,所以用 DeepSpeed 做显存优化、梯度同步和分布式训练。

--master_port=29500 是分布式通信端口。如果端口被占用,会报类似:

text 复制代码
Address already in use

解决方法:

bash 复制代码
--master_port=29501

或者换成其他端口。


4. --include localhost:4

这个参数非常重要。

官方脚本写的是:

bash 复制代码
--include localhost:4

含义是:只使用本机第 4 号 GPU。

如果你只有一张 RTX 5090,应该改成:

bash 复制代码
--include localhost:0

如果你有 4 张卡:

bash 复制代码
--include localhost:0,1,2,3

如果这里不改,最常见错误是:

text 复制代码
No slot '4' specified on host 'localhost'

或者 DeepSpeed 找不到对应 GPU。


5. --model_name_or_path

bash 复制代码
--model_name_or_path "pretrained_model/mllm/Mipha-3B"

这是基础多模态语言模型路径。

它负责:

text 复制代码
理解文本 prompt
接收图像 token
生成语言输出
输出 hidden states
提供 [SEG] token 的语义表征

注意,它不是专门做 mask 的模块。真正产生 mask 的是后面的 Mask2Former 分割头。


6. --vision_tower

bash 复制代码
--vision_tower "pretrained_model/CLIP/siglip-so400m-patch14-384"

这是给多模态大语言模型使用的视觉塔,也就是 SigLIP。

它的作用是把图像变成视觉 token,再通过 projector 映射到语言模型空间。

你可以理解为:

text 复制代码
原始图像
  ↓
SigLIP
  ↓
视觉 patch embedding
  ↓
mm_projector
  ↓
语言模型可理解的图像 token

7. --vision_tower_mask

bash 复制代码
--vision_tower_mask "pretrained_model/mask2former/model_final_54b88a.pkl"

这个是 Mask2Former / Swin 分割分支的预训练权重。

注意,这和 vision_tower 不是一个东西。

SegEarth-R2 里有两套视觉处理逻辑:

text 复制代码
images_clip → SigLIP → 多模态语言理解
images      → Swin / Mask2Former → 像素级分割

这是你必须记住的第一个关键点。


8. --base_data_path

bash 复制代码
--base_data_path '/data1/xzp/data'

这是数据集根目录。

根据 dataset.py 的逻辑,如果是训练集,它会去找:

text 复制代码
base_data_path/train/images
base_data_path/train/annotations/train_data.json

如果是测试集,它会去找:

text 复制代码
base_data_path/test/images
base_data_path/test/annotations/xxx.json

LaSeRSDataset 会根据 split 参数判断使用 train 还是 test 子目录。(GitHub)

你的路径必须改成自己的数据目录,例如:

bash 复制代码
--base_data_path "/home/yourname/datasets/LaSeRS"

9. --output_dir

bash 复制代码
--output_dir output_folder

训练输出目录。

这里会保存:

text 复制代码
checkpoint
LoRA 权重
trainer state
模型配置
最终保存模型

建议你不要直接叫 output_folder,而是改成清楚的实验名:

bash 复制代码
--output_dir outputs/segearth_r2_lora_r4_bs1_5090

这样后面做多组实验不会混乱。


10. --max_steps 5000

训练 5000 step。

注意它不是 epoch,而是 step。因为使用 HuggingFace Trainer / DeepSpeed 时,经常通过 step 控制训练。

如果你第一次复现,不要直接跑 5000 step。建议先改成:

bash 复制代码
--max_steps 20

先确认:

text 复制代码
数据能读
模型能加载
forward 能通
loss 正常下降或至少能计算
checkpoint 能保存

跑通后再改成:

bash 复制代码
--max_steps 1000

最后再跑完整训练。


11. --per_device_train_batch_size 1

每张 GPU 上 batch size 为 1。

这很正常,因为:

text 复制代码
大语言模型很大
图像尺寸 1024×1024
还有 mask decoder
还要保存 hidden states 和 attention

显存压力非常高。

如果 RTX 5090 显存够,可以尝试 2,但第一次复现建议保持 1。


12. --bf16 True

使用 bfloat16 训练。

RTX 5090 支持 BF16,建议保留。BF16 比 FP16 更稳定,尤其是大模型训练。

不过如果你遇到某些算子 BF16 不支持,可以临时改成:

bash 复制代码
--bf16 False
--fp16 True

但优先建议 BF16。


13. --learning_rate 5e-5

LoRA 微调时 5e-5 是常见学习率。

如果你训练的是 LoRA + 分割头,这个学习率可以先保持。后续做实验可以试:

text 复制代码
1e-5
2e-5
5e-5
1e-4

但第一次复现不要乱改。


14. --lora_r 4

LoRA 秩为 4。

LoRA 的作用是:不全量微调大语言模型,而是在部分线性层上增加低秩可训练矩阵。

你可以理解为:

text 复制代码
冻结大模型大部分参数
只训练少量 LoRA 参数
降低显存
降低训练成本
降低过拟合风险

train.py 中,项目会通过 find_linear_layers() 找到可插入 LoRA 的线性层,再用 get_peft_model() 注入 LoRA。(GitHub)


15. --deepspeed scripts/zero3.json

使用 DeepSpeed ZeRO-3。

ZeRO-3 会切分:

text 复制代码
模型参数
梯度
优化器状态

显存最省,但通信和调试复杂度更高。

如果你单卡 RTX 5090,只是先跑通,可以考虑:

bash 复制代码
--deepspeed scripts/zero2.json

甚至先不用 DeepSpeed 做最小调试,但官方脚本是 DeepSpeed 路线。


16. --mask_config

bash 复制代码
--mask_config 'segearth_r2/model/mask_decoder/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml'

这是 Mask2Former 分割模块的配置文件。

它决定:

text 复制代码
Swin backbone 配置
Mask2Former hidden_dim
query 数量
loss 权重
decoder 层数
训练采样点数

llava_phi.py 中,初始化 mask module 时会读取这个配置,构造 pixel decoder、predictor、criterion 等模块。(GitHub)


17. --data_ratio '1'

控制数据集混合比例。

当前只有 LaSeRS 一个数据源,所以是:

bash 复制代码
--data_ratio '1'

如果以后你把自己的遥感数据加入进去,可以扩展成:

bash 复制代码
--data_ratio '1||1'

表示两个数据集混合。

train.pymake_unify_datamodule() 会读取 data_ratio,然后构建 LaSeRSDatasetUnifyDatasetSingleDatasetForBatch。(GitHub)


18. --switch_bs 4

这个参数用于 UnifyDatasetSingleDatasetForBatch

简单理解:

text 复制代码
每隔 switch_bs 个 batch 切换一次数据集

当前只有一个数据集时,它的意义不大;多个数据集混训时才重要。


三、你应该把 train.sh 改成这样先跑通

你的 RTX 5090 单卡调试版建议先这样:

bash 复制代码
export NCCL_P2P_DISABLE="1"
export NCCL_IB_DISABLE="1"

deepspeed --master_port=29501 --include localhost:0 segearth_r2/train/train.py \
  --model_name_or_path "pretrained_model/mllm/Mipha-3B" \
  --vision_tower "pretrained_model/CLIP/siglip-so400m-patch14-384" \
  --vision_tower_mask "pretrained_model/mask2former/model_final_54b88a.pkl" \
  --base_data_path "/your/path/LaSeRS" \
  --output_dir outputs/debug_5090_lora_r4 \
  --max_steps 20 \
  --per_device_train_batch_size 1 \
  --save_strategy "steps" \
  --save_steps 10 \
  --bf16 True \
  --save_total_limit 1 \
  --learning_rate 5e-5 \
  --weight_decay 0. \
  --warmup_ratio 0.03 \
  --lr_scheduler_type "cosine" \
  --logging_steps 1 \
  --tf32 False \
  --model_max_length 2048 \
  --gradient_checkpointing False \
  --dataloader_num_workers 2 \
  --lora_r 4 \
  --deepspeed scripts/zero3.json \
  --mask_config 'segearth_r2/model/mask_decoder/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml' \
  --data_ratio '1' \
  --switch_bs 4

第一次只跑 20 step。不要一开始跑 5000 step。

你的第一目标不是训练出好结果,而是确认:

text 复制代码
1. 环境没问题
2. 数据能读
3. 模型能加载
4. forward 能通
5. loss 能算
6. checkpoint 能保存

四、第 2 讲:数据链路整体理解

现在开始看 dataset.py。这个文件的作用是:

text 复制代码
原始 JSON + 图片 + mask
        ↓
LaSeRSDataset.__getitem__()
        ↓
构造单样本 data_dict
        ↓
DataCollatorForCOCODatasetV2.__call__()
        ↓
拼成 batch
        ↓
送入 model.forward()

你必须先理解它产生了哪些东西。

一个 batch 大致长这样:

python 复制代码
{
    "input_ids": ...,
    "labels": ...,
    "attention_mask": ...,
    "images_clip": ...,
    "images": ...,
    "seg_info": ...,
    "dataset_type": ...,
    "token_refer_id": ...,
    "SEG_token_embedding_indices": ...,
    "mask_num": ...
}

每个字段含义:

字段 作用
input_ids 文本 prompt + answer token 序列
labels 语言模型监督标签,人类指令部分会被设为 IGNORE_INDEX
attention_mask 标记哪些 token 是有效 token
images_clip 给 SigLIP / 多模态 LLM 使用的图像
images 给 Swin / Mask2Former 使用的 1024 图像
seg_info 存放 mask、image_id、height、width 等分割监督信息
token_refer_id referring instruction 的 token
SEG_token_embedding_indices 标记 input_ids 中哪些位置是 [SEG]
mask_num 当前样本中 [SEG] 数量,也就是需要预测几个 mask

五、dataset.py 的关键函数 1:preprocess_mask

这个函数的作用是:把 mask 按比例缩放,并 padding 到固定大小。

核心逻辑:

text 复制代码
输入 mask
    ↓
如果是二维 mask,就扩展成 [1, H, W]
    ↓
逐个 mask 处理
    ↓
保持长宽比 resize
    ↓
不足部分 padding 0
    ↓
输出 [N, image_size, image_size]

为什么 mask resize 要用最近邻插值?

因为 mask 是类别标签,不是自然图像。

如果用双线性插值,0/1 mask 会变成 0.2、0.5、0.8 这种灰度值,类别边界会被污染。

所以代码里使用的是:

python 复制代码
cv2.INTER_NEAREST

这很关键。


六、dataset.py 的关键函数 2:preprocess_image

这个函数处理的是给 Mask2Former 分割分支使用的图像。

逻辑是:

text 复制代码
读取图像
    ↓
按照短边 resize
    ↓
限制最大边
    ↓
padding 到 1024×1024
    ↓
返回 numpy array

这一步得到的图像后面会进入:

text 复制代码
Swin backbone
pixel decoder
Mask2Former predictor

而不是直接给 SigLIP。

注意这里和 images_clip 是两条不同路线。


七、dataset.py 的关键函数 3:tokenizer_special_tokens

这个函数负责处理特殊 token。

代码里有两个特殊占位:

text 复制代码
IMAGE_TOKEN_INDEX
REFER_TOKEN_INDEX

它通过正则把 prompt 切开,一旦遇到特殊 token,就不用普通 tokenizer 编码,而是直接替换成特殊 index。(GitHub)

你可以理解为:

text 复制代码
普通文本 → tokenizer.encode()
图像占位符 → IMAGE_TOKEN_INDEX
refer 占位符 → REFER_TOKEN_INDEX

这一步是多模态模型常见做法。

因为图像不是普通文字,但它要插入语言模型序列,所以需要一个特殊 token 位置告诉模型:

text 复制代码
这里要插入图像特征

八、dataset.py 的关键函数 4:preprocess_llama2

这个函数是文本监督的核心。

它做了三件事:

第一,构造对话模板

数据被包装成人类和助手的对话:

python 复制代码
[
    {"from": "human", "value": "...instruction..."},
    {"from": "gpt", "value": "...answer..."}
]

然后根据 conversation template 拼成完整 prompt。


第二,把 prompt token 化

调用:

python 复制代码
self.tokenizer_special_tokens(prompt, tokenizer, return_tensors='pt')

得到:

python 复制代码
input_ids

第三,构造 labels,并屏蔽人类问题部分

训练语言模型时,不应该让模型学习"人类输入问题",而应该让模型学习"助手回答"。

所以代码会把 human instruction 部分的 label 设置为:

python 复制代码
IGNORE_INDEX

这样 loss 只在 assistant answer 部分计算。

你可以理解为:

text 复制代码
Human: segment the airport
Assistant: airport [SEG]

训练时:
Human 部分不算 loss
Assistant 部分算 language loss

这也是 LLaVA、Vicuna、Instruct tuning 类项目的标准做法。


九、LaSeRSDataset.__init__

初始化阶段主要做:

text 复制代码
设置 pixel mean/std
保存 base_data_path
保存 tokenizer
根据 split 判断 train/test
读取 annotation json
记录 [SEG] token id

代码逻辑中,如果 split 包含 train,就读取 train/imagestrain/annotations;如果 split 包含 test,就读取 test/imagestest/annotations。(GitHub)

这里要注意一个细节:

python 复制代码
self.SEG_token_id = self.tokenizer.convert_tokens_to_ids("[SEG]")

这说明 [SEG] 必须已经被加入 tokenizer。

而它是在 train.py 里通过:

python 复制代码
tokenizer.add_tokens("[SEG]")
model.resize_token_embeddings(len(tokenizer))

加入的。(GitHub)

所以顺序是:

text 复制代码
train.py 加 [SEG] 到 tokenizer
        ↓
构建 LaSeRSDataset
        ↓
dataset.py 才能拿到 [SEG] token id

如果顺序错了,[SEG] 可能会被当成普通未知 token,整个分割桥梁就断了。


十、LaSeRSDataset.__getitem__ 逐步理解

这是数据文件最重要的函数。


1. 取出一条标注

python 复制代码
data_info = self.reason_file[idx]

一条数据通常包含:

text 复制代码
image_name
description
answer
id
mask

含义:

字段 含义
image_name 图像文件名
description 文本指令
answer 模型应该学习的回答
id 样本 id
mask RLE 格式分割 mask

2. 读取图片路径

python 复制代码
image_path = os.path.join(self.LaSeRS_image_path, data_info['image_name'])

得到完整图像路径。

如果这里报错,通常是:

text 复制代码
base_data_path 不对
image_name 与实际文件名不一致
train/images 目录结构不对

3. 读取文本指令和答案

python 复制代码
ref = data_info['description']
answer = data_info['answer']

description 是给模型看的 instruction。

answer 是模型要学习输出的内容。它里面非常关键的是 [SEG]

例如:

text 复制代码
The target is building [SEG].

模型后面会通过 [SEG] 的 hidden state 预测 mask。


4. 解码 RLE mask

如果 data_info 中有 mask 字段,代码会用 pycocotools 解码:

python 复制代码
mask = M.decode(rle)

然后多个 mask 堆叠成:

python 复制代码
[N, H, W]

这里的 N 应该和 answer 中 [SEG] 的数量一致。

如果:

text 复制代码
answer 中有 2 个 [SEG]
mask 只有 1 个

后面就会错位。

所以你调试数据集时必须检查:

python 复制代码
assert answer.count("[SEG]") == len(mask_list)

5. 读取图像高宽

代码会用 OpenCV 读取 BGR 图像:

python 复制代码
image_BGR = cv2.imread(image_path)
image_height = image_BGR.shape[0]
image_width = image_BGR.shape[1]

这些原始尺寸会保存到 data_dict,后面用于 mask 还原、评估或可视化。


6. 构造分割分支图像

python 复制代码
image_RGB = preprocess_image(image_path)
image_tensor = torch.as_tensor(np.ascontiguousarray(image_RGB.transpose(2, 0, 1)))
data_dict['image'] = (image_tensor - self.pixel_mean) / self.pixel_std

这里得到的是给 Mask2Former 分支的图像。

注意几件事:

text 复制代码
1. 图像变成 C×H×W
2. H 和 W 最终是 1024
3. 使用 ImageNet mean/std 归一化
4. 这个 image 不是给 SigLIP 的

7. 根据 [SEG] 数量构造 annotations

python 复制代码
mask_num = answer.count("[SEG]")

这行极其重要。

它说明:

text 复制代码
answer 中有几个 [SEG]
就认为要预测几个 mask

然后循环:

python 复制代码
for i in range(mask_num):
    data_dict['annotations'].append(...)

每个 annotation 里包含:

text 复制代码
data_id
mask_id
mask
image_path
height
width
image_id

也就是说,[SEG] 和 mask 是一一对应的。

你必须记住这句话:

text 复制代码
SegEarth-R2 的分割监督不是靠类别 id 对齐,而是靠 answer 中的 [SEG] 顺序和 mask 顺序对齐。

8. 构造 human prompt

代码里有类似:

python 复制代码
prefix_inst = 'This is an image \n\n, please doing Reasoning Segmentation according to the following instruction:'
instruction = ref.strip()

然后构造:

python 复制代码
sources = [
    [
        {'from': 'human', 'value': prefix_inst + '\n <|assistant|>'},
        {'from': 'gpt', 'value': '\n' + answer}
    ]
]

这里要理解:

text 复制代码
human 部分:告诉模型看图并根据 instruction 做 reasoning segmentation
gpt 部分:标准答案,里面包含 [SEG]

有一个细节:instruction = ref.strip() 得到了文本指令,但你要认真检查官方代码中它是否真正拼入了 human prompt。如果你后面做自己的数据格式,建议确保 instruction 明确进入 prompt,否则模型可能看不到具体目标描述。


9. 构造 input_ids 和 labels

python 复制代码
text_dict = self.preprocess_llama2(sources, self.tokenizer)

得到:

python 复制代码
input_ids
labels

其中:

text 复制代码
input_ids:完整 token 序列
labels:训练监督标签

human 输入部分的 labels 会被设为 IGNORE_INDEX,assistant answer 部分才计算语言模型 loss。


10. 构造 [SEG] 位置标记

python 复制代码
SEG_token_embedding_indices = torch.zeros_like(input_ids)
SEG_token_embedding_indices[input_ids == self.SEG_token_id] = 1

这一步是整个项目的核心之一。

它的含义是:

text 复制代码
在 input_ids 中找到所有 [SEG]
标记这些位置
后面从 LLM hidden states 中取这些位置的 hidden state

之后在 llava_phi.py 里会调用:

python 复制代码
self.get_SEG_embedding(hidden_states, SEG_token_embedding_indices)

取出 [SEG] 对应 hidden state,再经过 SEG_token_projector 映射到 Mask2Former hidden dim。llava_phi.py 中可以看到 get_SEG_embedding 负责根据索引取出 SEG embedding,随后用 SEG_token_projector 处理。(GitHub)


十一、DataCollatorForCOCODatasetV2.__call__

这个类负责把多个样本拼成一个 batch。

你可以理解为:

text 复制代码
Dataset.__getitem__ 只返回一个样本
DataCollator 把多个样本整理成 batch

它主要做 5 件事。


1. padding 文本

因为每个样本文本长度不同,所以需要 padding。

python 复制代码
input_ids = pad_sequence(...)
labels = pad_sequence(...)

然后生成:

python 复制代码
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)

2. 构造 images_clip

代码会重新读取原图,然后用 SigLIP 的 image processor 处理:

python 复制代码
self.clip_image_processor.preprocess(..., return_tensors="pt")

这一路给 SigLIP / 多模态 LLM 用。

注意:这是 images_clip,不是 images


3. 构造 images

python 复制代码
images = [instance['image'] for instance in instances]
batch['images'] = torch.stack(images)

这一路是前面 preprocess_image() 得到的 1024 图像,用于 Swin / Mask2Former 分割分支。

所以 batch 中有两个图像字段:

text 复制代码
images_clip:语言视觉塔
images:分割视觉塔

这个设计非常重要。


4. 构造 seg_info

每个 annotation 被整理到 seg_info 里。

其中 mask 会被转成 tensor:

python 复制代码
seg['mask'] = torch.as_tensor(seg['mask'], dtype=torch.uint8)

后面 llava_phi.py 会根据 seg_info 构造 mask targets,并传给 criterion 计算 loss_maskloss_dice。相关代码中可以看到 targets 包含 labelsmasks,之后调用 self.criterion(mask_outputs, targets),并汇总 loss_maskloss_dice。(GitHub)


5. padding [SEG] 位置标记

python 复制代码
SEG_token_embedding_indices = pad_sequence(...)

因为不同样本长度不同,所以 [SEG] 标记也要 padding。

最终 batch 会包含:

python 复制代码
batch['SEG_token_embedding_indices']

这个字段后面会直接传入模型 forward。


十二、到这里你应该形成第一张"数据流图"

你现在必须能自己画出下面这张图:

text 复制代码
train_data.json
    ↓
data_info = reason_file[idx]
    ↓
image_name → image_path
description → instruction
answer → answer.count("[SEG]")
mask RLE → M.decode → masks
    ↓
preprocess_image → images
SigLIP processor → images_clip
preprocess_llama2 → input_ids + labels
input_ids == [SEG] → SEG_token_embedding_indices
    ↓
DataCollator
    ↓
batch
    ↓
model.forward(
    input_ids,
    labels,
    images,
    images_clip,
    seg_info,
    SEG_token_embedding_indices,
    mask_num
)

你如果能把这张图讲清楚,就说明你已经吃透了数据部分的 50%。


十三、复现时必须先做的 5 个数据检查

在正式训练前,你一定要写一个小调试脚本,或者在 train.py 里临时打印一个 batch。

你要检查:

python 复制代码
print(batch.keys())
print(batch["input_ids"].shape)
print(batch["labels"].shape)
print(batch["images"].shape)
print(batch["images_clip"].shape)
print(batch["SEG_token_embedding_indices"].shape)
print(batch["mask_num"])
print(len(batch["seg_info"]))

你希望看到类似:

text 复制代码
input_ids: [B, L]
labels: [B, L]
images: [B, 3, 1024, 1024]
images_clip: [B, 3, 384, 384] 或类似 SigLIP 输入尺寸
SEG_token_embedding_indices: [B, L]
mask_num: [每个样本的 mask 数量]
seg_info: 总 mask 数量

再检查:

python 复制代码
print(batch["SEG_token_embedding_indices"].sum())
print(sum(batch["mask_num"]))

这两个值应该相等。

因为:

text 复制代码
batch 中 [SEG] 的总数 = batch 中需要预测的 mask 总数

如果不相等,后面 [SEG] embedding 和 mask target 就会错位。


十四、你现在应该掌握的关键结论

到这一讲为止,你要记住 8 个结论。

结论 1

train.sh 是复现入口,不是普通配置文件。所有关键路径、训练策略、LoRA、DeepSpeed、mask config 都从这里进入。

结论 2

官方脚本默认 --include localhost:4,你的 RTX 5090 单卡应该改成 localhost:0

结论 3

SegEarth-R2 有两路图像输入:

text 复制代码
images_clip → SigLIP / LLM
images      → Swin / Mask2Former

结论 4

answer.count("[SEG]") 决定当前样本预测几个 mask。

结论 5

[SEG] token 是语言模型和分割模型之间的桥。

结论 6

SEG_token_embedding_indices 是为了告诉模型:哪些 hidden state 要拿去做 mask query。

结论 7

labels 不是简单等于 input_ids,human instruction 部分会被 IGNORE_INDEX 屏蔽。

结论 8

训练前必须检查:

text 复制代码
[SEG] 数量 == mask 数量

否则训练一定出问题。


十五、下一阶段的阅读任务

下一阶段进入 train.py 主训练流程。重点看:

text 复制代码
1. ModelArguments
2. DataArguments
3. TrainingArguments
4. get_mask_config()
5. SegEarthR2.from_pretrained()
6. model.initial_mask_module()
7. tokenizer.add_tokens("[SEG]")
8. model.resize_token_embeddings()
9. find_linear_layers()
10. get_peft_model()
11. train_module_list
12. make_unify_datamodule()
13. LLaVATrainer

下一讲你要重点吃透这句话:

text 复制代码
train.py 决定了:模型怎么加载、哪些参数训练、哪些参数冻结、[SEG] 怎么加入、LoRA 怎么注入、数据怎么送进 Trainer。

而再下一讲才进入最核心的:

text 复制代码
llava_phi.py:从 [SEG] hidden state 到 mask 的全过程。
相关推荐
KaMeidebaby1 小时前
卡梅德生物技术快报 | Fab 合成文库构建与抗体筛选实验流程及数据解析
人工智能·python·tcp/ip·算法·机器学习
装不满的克莱因瓶1 小时前
掌握3D CNN模型结构——从时空特征建模到视频理解与医学影像核心架构
人工智能·pytorch·python·深度学习·神经网络·3d·cnn
YOLO数据集集合1 小时前
无人机航拍RGBT双模态行人检测数据集 | 可见光红外对齐 低空小目标检测 多模态计算机视觉基准数据
人工智能·深度学习·目标检测·计算机视觉·无人机
古希腊掌管代码的神THU1 小时前
解析 MiniMax M3 多模态大模型的架构/源码?
人工智能·深度学习·自然语言处理·面试
动物园猫1 小时前
用于实验室智能识别的目标检测数据集分享(适用于YOLO系列深度学习分类检测任务)
深度学习·yolo·目标检测
君为先-bey2 小时前
LightningDiT----重建与生成:在潜在扩散模型中驯服优化困境
深度学习·扩散模型·视频生成·潜在扩散模型
jay神2 小时前
基于 YOLOv8 + CRNN 的车牌识别系统
深度学习·yolo·目标检测·计算机视觉·车牌识别
装不满的克莱因瓶2 小时前
掌握基于YOLO v5实现车牌目标检测任务的完整流程——从数据到部署的工业级实践
人工智能·python·深度学习·yolo·目标检测·计算机视觉·目标跟踪
安逸sgr2 小时前
《图解机器学习-第六章》:线性回归和逻辑回归:最简单但最重要的机器学习模型
机器学习·逻辑回归·线性回归