Segearth-R2-02

下面开始按"吃透 + 能复现 + 能改进成自己博士课题能力 "来推进。我们先不急着逐行钻代码,而是先建立一张完整地图:这个项目解决什么问题、数据怎么进来、模型怎么走、loss 怎么算、训练怎么跑、复现哪里最容易出错


一、先给你一个总体判断

SegEarth-R2 的定位是:面向遥感图像的语言引导分割模型 。官方 README 标题就是 Towards Comprehensive Language-guided Segmentation for Remote Sensing Images ,仓库说明代码已在 2026-05-24 发布,LaSeRS 数据集在 2026-05-23 发布。(GitHub)

你可以把它理解为:

text 复制代码
遥感图像 + 文本指令
        ↓
多模态大模型理解文本和图像
        ↓
生成 [SEG] 分割语义 token
        ↓
Mask2Former 类分割解码器根据 [SEG] token 输出 mask
        ↓
得到语言引导的遥感目标分割结果

它和普通语义分割不同。普通语义分割是:

text 复制代码
图像 → 固定类别 mask

SegEarth-R2 是:

text 复制代码
图像 + 语言描述 → 与语言对应的目标 mask

所以它对你的博士课题"多模态融合驱动的遥感开放世界目标识别研究"非常重要,因为它已经从"封闭类别分割"走向了"语言指令驱动的开放目标定位/分割"。


二、项目整体结构先这样理解

官方仓库主目录很清楚,核心包括 docsscriptssegearth_r2requirements.txt。其中 segearth_r2 下进一步分成 datasetsevalmodeltrainutils。(GitHub)

text 复制代码
SegEarth-R2/
├── docs/
│   ├── Preparation.md        # 数据和预训练权重准备
│   ├── Installation.md       # 环境安装
│   ├── Training.md           # 训练说明
│   └── Evaluation.md         # 测试说明
│
├── scripts/
│   ├── train.sh              # 训练入口
│   ├── test.sh               # 测试入口
│   ├── merge_lora_weights.sh # LoRA 权重合并
│   ├── zero1.json
│   ├── zero2.json
│   └── zero3.json            # DeepSpeed 显存优化配置
│
└── segearth_r2/
    ├── datasets/
    │   └── dataset.py        # 数据读取、图像处理、文本处理、mask 构造
    │
    ├── train/
    │   ├── train.py          # 训练主程序
    │   ├── llava_trainer.py  # 自定义 Trainer
    │   └── merge_lora_weights_and_save_hf_model.py
    │
    ├── eval/
    │   └── eval.py           # 评估主程序
    │
    ├── model/
    │   ├── language_model/
    │   │   └── llava_phi.py  # SegEarth-R2 主模型,最核心
    │   ├── mipha/            # Mipha 多模态语言模型基础结构
    │   ├── mask_encoder/     # Swin Transformer 图像特征编码
    │   ├── mask_decoder/     # Mask2Former 分割解码器
    │   └── datasets_mapper/
    │
    └── utils/
        ├── constants.py
        ├── conversation.py
        ├── builder.py
        └── mm_utils.py

真正核心是:

text 复制代码
dataset.py
train.py
llava_phi.py
llava_trainer.py
eval.py

你只要先吃透这五个文件,就能理解 70% 以上的工程逻辑。


三、项目的三条主线

主线 1:数据流

官方准备文档要求数据大致按下面结构放置:训练集包括 train/annotations/train_ann.jsontrain/images,测试集包括 test/annotationstest/images。预训练权重需要包括 Mipha-3B、SigLIP 和 Mask2Former,并放到 pretrained_model 目录。(GitHub)

代码里 LaSeRSDataset 会读取:

text 复制代码
image_name
description
answer
mask
id

然后构造:

text 复制代码
image tensor
input_ids
labels
mask
[SEG] token 位置
refer instruction token
seg_info

也就是说,数据不是简单的"图像 + mask",而是:

text 复制代码
图像
文本指令 description
模型回答 answer
答案中的 [SEG]
对应目标 mask

其中 [SEG] 是整个项目的关键。answer 中出现几个 [SEG],就说明要预测几个 mask。dataset.py 里通过 answer.count("[SEG]") 统计 mask 数量,并构造 SEG_token_embedding_indices。(GitHub)


主线 2:训练流

官方训练脚本使用 DeepSpeed 启动 segearth_r2/train/train.py,参数包括 Mipha-3B 路径、SigLIP 路径、Mask2Former 权重、数据路径、输出路径、LoRA、DeepSpeed 配置和 mask config。(GitHub)

训练流可以画成:

text 复制代码
scripts/train.sh
        ↓
segearth_r2/train/train.py
        ↓
解析 ModelArguments / DataArguments / TrainingArguments
        ↓
加载 SegEarthR2.from_pretrained()
        ↓
初始化 Mask2Former 分割模块
        ↓
初始化 vision tower
        ↓
加载 tokenizer
        ↓
添加 [SEG] token
        ↓
构建 LaSeRSDataset
        ↓
构建 DataCollator
        ↓
LoRA 注入
        ↓
LLaVATrainer.train()
        ↓
保存 checkpoint

train.py 中能看到 ModelArgumentsDataArgumentsTrainingArguments 三类参数,也能看到它加载 SegEarthR2,初始化 mask module,添加 [SEG] token,并使用 LoRA 进行训练。(GitHub)


主线 3:模型流

模型流是最重要的:

text 复制代码
images_clip → SigLIP → mm_projector → LLM token 序列
images      → Swin → pixel decoder → mask feature
hidden state of [SEG] → SEG_token_projector → mask query
mask query + mask feature → Mask2Former predictor → pred_masks

llava_phi.py 里定义了 SegEarthR2ModelSegEarthR2。它会初始化 Swin 视觉分支、Mask2Former pixel decoder、transformer predictor、SEG_token_projector,并在 forward 中把 [SEG] 位置的 hidden states 取出用于预测 mask。(GitHub)

所以这个工程的灵魂不是"LLM 生成文字",而是:

text 复制代码
LLM 的 [SEG] token hidden state
        ↓
变成 mask query
        ↓
驱动 Mask2Former 产生分割结果

四、你要按这个计划学

我建议分成 8 个阶段。每个阶段都对应"能复现"的检查点。


阶段 0:先建立认知地图

目标:知道项目做什么,不陷入代码细节。

你要掌握:

text 复制代码
1. 什么是语言引导分割
2. 什么是 [SEG] token
3. Mipha-3B 负责什么
4. SigLIP 负责什么
5. Swin + Mask2Former 负责什么
6. LoRA 训练的作用
7. DeepSpeed 为什么必须用

阶段检查点:

text 复制代码
你能够用自己的话解释:
输入是什么?
输出是什么?
[SEG] 为什么重要?
mask 是怎么被预测出来的?

阶段 1:环境复现

官方安装说明要求 Linux、Python ≥ 3.10、PyTorch ≥ 2.0,并给了 torch==2.1.0 + cu121 的安装命令,同时需要额外安装 Detectron2,并编译 MSDeformAttn CUDA kernel。(GitHub)

但是你的显卡是 RTX 5090 ,不要直接照搬官方 cu121 环境。RTX 5090 属于 Blackwell 架构,PyTorch 社区曾明确讨论 Blackwell sm_100/sm_120 需要源码构建或 CUDA 12.8 相关构建支持。(PyTorch Forums)

你的复现环境建议按这个方向:

text 复制代码
Ubuntu 22.04 / 24.04
NVIDIA Driver 新版本
CUDA 12.8
Python 3.10
PyTorch cu128 版本
torchvision 匹配 PyTorch
Detectron2 从源码编译
MSDeformAttn 重新编译

阶段检查点:

bash 复制代码
python -c "import torch; print(torch.__version__); print(torch.cuda.is_available()); print(torch.cuda.get_device_name(0))"

还要检查:

bash 复制代码
python -c "import detectron2; print('detectron2 ok')"

以及:

bash 复制代码
cd segearth_r2/model/mask_decoder/Mask2Former_Simplify/modeling/pixel_decoder/ops
sh make.sh

阶段 2:数据与权重准备

官方要求准备三类权重:

text 复制代码
Mipha-3B
SigLIP so400m patch14 384
Mask2Former model_final

并放到:

text 复制代码
pretrained_model/
├── CLIP/siglip-so400m-patch14-384/
├── mask2former/model.pkl
└── mllm/Mipha-3B/

这一步最容易出错的是路径。训练脚本里默认路径是:

bash 复制代码
--model_name_or_path "pretrained_model/mllm/Mipha-3B"
--vision_tower "pretrained_model/CLIP/siglip-so400m-patch14-384"
--vision_tower_mask "pretrained_model/mask2former/model_final_54b88a.pkl"

官方训练脚本也明确需要改自己的数据路径 --base_data_path。(GitHub)

阶段检查点:

text 复制代码
1. train/annotations/train_ann.json 能被读取
2. train/images/ 下图片存在
3. Mipha-3B tokenizer 能加载
4. SigLIP image processor 能加载
5. Mask2Former 权重路径正确

阶段 3:读懂 scripts/train.sh

这个文件是复现入口。你要逐行理解:

bash 复制代码
export NCCL_P2P_DISABLE="1"
export NCCL_IB_DISABLE="1"

作用是关闭部分 NCCL 通信路径,减少某些单机训练环境中的通信问题。

bash 复制代码
deepspeed --master_port=29500 --include localhost:4 segearth_r2/train/train.py

含义是用 DeepSpeed 启动训练,并指定 GPU。官方脚本里是 localhost:4,说明它默认只用第 4 张卡。你如果单卡 RTX 5090,应该改成:

bash 复制代码
--include localhost:0

如果多卡:

bash 复制代码
--include localhost:0,1,2,3

阶段检查点:

text 复制代码
你能解释 train.sh 里每一个参数:
model_name_or_path
vision_tower
vision_tower_mask
base_data_path
output_dir
max_steps
bf16
learning_rate
lora_r
deepspeed
mask_config
switch_bs

阶段 4:读懂 dataset.py

这是数据进入模型的地方。

重点函数:

python 复制代码
preprocess_image()
preprocess_mask()
tokenizer_special_tokens()
preprocess_llama2()
LaSeRSDataset.__getitem__()
DataCollatorForCOCODatasetV2.__call__()

你要重点理解:

text 复制代码
1. 图像为什么被处理成 1024×1024
2. mask 为什么要保持最近邻插值
3. description 如何变成 instruction
4. answer 中的 [SEG] 如何标记
5. input_ids、labels、attention_mask 怎么构造
6. images 和 images_clip 为什么同时存在

这里有一个非常关键的双图像流:

text 复制代码
images_clip:给 SigLIP / LLM 视觉塔使用
images:给 Swin / Mask2Former 分割分支使用

这说明项目不是单一视觉编码器,而是:

text 复制代码
一个视觉塔服务多模态语言理解
一个视觉骨干服务像素级分割

阶段检查点:

text 复制代码
你能手动打印一个 batch,并说清楚 batch 中每个 key 的含义。

阶段 5:读懂 train.py

这个文件是训练调度核心。

重点模块:

python 复制代码
ModelArguments
DataArguments
TrainingArguments
make_unify_datamodule()
find_linear_layers()
safe_save_model_for_hf_trainer()
train()

特别重要的是:

python 复制代码
tokenizer.add_tokens("[SEG]")
model.resize_token_embeddings(len(tokenizer))

这表示 [SEG] 不是普通文本符号,而是新加入模型词表的特殊语义分割 token。

还有:

python 复制代码
train_module_list = [
    "lm_head",
    "pixel_decoder",
    "predictor",
    "SEG_token_projector",
]

这说明训练时重点更新:

text 复制代码
语言输出头
分割 pixel decoder
Mask2Former predictor
SEG token 投影层
LoRA 参数

这对你未来做博士课题很重要,因为你可以从这里改:

text 复制代码
多模态融合模块
开放世界类别发现模块
轻量化部署模块

阶段检查点:

text 复制代码
你能说清楚哪些参数冻结,哪些参数训练。

阶段 6:吃透 llava_phi.py

这是整个项目的核心。

你要分 6 层读:

text 复制代码
第1层:SegEarthR2Model 初始化视觉分支
第2层:SegEarthR2 初始化 mask decoder
第3层:encode_images 把 SigLIP 特征送入 mm_projector
第4层:prepare_inputs_labels_for_multimodal 拼接图像 token 和文本 token
第5层:get_SEG_embedding 取出 [SEG] hidden state
第6层:forward 计算 LLM loss + mask loss + attention loss

核心 forward 逻辑是:

text 复制代码
input_ids + images_clip
        ↓
prepare_inputs_labels_for_multimodal
        ↓
LLM forward
        ↓
hidden_states
        ↓
提取 [SEG] hidden state
        ↓
SEG_token_projector
        ↓
Mask2Former predictor
        ↓
pred_masks
        ↓
loss_llm + loss_mask + loss_dice + loss_attention

llava_phi.py 中可以看到,forward 会计算语言建模的交叉熵 loss,也会根据 mask_outputsseg_info 计算 mask/dice 类损失,还会额外构造 attention loss。(GitHub)

阶段检查点:

text 复制代码
你能画出 forward 数据流。
你能找到 [SEG] hidden state 在哪里被提取。
你能找到 mask loss 在哪里计算。
你能找到 pred_masks 在哪里生成。

阶段 7:训练、合并、测试

官方训练后需要运行 scripts/merge_lora_weights.sh,把 LoRA adapter 合并进基础模型,用于后续推理和评估。(GitHub)

流程是:

text 复制代码
训练
  ↓
保存 LoRA checkpoint
  ↓
merge_lora_weights_and_save_hf_model.py
  ↓
得到完整模型
  ↓
test.sh / eval.py
  ↓
输出预测 mask

官方测试脚本通过 DeepSpeed 调用 segearth_r2/eval/eval.py,需要指定 base_data_pathmodel_pathvision_tower_maskmask_configoutput_dir。(GitHub)

阶段检查点:

text 复制代码
你能完成一次小样本训练。
你能合并 LoRA。
你能跑通 eval。
你能在 output/res 里看到预测结果。

五、最终学习路线表

阶段 目标 核心文件 结果
0 建立整体理解 README、docs 知道项目解决什么问题
1 环境复现 Installation.md Python/PyTorch/Detectron2/MSDeformAttn 可用
2 数据准备 Preparation.md 数据和权重路径正确
3 训练入口 scripts/train.sh 能解释所有训练参数
4 数据读取 dataset.py 能打印并解释一个 batch
5 训练逻辑 train.py 能知道模型怎么加载、哪些参数训练
6 模型核心 llava_phi.py 能理解 [SEG] → mask 的全过程
7 评估推理 eval.pytest.sh 能复现预测结果
8 博士改进 model/ 能设计自己的创新点

六、你现在先记住 5 个关键结论

第一,SegEarth-R2 的核心不是普通分割,而是语言引导分割。

它输入的是:

text 复制代码
遥感图像 + 语言描述

输出的是:

text 复制代码
语言对应目标的 mask

第二,SEG 是连接语言模型和分割模型的桥。

语言模型不是直接输出 mask,而是输出 [SEG] token。模型取 [SEG] 对应的 hidden state,再送入 Mask2Former 分割头。

第三,它有两条视觉分支。

text 复制代码
SigLIP:服务语言模型,看图理解语义
Swin + Mask2Former:服务像素级分割

第四,训练 loss 至少包含三类。

text 复制代码
语言建模 loss
mask loss
dice loss
attention loss

第五,复现最容易错在三个地方。

text 复制代码
RTX 5090 CUDA / PyTorch / Detectron2 版本
MSDeformAttn 编译
数据路径和预训练权重路径

七、后续拆解顺序

接下来建议按这个顺序继续:

text 复制代码
第一讲:scripts/train.sh 逐行解释
第二讲:docs/Preparation.md + 数据格式完整解释
第三讲:dataset.py 逐函数解释
第四讲:train.py 逐模块解释
第五讲:llava_phi.py 模型初始化部分
第六讲:llava_phi.py forward 全流程
第七讲:loss 计算与 mask decoder
第八讲:eval.py 推理与可视化
第九讲:RTX 5090 复现和报错修复
第十讲:如何改成你的博士课题方向

现在你先建立这个整体框架。下一段我会直接从 scripts/train.sh 逐行讲起 ,然后进入 dataset.py,最后再攻克最核心的 llava_phi.py

相关推荐
Z-D-K1 小时前
考验AI的“自我和意识“-AI对《红楼梦》后40回的改写(22)
人工智能·ai·aigc·agent·agi
CIO_Alliance1 小时前
(企业AI化转型)选对iPaaS系统集成厂家是制造业数字化转型的生死线
大数据·数据库·人工智能·企业数字化转型·ipaas·系统集成
生成论实验室1 小时前
六十四卦态势操作系统技术白皮书
人工智能·语言模型·系统架构·机器人·自动驾驶·agi·安全架构
qcx231 小时前
【AI Daily 2026-06-05】 AI 方向的基础设施化,能力从模型层下沉到工具链和工作流
人工智能·ai·llm·agent·agi
一次旅行1 小时前
AI领域每日资讯报告 | 2026年6月15日
人工智能
workflower1 小时前
互联网与大数据环境下制造服务模式
人工智能·自然语言处理·数据挖掘·自动驾驶·动态规划·制造
WangN21 小时前
【通识】RSL-RL快速上手
人工智能·python·机器学习·机器人
lijgvnns1 小时前
散户做股票研究与复盘,主流AI工具的场景化使用指南
大数据·人工智能·数据挖掘
weixin_446260851 小时前
学习协调偏好用于多目标多智能体强化学习
人工智能·多智能体