继续。下面进入 第 5 讲:llava_phi.py 模型核心结构精读 。这是 SegEarth-R2 最重要的源码文件之一,官方 GitHub 显示该文件约 945 行,位于 segearth_r2/model/language_model/llava_phi.py,里面定义了模型输出结构、注意力监督损失、SegEarthR2Model、SegEarthR2 主类、mask decoder 初始化、[SEG] embedding 提取、训练 forward、评估 eval_seg 和推理接口。(GitHub)
第 5 讲:llava_phi.py 总体定位
你可以把 llava_phi.py 理解成 SegEarth-R2 的"大脑 + 分割头连接器"。
前面我们已经知道:
text
dataset.py 负责造 batch
train.py 负责加载模型和训练配置
llava_phi.py 负责真正执行 forward
llava_phi.py 主要解决一个核心问题:
text
语言模型输出的是 token hidden state,
但分割任务需要 pixel-level mask,
那么如何把 token hidden state 变成 mask?
SegEarth-R2 的答案是:
text
[SEG] token hidden state
↓
SEG_token_projector
↓
Mask2Former query
↓
pixel_decoder + predictor
↓
pred_masks
这就是整个项目的灵魂。
一、先看文件里的 4 个核心对象
llava_phi.py 里最重要的是这几个对象:
text
1. CausalOutputWithMask
2. AttentionLoss
3. SegEarthR2Model
4. SegEarthR2
1. CausalOutputWithMask
这个类继承自 HuggingFace 的 CausalLMOutputWithPast。
普通语言模型输出通常包括:
text
loss
logits
past_key_values
hidden_states
attentions
但 SegEarth-R2 还要输出分割相关 loss,所以它额外加入:
text
loss_mask
loss_dice
loss_llm
loss_attention
也就是说,这个模型不是普通 LLM,而是"语言建模 + 分割监督"的复合模型。源码中 CausalOutputWithMask 明确包含 loss_mask、loss_dice、loss_llm 和 loss_attention 字段。(GitHub)
你读到这里时要建立一个意识:
text
logits 负责语言输出
pred_masks 负责分割输出
loss_llm 监督文本生成
loss_mask / loss_dice 监督分割结果
loss_attention 监督 [SEG] 对图像区域的注意力
2. AttentionLoss
AttentionLoss 的作用是让 [SEG] token 对图像 token 的注意力更接近真实 mask 区域。
通俗讲:
text
如果问题是"分割机场跑道",
那么 [SEG] token 不应该平均看整张图,
而应该更多关注跑道所在区域。
代码逻辑大致是:
text
取出 [SEG] 对图像 token 的 attention
根据 gt_mask 把 attention 分成目标区域和非目标区域
计算目标区域 attention 与背景均值之间的差异
形成 attention loss
源码中 AttentionLoss.forward() 会从 attention map 中分别取出 mask 区域和非 mask 区域,并用差异构造损失。(GitHub)
这对论文理解很重要,因为它说明 SegEarth-R2 不是只靠 mask loss 训练分割头,还试图让语言模型的注意力机制也对齐到目标区域。
3. SegEarthR2Model
SegEarthR2Model 继承自 MiphaPhiModel。它更接近"基础模型骨架",主要负责:
text
继承 Mipha/Phi 语言模型能力
保存 mask decoder 配置
构建 vision_tower_mask
提供 get_vision_tower()
提供 get_vision_tower_mask()
初始化视觉模块
源码中 SegEarthR2Model 在初始化时会根据 swin_type 构建 build_swin_b 或 build_swin_l,并把 IVSDatasetMapper 作为 mask 视觉分支的 image processor。(GitHub)
这里要注意两个名字:
text
vision_tower → 多模态 LLM 的视觉塔,通常是 SigLIP
vision_tower_mask → 分割分支视觉骨干,通常是 Swin
这和前面 dataset.py 里的两个图像输入完全对应:
text
images_clip → vision_tower
images → vision_tower_mask
4. SegEarthR2
SegEarthR2 继承自 MiphaPhiForCausalLM,这是整个工程真正的主模型。
它负责:
text
语言模型 forward
多模态输入拼接
mask decoder 初始化
[SEG] embedding 提取
mask 输出
loss 计算
eval_seg
inference
源码中 SegEarthR2 初始化时会创建 SegEarthR2Model,设置 lm_head,并根据 mask_decode_train 判断是否初始化 mask decoder。(GitHub)
二、constants.py:几个特殊 token 的意义
理解 llava_phi.py 前,必须理解特殊 token index。
官方 constants.py 中定义了:
text
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
SEG_TOKEN_INDEX = -201
CLS_TOKEN_INDEX = -202
REGION_TOKEN_INDEX = -203
REFER_TOKEN_INDEX = -204
TEMPORAL_TOKEN_INDEX = -205
其中最重要的是:
text
IGNORE_INDEX → 语言 loss 中被忽略的位置
IMAGE_TOKEN_INDEX → 文本序列中图像占位符
REFER_TOKEN_INDEX → referring instruction 占位符
这些负数 index 不是真正词表 token,而是工程中用于"占位"的特殊标记。constants.py 明确将 IGNORE_INDEX 设为 -100,IMAGE_TOKEN_INDEX 设为 -200,REFER_TOKEN_INDEX 设为 -204。(GitHub)
这很关键,因为 concat_image_seg_cls_embeds() 会扫描 input_ids,遇到普通 token 就用 embedding,遇到 IMAGE_TOKEN_INDEX 就插入图像特征,遇到 REFER_TOKEN_INDEX 就插入 refer embedding。源码中 concat_image_seg_cls_embeds() 明确断言当前只支持一个 image index,并在遇到 IMAGE_TOKEN_INDEX 时插入图像特征。(GitHub)
三、核心机制 1:Mask decoder 怎么初始化
函数:
python
initial_mask_module()
它做了 5 件事:
text
1. 创建 AttentionLoss
2. 根据 mask_config 创建 pixel_decoder
3. 创建 predictor
4. 创建 SEG_token_projector
5. 初始化 criterion,也就是 mask / dice loss 计算器
最关键的一句是:
python
self.SEG_token_projector = nn.Linear(
self.config.hidden_size,
self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM
)
含义是:
text
LLM hidden size
↓
Mask2Former hidden_dim
也就是说,语言模型的 hidden state 不能直接给 Mask2Former 用,需要先经过一个线性投影层。源码中 initial_mask_module() 会创建 pixel_decoder、predictor、SEG_token_projector,并初始化 mask decoder 的训练损失。(GitHub)
你可以把它理解成:
text
LLM 世界:token embedding
Mask2Former 世界:query embedding
SEG_token_projector:两个世界之间的翻译器
四、核心机制 2:两路视觉特征
llava_phi.py 里有两个非常容易混淆的函数。
1. encode_images(images)
这个函数服务于 多模态语言模型。
逻辑是:
text
images_clip
↓
vision_tower / SigLIP
↓
mm_projector
↓
image token embeddings
↓
插入 LLM 输入序列
它把 SigLIP 输出的视觉特征映射到语言模型 hidden size。
2. get_vision_tower_feature(images)
这个函数服务于 分割分支。
逻辑是:
text
images
↓
vision_tower_mask / Swin
↓
res2, res3, res4, res5
↓
pixel_decoder
↓
multi_scale_features + mask_features
源码中 get_vision_tower_feature() 会调用 get_vision_tower_mask()(images),然后组织成 res2、res3、res4、res5 四级特征。(GitHub)
所以你必须记住:
text
images_clip 不是为了分割像素,而是为了让 LLM 看懂图像
images 不是为了语言理解,而是为了产生高分辨率分割特征
五、核心机制 3:多模态输入怎么拼进 LLM
关键函数:
python
prepare_inputs_labels_for_multimodal()
它的任务是:
text
把 input_ids 里的 <image> 占位符替换成真实图像 token embedding
把普通文本 token 变成文本 embedding
同步更新 labels
同步更新 attention_mask
同步更新 SEG_token_embedding_indices
同步记录哪些位置是 image feature
源码中 prepare_inputs_labels_for_multimodal() 会先通过 encode_images(images) 得到图像特征,然后逐个 batch 样本处理 input ids,并构造新的 inputs_embeds、labels、SEG_token_embedding_indices 和 image_features_indices。(GitHub)
这一步非常关键,因为原始 input_ids 是这样的:
text
[文本 token, IMAGE_TOKEN_INDEX, 文本 token, REFER_TOKEN_INDEX, 文本 token, [SEG]]
经过处理后变成:
text
[文本 embedding, 图像 patch embeddings, 文本 embedding, refer embedding, 文本 embedding, [SEG] embedding]
也就是说,LLM 最终看到的不是单纯文本,而是"文本 embedding + 图像 embedding"的混合序列。
六、核心机制 4:concat_image_seg_cls_embeds()
这个函数是"替换占位符"的具体执行者。
它会扫描单个样本的 input_id:
text
普通 token
↓
embed_tokens()
IMAGE_TOKEN_INDEX
↓
替换为 img_feature
REFER_TOKEN_INDEX
↓
替换为 refer_embedding
它还会同步处理三类伴随信息:
text
labels
SEG_token_embedding_indices
image_features_indices
为什么这很重要?
因为插入图像 token 后,序列长度变了。例如原始 <image> 只占 1 个位置,但 SigLIP 图像 patch 可能占几百个 token。此时 [SEG] 的位置、label 的位置、attention mask 的位置都必须一起重新对齐。
你可以这样理解:
text
原序列:
A B <image> C [SEG]
替换后:
A B img1 img2 img3 ... img729 C [SEG]
所以 [SEG] 原本可能是第 5 个 token,替换后可能变成第 733 个位置。如果 SEG_token_embedding_indices 不同步更新,后面就会取错 hidden state。
七、核心机制 5:get_SEG_embedding()
函数:
python
get_SEG_embedding(hidden_states, SEG_embedding_indices)
作用非常直接:
text
遍历 batch 中每个样本
根据 SEG_embedding_indices 找到 [SEG] 位置
取出 hidden_states 中对应 token 的向量
拼接成 SEG_embedding
源码中 get_SEG_embedding() 会根据每个样本的 bool index 从 hidden_states 取出对应状态,然后 torch.cat 并 unsqueeze(1)。(GitHub)
这一步是 SegEarth-R2 的核心桥梁:
text
语言模型输出 hidden_states
↓
取出 [SEG] 位置
↓
得到 SEG_embedding
↓
投影到 Mask2Former query 空间
你可以把 [SEG] 理解成"语言模型生成的分割查询"。
八、forward() 主流程逐步拆解
现在进入最核心的 forward()。
函数输入包括:
python
input_ids
attention_mask
past_key_values
inputs_embeds
labels
images
images_clip
seg_info
token_refer_id
SEG_token_embedding_indices
global_step
mask_num
dataset_type
源码中 forward() 参数明确同时接收 images 和 images_clip,也接收 seg_info、SEG_token_embedding_indices 和 mask_num。(GitHub)
第 1 步:检查 batch 数据类型
源码中会检查:
text
如果 dataset_type 不为空,
则 batch 内所有 dataset_type 必须一致。
这说明作者原本考虑过多数据集混训,但一个 batch 内不希望混入不同任务类型。
第 2 步:打开 attention 输出
text
output_attentions = True
为什么一定要输出 attention?
因为后面要计算 loss_attention。
也就是说,这个 forward 不只是需要 hidden states,还需要每层 attention map,用来约束 [SEG] 对图像区域的关注。
第 3 步:如果有 [SEG],先提取分割图像特征
代码逻辑是:
text
如果 SEG_token_embedding_indices 不为空,
且其中确实存在 [SEG],
并且当前不是单 token 生成阶段,
则调用 get_vision_tower_feature(images)
这一步得到的是分割分支的 Swin 多尺度特征。源码中在 SEG_token_embedding_indices 存在且非空时,会调用 get_vision_tower_feature(images) 得到 image_features。(GitHub)
注意:这里用的是 images,不是 images_clip。
第 4 步:构造 LLM 的 inputs_embeds
然后调用:
python
prepare_inputs_labels_for_multimodal(
input_ids,
attention_mask,
past_key_values,
labels,
images_clip,
token_refer_id=token_refer_id,
SEG_token_embedding_indices=SEG_token_embedding_indices
)
注意这里传入的是 images_clip。
它会把 SigLIP 特征插入到语言模型 token 序列中。源码中 forward 会调用 prepare_inputs_labels_for_multimodal(..., images_clip, ...) 来生成 inputs_embeds、新的 attention_mask、新的 labels、新的 [SEG] 索引和图像特征索引。(GitHub)
所以 forward 内部再次体现了双视觉流:
text
images → Swin → mask decoder
images_clip → SigLIP → LLM input embeddings
第 5 步:进入语言模型
代码逻辑:
text
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=True,
return_dict=return_dict
)
随后:
text
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
源码中 forward() 会调用底层模型得到 outputs,再取 last_hidden_state 并通过 lm_head 得到 logits。(GitHub)
含义:
text
hidden_states → 用于取 [SEG] embedding
logits → 用于计算语言模型 loss
attentions → 用于 attention loss
第 6 步:取出 [SEG] embedding
关键逻辑:
text
SEG_embedding =
SEG_token_projector(
get_SEG_embedding(hidden_states, SEG_token_embedding_indices)
)
源码中这一句在 forward 里非常明确:先从 hidden states 提取 [SEG] embedding,再经过 SEG_token_projector。(GitHub)
这一步之后,SEG_embedding 已经不再只是语言模型向量,而是可以作为 mask decoder query 使用的向量。
你要把这句话背下来:
text
SegEarth-R2 的分割 query 不是随机学习的静态 query,
而是来自语言模型 [SEG] token 的上下文 hidden state。
这也是它能够做到"语言引导分割"的关键。
第 7 步:pixel decoder 提取 mask features
如果 image_features 不为空,执行:
text
mask_features, transformer_encoder_features, multi_scale_features =
pixel_decoder.forward_features(image_features)
源码中 forward 会调用 pixel_decoder.forward_features(image_features) 生成 mask_features 和 multi_scale_features。(GitHub)
这一步的含义是:
text
Swin backbone 的 res2/res3/res4/res5
↓
MSDeformAttnPixelDecoder
↓
Mask2Former 可用的多尺度特征
第 8 步:根据 mask_num 复制图像特征
这一段很重要,很多人第一次看会不理解:
text
mask_features = repeat_interleave(mask_features, repeats=mask_num, dim=0)
multi_scale_features = [
repeat_interleave(feat, repeats=mask_num, dim=0)
for feat in multi_scale_features
]
为什么要复制?
因为一个 batch 中,一张图可能有多个 [SEG],例如:
text
图像 A:2 个 [SEG] → 2 个 mask
图像 B:1 个 [SEG] → 1 个 mask
原始图像特征数量是按图像来的:
text
B 张图 → B 份图像特征
但 mask query 是按 [SEG] 来的:
text
总共 3 个 [SEG] → 3 个 SEG_embedding
所以要把图像特征复制到和 SEG_embedding 数量一致。源码中 forward 会把 mask_features 和 multi_scale_features 按 mask_num 做 repeat_interleave,再传给 predictor。(GitHub)
第 9 步:predictor 输出 mask
关键逻辑:
text
mask_outputs = predictor(
multi_scale_features,
mask_features,
None,
None,
SEG_embedding
)
源码中 mask_outputs = self.predictor(..., SEG_embedding),说明 Mask2Former predictor 是被 [SEG] embedding 条件化驱动的。(GitHub)
输出里最重要的是:
text
mask_outputs["pred_masks"]
这就是模型预测的分割 mask。
九、训练 loss 怎么算
forward() 中 loss 分为 3 类:
text
1. llm_loss
2. mask_loss + dice_loss
3. attention_loss
1. llm_loss
源码中计算语言 loss 的方式是标准 causal LM:
text
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
CrossEntropyLoss(shift_logits, shift_labels)
也就是"前一个 token 预测下一个 token"。源码中 forward 明确用 CrossEntropyLoss 计算 llm_loss。(GitHub)
这部分监督模型生成 answer,例如:
text
The target is [SEG].
注意,human prompt 部分的 label 在 dataset 中已经被设成 IGNORE_INDEX,所以这里主要监督 assistant answer。
2. mask_loss 和 dice_loss
如果 seg_info 不为空,源码会根据 seg_info 构造 targets。对于包含 mask 的情况,每个 target 包括:
text
labels
masks
valid
inst_id
然后调用:
text
mask_losses = self.criterion(mask_outputs, targets)
接着按 weight_dict 汇总 loss_mask 和 loss_dice。源码中 forward 在 seg_info 存在时会构造 targets,并通过 self.criterion(mask_outputs, targets) 计算分割损失。(GitHub)
含义:
text
mask_loss 关注像素级预测是否正确
dice_loss 关注预测区域和真实区域重叠程度
3. attention_loss
源码里会把真实 mask resize 到 800×800,再下采样到 27×27,最后和 [SEG] token 对图像 token 的 attention 做监督。forward() 中会构造 masks_down,然后遍历 attentions,取出 [SEG] 对 image features 的 attention,并加入 loss_attention。(GitHub)
你可以这样理解:
text
mask_loss 监督最后输出的 mask
attention_loss 监督中间注意力是否看向目标
最终总 loss 是:
text
loss = llm_loss + mask_loss + 0.01 * loss_attention
源码中 forward 明确使用 loss = llm_loss + mask_loss + 0.01 * loss_attention,并返回 loss_mask、loss_dice、loss_llm 和 loss_attention。(GitHub)
十、把 forward 画成完整数据流
你现在应该能画出这张图:
text
batch from dataset.py
│
├── input_ids
├── labels
├── attention_mask
├── images_clip
├── images
├── SEG_token_embedding_indices
├── seg_info
└── mask_num
↓
llava_phi.py forward
↓
images_clip
↓
SigLIP vision_tower
↓
mm_projector
↓
image token embeddings
↓
prepare_inputs_labels_for_multimodal
↓
LLM input embeddings
images
↓
Swin vision_tower_mask
↓
res2/res3/res4/res5
↓
pixel_decoder
↓
mask_features + multi_scale_features
LLM input embeddings
↓
Phi/Mipha language model
↓
hidden_states + logits + attentions
hidden_states + SEG_token_embedding_indices
↓
get_SEG_embedding
↓
SEG_token_projector
↓
SEG_embedding
SEG_embedding + mask_features + multi_scale_features
↓
Mask2Former predictor
↓
pred_masks
logits + labels
↓
llm_loss
pred_masks + seg_info
↓
mask_loss + dice_loss
attentions + gt mask
↓
attention_loss
最终:
loss = llm_loss + mask_loss + 0.01 * attention_loss
这张图是你吃透 SegEarth-R2 的核心图。后面所有代码都围绕这张图展开。
十一、eval_seg():评估阶段怎么生成 mask
训练 forward 是为了算 loss,eval_seg() 是为了输出预测结果。
eval_seg() 的主要流程是:
text
1. get_vision_tower_feature(images)
2. prepare_inputs_labels_for_multimodal(..., images_clip, ...)
3. self.model(...) 得到 hidden_states
4. get_SEG_embedding()
5. SEG_token_projector()
6. pixel_decoder.forward_features()
7. repeat_interleave 图像特征
8. predictor(..., SEG_embedding)
9. 取 pred_masks
10. resize 到图像尺寸
11. 生成 processed_results
源码中 eval_seg() 会先提取 image_features,然后构造多模态输入,取 last_hidden_state,提取 [SEG] embedding,经过 pixel decoder 和 predictor 得到 pred_masks。(GitHub)
评估阶段还会把预测 mask 插值到图像 tensor 尺寸,并把预测 mask、GT mask、image name、data id 和 mask id 放进 processed_results。源码中 eval_seg() 将 mask_pred_results 插值到 ImageList 的尺寸,并返回包含 pred、gt、image_name、id、mask_id 的结果字典。(GitHub)
十二、训练 forward 和 eval_seg 的区别
你要把这两个区别讲清楚:
text
forward:
训练用
计算 logits
计算 llm_loss
计算 mask_loss / dice_loss
计算 attention_loss
返回 CausalOutputWithMask
eval_seg:
测试用
主要生成 mask
不计算训练 loss
返回 processed_results
也就是说:
text
forward 关心 loss
eval_seg 关心 pred mask
十三、这部分最容易出错的 8 个地方
1. [SEG] 数量和 mask 数量不一致
表现:
text
repeat_interleave shape 对不上
criterion 报 target 数量不匹配
pred_masks 和 gt masks 对不齐
检查:
python
print(SEG_token_embedding_indices.sum())
print(sum(mask_num))
print(len(seg_info))
三者应该一致。
2. images 和 images_clip 搞反
表现:
text
SigLIP 输入尺寸不对
Swin 输入尺寸不对
pixel_decoder shape 报错
正确关系:
text
images_clip → prepare_inputs_labels_for_multimodal → LLM
images → get_vision_tower_feature → pixel_decoder
3. SEG_token_embedding_indices 没有随图像 token 插入同步扩展
如果你自己改 prompt 或 tokenizer,最容易破坏这里。prepare_inputs_labels_for_multimodal() 和 concat_image_seg_cls_embeds() 会同步更新 [SEG] 位置,否则就会从 hidden states 中取错 token。
4. mask_num 不是 list 或 tensor 形状异常
mask_num 用于 repeat_interleave。如果 batch 中某个样本没有 [SEG],或者 mask_num 与 batch size 不一致,就容易报错。
5. seg_info 为空
如果 seg_info 为空,mask loss 无法计算。
检查 dataset:
text
annotation 是否有 mask
RLE 是否能 decode
answer 是否包含 [SEG]
DataCollator 是否把 seg_info 传进 batch
6. Mask2Former 预训练权重 key 不匹配
initial_mask_module() 里会从 checkpoint 中抽取 sem_seg_head.pixel_decoder 和 sem_seg_head.predictor 权重,并做部分 key name 转换。如果权重版本不匹配,可能出现 missing keys 或 unexpected keys。源码中初始化函数会对 pixel decoder 和 predictor 权重做 key 处理,并分别 load_state_dict(..., strict=False)。(GitHub)
7. attention loss 出 NaN
可能原因:
text
mask 全 0
attention_map_target 为空
BF16 下数值不稳定
masks_down 尺寸不符合 image token 数
AttentionLoss 内部有 epsilon 避免 log(0),但如果输入本身异常,仍可能有问题。源码中 AttentionLoss 使用 epsilon = 1e-8,并在 target 区域为空时跳过。(GitHub)
8. 图像 token 数和 attention mask 下采样尺寸不一致
代码里 attention loss 把 mask 下采样到 27×27,对应 729 个图像 token。源码中注释也显示 attention 的目标形状类似 [1, 729] 或 [4, 729]。(GitHub)
如果你换了 SigLIP 输入尺寸、patch size 或视觉塔,27×27 可能就不对。这是后续改模型时必须注意的点。
十四、博士视角:这个文件可以怎么改
你做"多模态融合驱动的遥感开放世界目标识别研究",最值得改的不是 train.sh,而是 llava_phi.py 这几处:
方向 1:改 SEG_token_projector
当前是简单线性层:
text
LLM hidden state → Mask2Former hidden dim
你可以改成:
text
MLP projector
门控 projector
MoE projector
模态感知 projector
类别不确定性感知 projector
对应你的博士创新:
text
多模态协同表征与跨模态语义对齐
方向 2:改 predictor(..., SEG_embedding)
现在 predictor 使用 [SEG] embedding 条件化 mask decoding。
你可以思考:
text
是否可以加入多个 [SEG] 之间的关系建模?
是否可以加入未知类别原型?
是否可以加入开放世界置信度?
是否可以让 query 同时包含文本语义和视觉区域原型?
对应你的博士创新:
text
开放世界目标识别与未知类别发现机制
方向 3:改 attention loss
当前 attention loss 主要约束 [SEG] 对图像 token 的注意力和 GT mask 的关系。
你可以改成:
text
边界增强 attention loss
小目标增强 attention loss
多尺度 attention loss
跨模态一致性 attention loss
未知区域对比 attention loss
对应你的博士创新:
text
已知目标精准判别 + 未知目标主动发现
方向 4:改两路视觉分支
现在是:
text
SigLIP → LLM
Swin → Mask2Former
你可以做:
text
光学 / SAR / 红外 多模态输入
不同模态单独 encoder
共享语义 projector
MoE 动态融合
缺失模态条件化表示
对应你的博士创新:
text
物理-语义联合跨模态对齐
十五、本讲你必须掌握的 12 个结论
llava_phi.py是 SegEarth-R2 的核心模型文件。CausalOutputWithMask让模型同时返回语言 loss 和分割 loss。SegEarthR2Model管基础多模态模型和 Swin mask 视觉塔。SegEarthR2管完整 forward、mask decoder、loss 和推理。images_clip进入 SigLIP 和 LLM。images进入 Swin 和 Mask2Former。prepare_inputs_labels_for_multimodal()把图像 token 插入 LLM 输入序列。concat_image_seg_cls_embeds()负责替换IMAGE_TOKEN_INDEX和REFER_TOKEN_INDEX。get_SEG_embedding()从 hidden states 中取出[SEG]对应向量。SEG_token_projector把 LLM hidden state 映射成 Mask2Former query。predictor(..., SEG_embedding)输出语言引导的 mask。- 总 loss 是
llm_loss + mask_loss + 0.01 * attention_loss。
十六、下一讲继续拆什么
下一讲进入 第 6 讲:eval.py、test.sh 和推理可视化流程。
重点解决:
text
1. 测试脚本怎么启动
2. 模型怎么加载
3. LoRA 合并后怎么推理
4. eval.py 怎么构造 prompt
5. eval_seg() 怎么被调用
6. pred mask 怎么保存
7. 如何验证复现结果是否正确
8. 如何做自己的遥感图片推理
到这里,你已经吃透了 SegEarth-R2 最关键的一条链:
text
[SEG] token hidden state → SEG_token_projector → Mask2Former query → pred_mask
后面只要把评估和环境报错解决,你就具备复现和二次改造这个项目的基础了。