import argparse
import json
import os
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
from PIL import Image
SEP = " " # txt中图片路径和标签路径使用三个空格分隔
@dataclass
class Sample:
image_path: str
mask_path: str
source_txt: str
line_no: int
def _normalize_path(path: str, case_insensitive: bool) -> str:
s = path.strip().replace("\\", "/")
while "//" in s:
s = s.replace("//", "/")
return s.lower() if case_insensitive else s
def _make_key(path: str, match_mode: str, case_insensitive: bool) -> str:
n = _normalize_path(path, case_insensitive)
if match_mode == "basename":
return os.path.basename(n)
return n
def parse_txt(txt_path: str) -> List[Sample]:
samples: List[Sample] = []
with open(txt_path, "r", encoding="utf-8") as f:
for i, raw in enumerate(f, start=1):
line = raw.strip()
if not line:
continue
parts = line.split(SEP)
if len(parts) != 2:
raise ValueError(
f"Invalid line format in {txt_path}:{i}. Expect '<img>{SEP}<mask>', got: {line}"
)
samples.append(Sample(parts[0].strip(), parts[1].strip(), txt_path, i))
return samples
def load_mask(path: str) -> np.ndarray:
mask = Image.open(path)
if len(mask.getbands()) > 1:
mask = mask.split()[0]
return np.array(mask, dtype=np.uint8)
def make_output_mask_path(fsd_mask_path: str, suffix: str) -> str:
base, ext = os.path.splitext(fsd_mask_path)
return f"{base}{suffix}{ext}"
def write_txt(path: str, lines: List[Tuple[str, str]]):
parent = os.path.dirname(path)
if parent:
os.makedirs(parent, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
for img, mask in lines:
f.write(f"{img}{SEP}{mask}\n")
def build_rm_index(
rm_txts: List[str], match_mode: str, case_insensitive: bool
) -> Dict[str, List[Sample]]:
idx: Dict[str, List[Sample]] = {}
for txt in rm_txts:
for s in parse_txt(txt):
key = _make_key(s.image_path, match_mode, case_insensitive)
idx.setdefault(key, []).append(s)
return idx
def merge_one_fsd_txt(
fsd_txt: str,
rm_index: Dict[str, List[Sample]],
args: argparse.Namespace,
):
fsd_samples = parse_txt(fsd_txt)
all_lines: List[Tuple[str, str]] = []
hit_only_lines: List[Tuple[str, str]] = []
total = len(fsd_samples)
hit_rm = 0
wrote_luyan = 0
size_mismatch = 0
multi_match = 0
merged_mask_written = 0
for fsd in fsd_samples:
key = _make_key(fsd.image_path, args.match_mode, args.case_insensitive)
cands = rm_index.get(key, [])
默认保留原始mask
out_mask_path = fsd.mask_path
if cands:
hit_rm += 1
if len(cands) > 1:
multi_match += 1
rm = cands[0] # 多命中时取第一个(按rm txt读取顺序)
try:
fsd_mask = load_mask(fsd.mask_path)
rm_mask = load_mask(rm.mask_path)
except Exception as e:
print(f"[WARN] read mask failed: fsd={fsd.mask_path}, rm={rm.mask_path}, err={e}")
all_lines.append((fsd.image_path, out_mask_path))
continue
if fsd_mask.shape != rm_mask.shape:
size_mismatch += 1
all_lines.append((fsd.image_path, out_mask_path))
continue
curb_area = rm_mask == args.rm_curb_value
merged = fsd_mask.copy()
if np.any(curb_area):
merged[curb_area] = args.fsd_curb_class
wrote_luyan += 1
out_mask_path = make_output_mask_path(fsd.mask_path, args.mask_suffix)
out_dir = os.path.dirname(out_mask_path)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
if args.overwrite or (not os.path.exists(out_mask_path)):
Image.fromarray(merged).save(out_mask_path)
merged_mask_written += 1
hit_only_lines.append((fsd.image_path, out_mask_path))
all_lines.append((fsd.image_path, out_mask_path))
base = os.path.splitext(os.path.basename(fsd_txt))[0]
out_all_txt = os.path.join(args.out_dir, f"{base}{args.txt_suffix}.txt")
out_hit_txt = os.path.join(args.out_dir, f"{base}{args.txt_suffix}_hit_rm_only.txt")
write_txt(out_all_txt, all_lines)
write_txt(out_hit_txt, hit_only_lines)
summary = {
"fsd_txt": fsd_txt,
"output_all_txt": out_all_txt,
"output_hit_txt": out_hit_txt,
"total_fsd_samples": total,
"hit_rm_samples": hit_rm,
"wrote_luyan_pixel_samples": wrote_luyan,
"size_mismatch_samples": size_mismatch,
"multi_match_samples": multi_match,
"merged_mask_written": merged_mask_written,
"hit_only_dataset_size": len(hit_only_lines),
}
return summary
def main():
parser = argparse.ArgumentParser(
description="离线融合RM路沿标签到FSD标签,输出训练可用txt(全量 + 命中RM子集)"
)
parser.add_argument("--fsd-txt", nargs="+", required=True, help="FSD txt列表(如train/val)")
parser.add_argument("--rm-txt", nargs="+", required=True, help="RM txt列表(如train/val)")
parser.add_argument(
"--match-mode",
choices=["fullpath", "basename"],
default="fullpath",
help="FSD与RM图片匹配方式:完整路径或文件名",
)
parser.add_argument(
"--case-insensitive",
action="store_true",
help="路径匹配大小写不敏感(默认false)",
)
parser.add_argument(
"--rm-curb-value",
type=int,
default=233,
help="RM中路沿像素值(默认233)",
)
parser.add_argument(
"--fsd-curb-class",
type=int,
default=2,
help="FSD中路沿类别id(默认2)",
)
parser.add_argument(
"--mask-suffix",
default="_luyan_mirrorfold",
help="融合后mask后缀,保留原扩展名",
)
parser.add_argument(
"--txt-suffix",
default="_luyan_mirrorfold",
help="输出txt文件名后缀",
)
parser.add_argument(
"--out-dir",
required=True,
help="输出txt目录(会生成全量txt和hit_rm_only txt)",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="若目标融合mask已存在则覆盖",
)
args = parser.parse_args()
rm_index = build_rm_index(args.rm_txt, args.match_mode, args.case_insensitive)
all_summary = []
for fsd_txt in args.fsd_txt:
summary = merge_one_fsd_txt(fsd_txt, rm_index, args)
all_summary.append(summary)
print(json.dumps(summary, ensure_ascii=False, indent=2))
total_summary = {
"num_fsd_txt": len(all_summary),
"total_fsd_samples": sum(x["total_fsd_samples"] for x in all_summary),
"hit_rm_samples": sum(x["hit_rm_samples"] for x in all_summary),
"wrote_luyan_pixel_samples": sum(x["wrote_luyan_pixel_samples"] for x in all_summary),
"size_mismatch_samples": sum(x["size_mismatch_samples"] for x in all_summary),
"multi_match_samples": sum(x["multi_match_samples"] for x in all_summary),
"merged_mask_written": sum(x["merged_mask_written"] for x in all_summary),
"hit_only_dataset_size": sum(x["hit_only_dataset_size"] for x in all_summary),
}
print("=== TOTAL SUMMARY ===")
print(json.dumps(total_summary, ensure_ascii=False, indent=2))
if name == "main":
main()
=============
已帮你实现离线融合脚本,满足你提的全部要求。
已实现内容
新增脚本:tools/build_fsd_luyan_from_rm.py
它支持:
- 读取多个 FSD txt(例如 train/val)和多个 RM txt(例如 train/val)
- 每行按 三个空格 分隔(与你 datasettxt.txt 格式一致)
- 两种匹配方式(可开关):
- fullpath:按图片完整路径匹配
- basename:按文件名匹配
- 路沿值可配(你现在默认 RM路沿=233)
- FSD 路沿类别可配(默认写成 2)
- 融合后标签命名规则:
- 原:xxx_gt.png
- 新:xxx_gt_luyan_mirrorfold.png
- 多匹配处理策略:
- 默认取第一个(按 RM txt 读取顺序)
- 同时统计 multi_match_samples
- 输出两类训练可用 txt:
- 全量(未命中的保留原mask路径)
- 仅命中RM子集(只包含成功命中且可融合样本)
- 输出你要求的统计:
- total_fsd_samples
- hit_rm_samples
- wrote_luyan_pixel_samples
- size_mismatch_samples
- 另外还给了 multi_match_samples、merged_mask_written
运行示例(按你4个txt场景)
python tools/build_fsd_luyan_from_rm.py ^
--fsd-txt fsd_train.txt fsd_val.txt ^
--rm-txt rm_train.txt rm_val.txt ^
--match-mode fullpath ^
--rm-curb-value 233 ^
--fsd-curb-class 2 ^
--out-dir merged_txt ^
--overwrite
如果你想按文件名匹配:
--match-mode basename
输出结果说明
假设 fsd_train.txt、fsd_val.txt 输入,则会在 --out-dir 下生成:
- fsd_train_luyan_mirrorfold.txt
- fsd_train_luyan_mirrorfold_hit_rm_only.txt
- fsd_val_luyan_mirrorfold.txt
- fsd_val_luyan_mirrorfold_hit_rm_only.txt
其中 mask 路径会指向你要求的后缀文件,例如:
.../NOR_xxx_gt_luyan_mirrorfold.png
如果你愿意,我下一步可以再给你加一个 --summary-json,把总统计和每个txt统计落盘成json,方便你汇报时直接引用