Segearth-r2-05

继续。下面进入 第 5 讲:llava_phi.py 模型核心结构精读 。这是 SegEarth-R2 最重要的源码文件之一,官方 GitHub 显示该文件约 945 行,位于 segearth_r2/model/language_model/llava_phi.py,里面定义了模型输出结构、注意力监督损失、SegEarthR2Model、SegEarthR2 主类、mask decoder 初始化、[SEG] embedding 提取、训练 forward、评估 eval_seg 和推理接口。(GitHub)


第 5 讲:llava_phi.py 总体定位

你可以把 llava_phi.py 理解成 SegEarth-R2 的"大脑 + 分割头连接器"。

前面我们已经知道:

text 复制代码
dataset.py 负责造 batch
train.py 负责加载模型和训练配置
llava_phi.py 负责真正执行 forward

llava_phi.py 主要解决一个核心问题:

text 复制代码
语言模型输出的是 token hidden state,
但分割任务需要 pixel-level mask,
那么如何把 token hidden state 变成 mask?

SegEarth-R2 的答案是:

text 复制代码
[SEG] token hidden state
        ↓
SEG_token_projector
        ↓
Mask2Former query
        ↓
pixel_decoder + predictor
        ↓
pred_masks

这就是整个项目的灵魂。


一、先看文件里的 4 个核心对象

llava_phi.py 里最重要的是这几个对象:

text 复制代码
1. CausalOutputWithMask
2. AttentionLoss
3. SegEarthR2Model
4. SegEarthR2

1. CausalOutputWithMask

这个类继承自 HuggingFace 的 CausalLMOutputWithPast

普通语言模型输出通常包括:

text 复制代码
loss
logits
past_key_values
hidden_states
attentions

但 SegEarth-R2 还要输出分割相关 loss,所以它额外加入:

text 复制代码
loss_mask
loss_dice
loss_llm
loss_attention

也就是说,这个模型不是普通 LLM,而是"语言建模 + 分割监督"的复合模型。源码中 CausalOutputWithMask 明确包含 loss_maskloss_diceloss_llmloss_attention 字段。(GitHub)

你读到这里时要建立一个意识:

text 复制代码
logits 负责语言输出
pred_masks 负责分割输出
loss_llm 监督文本生成
loss_mask / loss_dice 监督分割结果
loss_attention 监督 [SEG] 对图像区域的注意力

2. AttentionLoss

AttentionLoss 的作用是让 [SEG] token 对图像 token 的注意力更接近真实 mask 区域。

通俗讲:

text 复制代码
如果问题是"分割机场跑道",
那么 [SEG] token 不应该平均看整张图,
而应该更多关注跑道所在区域。

代码逻辑大致是:

text 复制代码
取出 [SEG] 对图像 token 的 attention
根据 gt_mask 把 attention 分成目标区域和非目标区域
计算目标区域 attention 与背景均值之间的差异
形成 attention loss

源码中 AttentionLoss.forward() 会从 attention map 中分别取出 mask 区域和非 mask 区域,并用差异构造损失。(GitHub)

这对论文理解很重要,因为它说明 SegEarth-R2 不是只靠 mask loss 训练分割头,还试图让语言模型的注意力机制也对齐到目标区域。


3. SegEarthR2Model

SegEarthR2Model 继承自 MiphaPhiModel。它更接近"基础模型骨架",主要负责:

text 复制代码
继承 Mipha/Phi 语言模型能力
保存 mask decoder 配置
构建 vision_tower_mask
提供 get_vision_tower()
提供 get_vision_tower_mask()
初始化视觉模块

源码中 SegEarthR2Model 在初始化时会根据 swin_type 构建 build_swin_bbuild_swin_l,并把 IVSDatasetMapper 作为 mask 视觉分支的 image processor。(GitHub)

这里要注意两个名字:

text 复制代码
vision_tower       → 多模态 LLM 的视觉塔,通常是 SigLIP
vision_tower_mask  → 分割分支视觉骨干,通常是 Swin

这和前面 dataset.py 里的两个图像输入完全对应:

text 复制代码
images_clip → vision_tower
images      → vision_tower_mask

4. SegEarthR2

SegEarthR2 继承自 MiphaPhiForCausalLM,这是整个工程真正的主模型。

它负责:

text 复制代码
语言模型 forward
多模态输入拼接
mask decoder 初始化
[SEG] embedding 提取
mask 输出
loss 计算
eval_seg
inference

源码中 SegEarthR2 初始化时会创建 SegEarthR2Model,设置 lm_head,并根据 mask_decode_train 判断是否初始化 mask decoder。(GitHub)


二、constants.py:几个特殊 token 的意义

理解 llava_phi.py 前,必须理解特殊 token index。

官方 constants.py 中定义了:

text 复制代码
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
SEG_TOKEN_INDEX = -201
CLS_TOKEN_INDEX = -202
REGION_TOKEN_INDEX = -203
REFER_TOKEN_INDEX = -204
TEMPORAL_TOKEN_INDEX = -205

其中最重要的是:

text 复制代码
IGNORE_INDEX      → 语言 loss 中被忽略的位置
IMAGE_TOKEN_INDEX → 文本序列中图像占位符
REFER_TOKEN_INDEX → referring instruction 占位符

这些负数 index 不是真正词表 token,而是工程中用于"占位"的特殊标记。constants.py 明确将 IGNORE_INDEX 设为 -100,IMAGE_TOKEN_INDEX 设为 -200,REFER_TOKEN_INDEX 设为 -204。(GitHub)

这很关键,因为 concat_image_seg_cls_embeds() 会扫描 input_ids,遇到普通 token 就用 embedding,遇到 IMAGE_TOKEN_INDEX 就插入图像特征,遇到 REFER_TOKEN_INDEX 就插入 refer embedding。源码中 concat_image_seg_cls_embeds() 明确断言当前只支持一个 image index,并在遇到 IMAGE_TOKEN_INDEX 时插入图像特征。(GitHub)


三、核心机制 1:Mask decoder 怎么初始化

函数:

python 复制代码
initial_mask_module()

它做了 5 件事:

text 复制代码
1. 创建 AttentionLoss
2. 根据 mask_config 创建 pixel_decoder
3. 创建 predictor
4. 创建 SEG_token_projector
5. 初始化 criterion,也就是 mask / dice loss 计算器

最关键的一句是:

python 复制代码
self.SEG_token_projector = nn.Linear(
    self.config.hidden_size,
    self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM
)

含义是:

text 复制代码
LLM hidden size
        ↓
Mask2Former hidden_dim

也就是说,语言模型的 hidden state 不能直接给 Mask2Former 用,需要先经过一个线性投影层。源码中 initial_mask_module() 会创建 pixel_decoderpredictorSEG_token_projector,并初始化 mask decoder 的训练损失。(GitHub)

你可以把它理解成:

text 复制代码
LLM 世界:token embedding
Mask2Former 世界:query embedding
SEG_token_projector:两个世界之间的翻译器

四、核心机制 2:两路视觉特征

llava_phi.py 里有两个非常容易混淆的函数。


1. encode_images(images)

这个函数服务于 多模态语言模型

逻辑是:

text 复制代码
images_clip
    ↓
vision_tower / SigLIP
    ↓
mm_projector
    ↓
image token embeddings
    ↓
插入 LLM 输入序列

它把 SigLIP 输出的视觉特征映射到语言模型 hidden size。


2. get_vision_tower_feature(images)

这个函数服务于 分割分支

逻辑是:

text 复制代码
images
    ↓
vision_tower_mask / Swin
    ↓
res2, res3, res4, res5
    ↓
pixel_decoder
    ↓
multi_scale_features + mask_features

源码中 get_vision_tower_feature() 会调用 get_vision_tower_mask()(images),然后组织成 res2res3res4res5 四级特征。(GitHub)

所以你必须记住:

text 复制代码
images_clip 不是为了分割像素,而是为了让 LLM 看懂图像
images 不是为了语言理解,而是为了产生高分辨率分割特征

五、核心机制 3:多模态输入怎么拼进 LLM

关键函数:

python 复制代码
prepare_inputs_labels_for_multimodal()

它的任务是:

text 复制代码
把 input_ids 里的 <image> 占位符替换成真实图像 token embedding
把普通文本 token 变成文本 embedding
同步更新 labels
同步更新 attention_mask
同步更新 SEG_token_embedding_indices
同步记录哪些位置是 image feature

源码中 prepare_inputs_labels_for_multimodal() 会先通过 encode_images(images) 得到图像特征,然后逐个 batch 样本处理 input ids,并构造新的 inputs_embedslabelsSEG_token_embedding_indicesimage_features_indices。(GitHub)

这一步非常关键,因为原始 input_ids 是这样的:

text 复制代码
[文本 token, IMAGE_TOKEN_INDEX, 文本 token, REFER_TOKEN_INDEX, 文本 token, [SEG]]

经过处理后变成:

text 复制代码
[文本 embedding, 图像 patch embeddings, 文本 embedding, refer embedding, 文本 embedding, [SEG] embedding]

也就是说,LLM 最终看到的不是单纯文本,而是"文本 embedding + 图像 embedding"的混合序列。


六、核心机制 4:concat_image_seg_cls_embeds()

这个函数是"替换占位符"的具体执行者。

它会扫描单个样本的 input_id

text 复制代码
普通 token
    ↓
embed_tokens()

IMAGE_TOKEN_INDEX
    ↓
替换为 img_feature

REFER_TOKEN_INDEX
    ↓
替换为 refer_embedding

它还会同步处理三类伴随信息:

text 复制代码
labels
SEG_token_embedding_indices
image_features_indices

为什么这很重要?

因为插入图像 token 后,序列长度变了。例如原始 <image> 只占 1 个位置,但 SigLIP 图像 patch 可能占几百个 token。此时 [SEG] 的位置、label 的位置、attention mask 的位置都必须一起重新对齐。

你可以这样理解:

text 复制代码
原序列:
A B <image> C [SEG]

替换后:
A B img1 img2 img3 ... img729 C [SEG]

所以 [SEG] 原本可能是第 5 个 token,替换后可能变成第 733 个位置。如果 SEG_token_embedding_indices 不同步更新,后面就会取错 hidden state。


七、核心机制 5:get_SEG_embedding()

函数:

python 复制代码
get_SEG_embedding(hidden_states, SEG_embedding_indices)

作用非常直接:

text 复制代码
遍历 batch 中每个样本
根据 SEG_embedding_indices 找到 [SEG] 位置
取出 hidden_states 中对应 token 的向量
拼接成 SEG_embedding

源码中 get_SEG_embedding() 会根据每个样本的 bool index 从 hidden_states 取出对应状态,然后 torch.catunsqueeze(1)。(GitHub)

这一步是 SegEarth-R2 的核心桥梁:

text 复制代码
语言模型输出 hidden_states
        ↓
取出 [SEG] 位置
        ↓
得到 SEG_embedding
        ↓
投影到 Mask2Former query 空间

你可以把 [SEG] 理解成"语言模型生成的分割查询"。


八、forward() 主流程逐步拆解

现在进入最核心的 forward()

函数输入包括:

python 复制代码
input_ids
attention_mask
past_key_values
inputs_embeds
labels
images
images_clip
seg_info
token_refer_id
SEG_token_embedding_indices
global_step
mask_num
dataset_type

源码中 forward() 参数明确同时接收 imagesimages_clip,也接收 seg_infoSEG_token_embedding_indicesmask_num。(GitHub)


第 1 步:检查 batch 数据类型

源码中会检查:

text 复制代码
如果 dataset_type 不为空,
则 batch 内所有 dataset_type 必须一致。

这说明作者原本考虑过多数据集混训,但一个 batch 内不希望混入不同任务类型。


第 2 步:打开 attention 输出

text 复制代码
output_attentions = True

为什么一定要输出 attention?

因为后面要计算 loss_attention

也就是说,这个 forward 不只是需要 hidden states,还需要每层 attention map,用来约束 [SEG] 对图像区域的关注。


第 3 步:如果有 [SEG],先提取分割图像特征

代码逻辑是:

text 复制代码
如果 SEG_token_embedding_indices 不为空,
且其中确实存在 [SEG],
并且当前不是单 token 生成阶段,
则调用 get_vision_tower_feature(images)

这一步得到的是分割分支的 Swin 多尺度特征。源码中在 SEG_token_embedding_indices 存在且非空时,会调用 get_vision_tower_feature(images) 得到 image_features。(GitHub)

注意:这里用的是 images,不是 images_clip


第 4 步:构造 LLM 的 inputs_embeds

然后调用:

python 复制代码
prepare_inputs_labels_for_multimodal(
    input_ids,
    attention_mask,
    past_key_values,
    labels,
    images_clip,
    token_refer_id=token_refer_id,
    SEG_token_embedding_indices=SEG_token_embedding_indices
)

注意这里传入的是 images_clip

它会把 SigLIP 特征插入到语言模型 token 序列中。源码中 forward 会调用 prepare_inputs_labels_for_multimodal(..., images_clip, ...) 来生成 inputs_embeds、新的 attention_mask、新的 labels、新的 [SEG] 索引和图像特征索引。(GitHub)

所以 forward 内部再次体现了双视觉流:

text 复制代码
images      → Swin → mask decoder
images_clip → SigLIP → LLM input embeddings

第 5 步:进入语言模型

代码逻辑:

text 复制代码
outputs = self.model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    inputs_embeds=inputs_embeds,
    output_attentions=True,
    return_dict=return_dict
)

随后:

text 复制代码
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)

源码中 forward() 会调用底层模型得到 outputs,再取 last_hidden_state 并通过 lm_head 得到 logits。(GitHub)

含义:

text 复制代码
hidden_states → 用于取 [SEG] embedding
logits        → 用于计算语言模型 loss
attentions    → 用于 attention loss

第 6 步:取出 [SEG] embedding

关键逻辑:

text 复制代码
SEG_embedding =
    SEG_token_projector(
        get_SEG_embedding(hidden_states, SEG_token_embedding_indices)
    )

源码中这一句在 forward 里非常明确:先从 hidden states 提取 [SEG] embedding,再经过 SEG_token_projector。(GitHub)

这一步之后,SEG_embedding 已经不再只是语言模型向量,而是可以作为 mask decoder query 使用的向量。

你要把这句话背下来:

text 复制代码
SegEarth-R2 的分割 query 不是随机学习的静态 query,
而是来自语言模型 [SEG] token 的上下文 hidden state。

这也是它能够做到"语言引导分割"的关键。


第 7 步:pixel decoder 提取 mask features

如果 image_features 不为空,执行:

text 复制代码
mask_features, transformer_encoder_features, multi_scale_features =
    pixel_decoder.forward_features(image_features)

源码中 forward 会调用 pixel_decoder.forward_features(image_features) 生成 mask_featuresmulti_scale_features。(GitHub)

这一步的含义是:

text 复制代码
Swin backbone 的 res2/res3/res4/res5
        ↓
MSDeformAttnPixelDecoder
        ↓
Mask2Former 可用的多尺度特征

第 8 步:根据 mask_num 复制图像特征

这一段很重要,很多人第一次看会不理解:

text 复制代码
mask_features = repeat_interleave(mask_features, repeats=mask_num, dim=0)

multi_scale_features = [
    repeat_interleave(feat, repeats=mask_num, dim=0)
    for feat in multi_scale_features
]

为什么要复制?

因为一个 batch 中,一张图可能有多个 [SEG],例如:

text 复制代码
图像 A:2 个 [SEG] → 2 个 mask
图像 B:1 个 [SEG] → 1 个 mask

原始图像特征数量是按图像来的:

text 复制代码
B 张图 → B 份图像特征

但 mask query 是按 [SEG] 来的:

text 复制代码
总共 3 个 [SEG] → 3 个 SEG_embedding

所以要把图像特征复制到和 SEG_embedding 数量一致。源码中 forward 会把 mask_featuresmulti_scale_featuresmask_numrepeat_interleave,再传给 predictor。(GitHub)


第 9 步:predictor 输出 mask

关键逻辑:

text 复制代码
mask_outputs = predictor(
    multi_scale_features,
    mask_features,
    None,
    None,
    SEG_embedding
)

源码中 mask_outputs = self.predictor(..., SEG_embedding),说明 Mask2Former predictor 是被 [SEG] embedding 条件化驱动的。(GitHub)

输出里最重要的是:

text 复制代码
mask_outputs["pred_masks"]

这就是模型预测的分割 mask。


九、训练 loss 怎么算

forward() 中 loss 分为 3 类:

text 复制代码
1. llm_loss
2. mask_loss + dice_loss
3. attention_loss

1. llm_loss

源码中计算语言 loss 的方式是标准 causal LM:

text 复制代码
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
CrossEntropyLoss(shift_logits, shift_labels)

也就是"前一个 token 预测下一个 token"。源码中 forward 明确用 CrossEntropyLoss 计算 llm_loss。(GitHub)

这部分监督模型生成 answer,例如:

text 复制代码
The target is [SEG].

注意,human prompt 部分的 label 在 dataset 中已经被设成 IGNORE_INDEX,所以这里主要监督 assistant answer。


2. mask_lossdice_loss

如果 seg_info 不为空,源码会根据 seg_info 构造 targets。对于包含 mask 的情况,每个 target 包括:

text 复制代码
labels
masks
valid
inst_id

然后调用:

text 复制代码
mask_losses = self.criterion(mask_outputs, targets)

接着按 weight_dict 汇总 loss_maskloss_dice。源码中 forward 在 seg_info 存在时会构造 targets,并通过 self.criterion(mask_outputs, targets) 计算分割损失。(GitHub)

含义:

text 复制代码
mask_loss 关注像素级预测是否正确
dice_loss 关注预测区域和真实区域重叠程度

3. attention_loss

源码里会把真实 mask resize 到 800×800,再下采样到 27×27,最后和 [SEG] token 对图像 token 的 attention 做监督。forward() 中会构造 masks_down,然后遍历 attentions,取出 [SEG] 对 image features 的 attention,并加入 loss_attention。(GitHub)

你可以这样理解:

text 复制代码
mask_loss 监督最后输出的 mask
attention_loss 监督中间注意力是否看向目标

最终总 loss 是:

text 复制代码
loss = llm_loss + mask_loss + 0.01 * loss_attention

源码中 forward 明确使用 loss = llm_loss + mask_loss + 0.01 * loss_attention,并返回 loss_maskloss_diceloss_llmloss_attention。(GitHub)


十、把 forward 画成完整数据流

你现在应该能画出这张图:

text 复制代码
batch from dataset.py
    │
    ├── input_ids
    ├── labels
    ├── attention_mask
    ├── images_clip
    ├── images
    ├── SEG_token_embedding_indices
    ├── seg_info
    └── mask_num

            ↓

llava_phi.py forward

            ↓

images_clip
    ↓
SigLIP vision_tower
    ↓
mm_projector
    ↓
image token embeddings
    ↓
prepare_inputs_labels_for_multimodal
    ↓
LLM input embeddings

images
    ↓
Swin vision_tower_mask
    ↓
res2/res3/res4/res5
    ↓
pixel_decoder
    ↓
mask_features + multi_scale_features

LLM input embeddings
    ↓
Phi/Mipha language model
    ↓
hidden_states + logits + attentions

hidden_states + SEG_token_embedding_indices
    ↓
get_SEG_embedding
    ↓
SEG_token_projector
    ↓
SEG_embedding

SEG_embedding + mask_features + multi_scale_features
    ↓
Mask2Former predictor
    ↓
pred_masks

logits + labels
    ↓
llm_loss

pred_masks + seg_info
    ↓
mask_loss + dice_loss

attentions + gt mask
    ↓
attention_loss

最终:
loss = llm_loss + mask_loss + 0.01 * attention_loss

这张图是你吃透 SegEarth-R2 的核心图。后面所有代码都围绕这张图展开。


十一、eval_seg():评估阶段怎么生成 mask

训练 forward 是为了算 loss,eval_seg() 是为了输出预测结果。

eval_seg() 的主要流程是:

text 复制代码
1. get_vision_tower_feature(images)
2. prepare_inputs_labels_for_multimodal(..., images_clip, ...)
3. self.model(...) 得到 hidden_states
4. get_SEG_embedding()
5. SEG_token_projector()
6. pixel_decoder.forward_features()
7. repeat_interleave 图像特征
8. predictor(..., SEG_embedding)
9. 取 pred_masks
10. resize 到图像尺寸
11. 生成 processed_results

源码中 eval_seg() 会先提取 image_features,然后构造多模态输入,取 last_hidden_state,提取 [SEG] embedding,经过 pixel decoder 和 predictor 得到 pred_masks。(GitHub)

评估阶段还会把预测 mask 插值到图像 tensor 尺寸,并把预测 mask、GT mask、image name、data id 和 mask id 放进 processed_results。源码中 eval_seg()mask_pred_results 插值到 ImageList 的尺寸,并返回包含 predgtimage_nameidmask_id 的结果字典。(GitHub)


十二、训练 forward 和 eval_seg 的区别

你要把这两个区别讲清楚:

text 复制代码
forward:
    训练用
    计算 logits
    计算 llm_loss
    计算 mask_loss / dice_loss
    计算 attention_loss
    返回 CausalOutputWithMask

eval_seg:
    测试用
    主要生成 mask
    不计算训练 loss
    返回 processed_results

也就是说:

text 复制代码
forward 关心 loss
eval_seg 关心 pred mask

十三、这部分最容易出错的 8 个地方

1. [SEG] 数量和 mask 数量不一致

表现:

text 复制代码
repeat_interleave shape 对不上
criterion 报 target 数量不匹配
pred_masks 和 gt masks 对不齐

检查:

python 复制代码
print(SEG_token_embedding_indices.sum())
print(sum(mask_num))
print(len(seg_info))

三者应该一致。


2. imagesimages_clip 搞反

表现:

text 复制代码
SigLIP 输入尺寸不对
Swin 输入尺寸不对
pixel_decoder shape 报错

正确关系:

text 复制代码
images_clip → prepare_inputs_labels_for_multimodal → LLM
images      → get_vision_tower_feature → pixel_decoder

3. SEG_token_embedding_indices 没有随图像 token 插入同步扩展

如果你自己改 prompt 或 tokenizer,最容易破坏这里。prepare_inputs_labels_for_multimodal()concat_image_seg_cls_embeds() 会同步更新 [SEG] 位置,否则就会从 hidden states 中取错 token。


4. mask_num 不是 list 或 tensor 形状异常

mask_num 用于 repeat_interleave。如果 batch 中某个样本没有 [SEG],或者 mask_num 与 batch size 不一致,就容易报错。


5. seg_info 为空

如果 seg_info 为空,mask loss 无法计算。

检查 dataset:

text 复制代码
annotation 是否有 mask
RLE 是否能 decode
answer 是否包含 [SEG]
DataCollator 是否把 seg_info 传进 batch

6. Mask2Former 预训练权重 key 不匹配

initial_mask_module() 里会从 checkpoint 中抽取 sem_seg_head.pixel_decodersem_seg_head.predictor 权重,并做部分 key name 转换。如果权重版本不匹配,可能出现 missing keys 或 unexpected keys。源码中初始化函数会对 pixel decoder 和 predictor 权重做 key 处理,并分别 load_state_dict(..., strict=False)。(GitHub)


7. attention loss 出 NaN

可能原因:

text 复制代码
mask 全 0
attention_map_target 为空
BF16 下数值不稳定
masks_down 尺寸不符合 image token 数

AttentionLoss 内部有 epsilon 避免 log(0),但如果输入本身异常,仍可能有问题。源码中 AttentionLoss 使用 epsilon = 1e-8,并在 target 区域为空时跳过。(GitHub)


8. 图像 token 数和 attention mask 下采样尺寸不一致

代码里 attention loss 把 mask 下采样到 27×27,对应 729 个图像 token。源码中注释也显示 attention 的目标形状类似 [1, 729][4, 729]。(GitHub)

如果你换了 SigLIP 输入尺寸、patch size 或视觉塔,27×27 可能就不对。这是后续改模型时必须注意的点。


十四、博士视角:这个文件可以怎么改

你做"多模态融合驱动的遥感开放世界目标识别研究",最值得改的不是 train.sh,而是 llava_phi.py 这几处:

方向 1:改 SEG_token_projector

当前是简单线性层:

text 复制代码
LLM hidden state → Mask2Former hidden dim

你可以改成:

text 复制代码
MLP projector
门控 projector
MoE projector
模态感知 projector
类别不确定性感知 projector

对应你的博士创新:

text 复制代码
多模态协同表征与跨模态语义对齐

方向 2:改 predictor(..., SEG_embedding)

现在 predictor 使用 [SEG] embedding 条件化 mask decoding。

你可以思考:

text 复制代码
是否可以加入多个 [SEG] 之间的关系建模?
是否可以加入未知类别原型?
是否可以加入开放世界置信度?
是否可以让 query 同时包含文本语义和视觉区域原型?

对应你的博士创新:

text 复制代码
开放世界目标识别与未知类别发现机制

方向 3:改 attention loss

当前 attention loss 主要约束 [SEG] 对图像 token 的注意力和 GT mask 的关系。

你可以改成:

text 复制代码
边界增强 attention loss
小目标增强 attention loss
多尺度 attention loss
跨模态一致性 attention loss
未知区域对比 attention loss

对应你的博士创新:

text 复制代码
已知目标精准判别 + 未知目标主动发现

方向 4:改两路视觉分支

现在是:

text 复制代码
SigLIP → LLM
Swin → Mask2Former

你可以做:

text 复制代码
光学 / SAR / 红外 多模态输入
不同模态单独 encoder
共享语义 projector
MoE 动态融合
缺失模态条件化表示

对应你的博士创新:

text 复制代码
物理-语义联合跨模态对齐

十五、本讲你必须掌握的 12 个结论

  1. llava_phi.py 是 SegEarth-R2 的核心模型文件。
  2. CausalOutputWithMask 让模型同时返回语言 loss 和分割 loss。
  3. SegEarthR2Model 管基础多模态模型和 Swin mask 视觉塔。
  4. SegEarthR2 管完整 forward、mask decoder、loss 和推理。
  5. images_clip 进入 SigLIP 和 LLM。
  6. images 进入 Swin 和 Mask2Former。
  7. prepare_inputs_labels_for_multimodal() 把图像 token 插入 LLM 输入序列。
  8. concat_image_seg_cls_embeds() 负责替换 IMAGE_TOKEN_INDEXREFER_TOKEN_INDEX
  9. get_SEG_embedding() 从 hidden states 中取出 [SEG] 对应向量。
  10. SEG_token_projector 把 LLM hidden state 映射成 Mask2Former query。
  11. predictor(..., SEG_embedding) 输出语言引导的 mask。
  12. 总 loss 是 llm_loss + mask_loss + 0.01 * attention_loss

十六、下一讲继续拆什么

下一讲进入 第 6 讲:eval.pytest.sh 和推理可视化流程

重点解决:

text 复制代码
1. 测试脚本怎么启动
2. 模型怎么加载
3. LoRA 合并后怎么推理
4. eval.py 怎么构造 prompt
5. eval_seg() 怎么被调用
6. pred mask 怎么保存
7. 如何验证复现结果是否正确
8. 如何做自己的遥感图片推理

到这里,你已经吃透了 SegEarth-R2 最关键的一条链:

text 复制代码
[SEG] token hidden state → SEG_token_projector → Mask2Former query → pred_mask

后面只要把评估和环境报错解决,你就具备复现和二次改造这个项目的基础了。

相关推荐
奔袭的算法工程师1 小时前
论文解读--Sparse4D v3: Advancing End-to-End 3D Detection and Tracking
人工智能·目标检测·计算机视觉·自动驾驶·信号处理
sali-tec1 天前
C# 基于OpenCv的视觉工作流-章84-包胶有无检测
图像处理·人工智能·opencv·算法·计算机视觉
安逸sgr1 天前
《图解机器学习-第三章》:训练、验证、测试:三分数据,缺一不可!
人工智能·深度学习·机器学习·计算机视觉
韩师傅1 天前
用 7 年前的 SHWD 数据,带你读一遍 Ultralytics predict 链
计算机视觉
硅谷秋水1 天前
ProDrive:基于自身-环境协同演化的自动驾驶主动规划
人工智能·深度学习·机器学习·计算机视觉·自动驾驶
菜鸟‍1 天前
【论文学习】2026.5 || 分解式视觉-语言对齐用于细粒度开放词汇分割
人工智能·深度学习·计算机视觉
郝学胜-神的一滴1 天前
[简化版 GAMES 101] 计算机图形学 13:从光栅化到着色——赋予三维像素光影灵魂
c++·计算机视觉·unity·godot·图形渲染·opengl·unreal
硅谷秋水1 天前
面向长上下文自动驾驶的规划对齐Token压缩
人工智能·深度学习·机器学习·计算机视觉·自动驾驶
YOLO数据集集合2 天前
无人机山地灾害巡检数据集 | 滑坡多区域实例分割 遥感影像解译 地质灾害预警深度学习数据10296期
人工智能·深度学习·目标检测·计算机视觉·无人机