U-Net 训练光纤识别

U-Net

U-Net和DeepLab系列是语义分割领域两种最经典的模型,各有侧重。简单来说:U-Net 以其对称的 U 型结构和跳跃连接见长,擅长精细分割,尤其适用于医学图像领域 ;而 DeepLab 系列则通过空洞卷积和ASPP模块,在多尺度特征提取和大场景语义理解上更具优势

它们的核心异同,可以从下表直观地看个大概:

对比维度 👨‍⚕️ U-Net 🏙️ DeepLab (v3+)
核心架构 对称的编码器-解码器结构,呈U形 编码器-解码器 + 空洞空间金字塔池化(ASPP)
核心创新 跳跃连接,直接融合编码器的浅层细节与解码器的深层语义 空洞卷积 扩大感受野 + ASPP捕获多尺度上下文信息
关键优势 对小目标和精细边界分割效果好;小样本学习能力出色,适合数据稀缺场景 多尺度特征提取 能力强,能有效处理图像中大小不一的目标;综合精度通常更高
典型应用 医学影像分析(细胞、肿瘤分割)、遥感建筑物提取 自动驾驶(道路、行人识别)、遥感地物分类

👨‍⚕️ U-Net:医学分割领域的"精确手术刀"

U-Net在2015年由Ronneberger等人提出,最初是为生物医学图像分析设计的。

  • 核心架构:对称的"U"型结构

    其结构由一个用于捕获上下文信息的收缩路径(编码器) 和一个用于精确定位的扩展路径(解码器) 组成,整体呈对称的U形。编码器通过池化层逐步降低图像分辨率、提取高级语义;解码器则通过上采样逐步恢复图像细节和分辨率。

  • 核心创新:跳跃连接

    这是U-Net的精髓所在。它将编码器中具有丰富空间信息的高分辨率特征图 与解码器中具有丰富语义信息的低分辨率特征图进行拼接(Concatenate),实现跨层特征融合。这能帮助模型在分割时"看见"更精确的边界。

  • 关键优势

    正如一位技艺精湛的外科医生,U-Net的优势在于"精准":

    • 精细分割:在医学影像等需要边界清晰的任务中表现卓越。
    • 小样本学习:能够在标注数据稀缺(如仅几十张图)的情况下,通过数据增强等技术获得出色的分割效果。

🏙️ DeepLab (v3+):自然场景理解的"多面手"

DeepLab系列是Google团队提出的一套语义分割模型,历经v1, v2, v3,到v3+版本达到了集大成者的高度。

  • 核心架构:编码器-解码器 + ASPP

    DeepLabv3+结合了编码器-解码器结构和ASPP模块。编码器负责提取特征,再由ASPP进行多尺度处理,最后解码器恢复分辨率。

  • 两大核心技术

    1. 空洞卷积 :通过在卷积核元素间插入"空洞"来扩大感受野,让模型在不增加计算量的情况下"看"到更大范围的上下文信息。
    2. ASPP模块:并行的空洞卷积层,分别捕捉不同尺度的上下文信息。这使得模型能同时"看"清近距离的小物体和远处的大物体,非常灵活。
  • 关键优势

    DeepLab更像一个视野开阔的"多面手",优势在于"全面":

    • 多尺度感知:通过ASPP模块,能更好地理解场景中大小不一的目标,在复杂环境(如自动驾驶、遥感)中表现优异。
    • 整体精度高:在PASCAL VOC、Cityscapes等通用语义分割基准测试中,DeepLab系列通常能达到顶尖水平。

🆚 对比总结:如何选择?

U-Net和DeepLab各有千秋,选择哪个取决于你的具体任务和数据。下面这张图可以帮你在不同应用场景下做出判断:
语义分割任务
应用场景?
"医学影像分析(如细胞、肿瘤分割)"
"自动驾驶场景(如道路、行人识别)"
"遥感图像解译(如地物分类、建筑物提取)"
"工业瑕疵检测(如产品表面缺陷)"
"首选 U-Net优势:小样本学习能力强分割边界精细"
"DeepLab 系列优势:多尺度特征提取应对复杂场景"
"地物尺度差异大:DeepLab小规则地物:U-Net"
"瑕疵尺度小:U-Net瑕疵类型多样:DeepLab"
"综合考量模型复杂度计算资源与推理速度要求"

  • 选择 U-Net :如果你的数据量不大(尤其在医学影像领域),或任务对分割边界的精细度要求极高,U-Net是极佳的起点和强大的基线模型。其轻量化设计也更易于部署。
  • 选择 DeepLab :如果处理的是自然场景图像,目标尺寸变化剧烈(如自动驾驶、遥感分析),且追求更高的整体精度,DeepLab系列通常是更合适的选择。
  • 关注最新进展:两者都有丰富的变体,如U-Net++、Res-UNet,以及轻量化的MFA-DeepLabv3+等,可根据实际需求进一步选择。

U-Net介绍:

U-Net是语义分割领域一个影响力深远的经典模型,它凭借精巧的"U"型结构和创新的"跳跃连接"设计,在数据稀缺的任务中展现出了非凡的能力。

🎯 U-Net:从医学影像出发的"像素级艺术家"

U-Net诞生于一个非常实际的需求------医学图像分割 。在医学领域,获取大量高质量的标注数据极为困难且昂贵。传统方法难以兼顾"理解"图像整体和"看清"局部细节。为了解决这个核心矛盾,2015年,Olaf Ronneberger等人在论文《U-Net: Convolutional Networks for Biomedical Image Segmentation》中提出了U-Net。它最初的成功是一个标志性事件:在仅有30张左右标注图像的情况下,它在ISBI细胞分割挑战赛中就以显著优势夺得了冠军。

🧬 核心架构:对称的"U"型设计

U-Net的名字源于其标志性的对称"U"型结构。这个结构主要由两条路径和一个精妙的设计组成:

  • 收缩路径(编码器) :位于网络左侧,负责提取图像特征并"理解"语境。它通过重复应用卷积(如3x3卷积)和池化(如2x2最大池化)操作,逐步压缩图像的空间尺寸,同时增加通道数,从而捕捉更抽象、更高级的语义特征。例如,输入一张512x512的图像,经过4次下采样后,会缩小到32x32大小,但通道数可能从最初的个位数增加到1024,信息高度浓缩。

  • 扩展路径(解码器) :位于网络右侧,负责恢复图像细节并进行精确定位。它使用上采样(如转置卷积)逐步将特征图恢复到原始输入图像的尺寸。

  • 跳跃连接(Skip Connections) :这是U-Net的核心创新,也是其成功的关键 。它将编码器每一层产生的、富含空间位置信息的高分辨率特征图 ,直接"跳跃"连接到解码器对应的层上,并在通道维度上进行拼接(Concatenation)。这解决了深层网络易丢失空间信息的问题,使得最终分割结果的边界异常清晰、精准

  • 最终输出 :经过一系列上采样和特征融合后,网络最后使用1x1卷积将通道数压缩为目标类别数(如肿瘤、血管等),并对每个像素应用Softmax函数,生成一张与原图大小相同的分割掩码(Segmentation Mask)。掩码的每个像素值代表了该点属于每个类别的概率。

📚 训练策略:应对数据稀缺的"魔法"

U-Net之所以能在少量数据上训练,关键在于其强大的数据增强(Data Augmentation) 策略。

作者在训练时对有限的训练图像进行大量的、密集的随机变形,以模拟真实世界中可能出现的各种形态变化,从而生成大量"新"的训练样本 。常见的策略包括随机旋转、弹性形变、缩放、剪切和亮度调整等。此外,为了应对医学图像中前景(病灶)与背景像素数量极不平衡的问题,U-Net采用了加权交叉熵损失函数(Weighted Cross-Entropy Loss),给予前景更高的权重,引导模型更加关注学习目标区域。

💡 关键优势与局限性

下表总结了U-Net的主要优势与局限性:

✅ 优势 (Strengths) ❌ 局限性 (Limitations)
小样本学习能力强:在标注数据稀缺(如几十张图像)的情况下表现优异,是医学等领域的首选。 感受野固定:标准卷积核(如3x3)捕捉全局或长距离依赖关系的能力有限。
分割精度高,边界清晰:跳跃连接的设计能有效保留空间细节,分割结果的边界尤其精准。 对小目标不友好:深层网络中,小目标的特征信息容易在多次下采样过程中衰减或丢失。
模型相对轻量:参数量适中(约7.8M),易于训练和部署。 语义信息与细节的权衡:跳跃连接结合浅层纹理和深层语义时,对尺度差异大的物体效果欠佳。
结构灵活,可扩展性强:其对称的U型框架非常经典,催生了U-Net++、3D U-Net、TransUNet等大量变体。 速度并非极致:相比一些实时性的模型(如某些轻量级Transformer),U-Net在极速推理场景下不占优势。

🌍 应用领域:从显微镜下到广阔天地

起初为细胞分割而生的U-Net,如今已被广泛应用到众多领域:

  • 🩺 医学影像分析 :依然是其最核心的应用场景,如肿瘤(肝、肺、脑等)分割器官(心脏、肝脏等)勾画血管和细胞的精准分割等。
  • 🛰️ 遥感图像分析 :用于建筑物提取道路网络构建土地利用分类森林砍伐监测等。
  • 🏭 工业视觉检测 :广泛应用于产品表面缺陷检测 (如划痕、凹坑)、焊缝质量评估印刷电路板瑕疵检测等领域。

🚀 演进之路:U-Net的主要变体

U-Net的成功催生了大量的改进工作,使其框架生命力愈发旺盛。部分有代表性的变体包括:

  • U-Net++ :核心是嵌套和密集的跳跃连接,加强了不同层级特征间的融合,在处理多尺度目标的复杂场景时效果更好。
  • Attention U-Net :在跳跃连接前引入注意力门(Attention Gate),让模型自动聚焦图像中的目标区域,有效抑制背景噪声。
  • Res-UNet :将残差块融入U-Net,通过跨层连接构建残差学习,支持训练更深的网络。
  • 3D U-Net:将标准U-Net的2D卷积/池化操作扩展到3D,直接处理CT、MRI等三维体积数据。
  • nnU-Net (no-new-Net):自动为新的医学图像分割任务动态配置最优的预处理、网络结构和训练方案,显著降低了使用门槛。
  • TransUNet :将Transformer作为编码器,利用其强大的自注意力机制来捕捉全局上下文信息,结合U-Net的解码器恢复细节。

更前沿的研究还探索了将U-Net与Mamba状态空间模型扩散生成模型 等新技术结合,或用于图像超分辨率等新领域。

💎 总结

U-Net不仅仅是一个模型,更成为了一种跨越多个行业的设计模式与基准。 它深刻洞察了"理解全局"与"保持细节"这一图像分割的核心矛盾,并通过"U形"架构和"跳跃连接"给出了经典又极具影响力的解决方案。

1D 序列模型:

与图像类似,序列也蕴含着丰富的"空间"结构,不过是沿着一个维度(通常是时间) 方向变化的。如果说2D模型像像素"画家",那1D模型就更像一位能理解顺序和节律的"捕手"。它的核心任务是理解数据的前后依赖关系,并从连续的序列中,精准地找出特定模式的开始与结束位置(即"语义分割")。

想象一下你想在一段长达几天的脑电波记录里,精准地找出睡眠周期切换的那些瞬间,这正是1D序列模型最擅长的。

与U-Net这些2D模型相比,它们最核心的区别可以这样理解:

特性 2D模型 (如 U-Net, DeepLab) 1D序列模型
处理数据 平面上延伸的像素,例如图像、特征图。 直线上延伸的符号,例如时间序列、音频波形、文本句子、DNA序列。
核心操作 2D卷积核在高 × 宽的平面上滑动,捕捉局部纹理、形状。 1D卷积核沿着单一线性轴(通常是时间轴)滑动,捕捉局部的时序依赖或上下文关系。
物理意义 特征在二维空间中的位置关系(如上、下、左、右)至关重要,代表了视觉结构。 顺序和时序上的前后依赖至关重要。在某个时间点之前发生了什么,强烈影响着对当前时刻的判断。

简单来说,2D模型在空间里找形状,而1D模型在时间里找节律。

✨ 主要架构

1D序列模型家族成员众多,各自拥有独特的能力:

  • 1D U-Net :通过将原版U-Net的二维卷积替换为一维操作,能够为序列中的每个时间点生成高精度的分割掩码,非常适合在脑电图、心电图等生物医学信号中精准定位异常事件。
  • 时序卷积网络 (TCN) :通过堆叠一维卷积并使用空洞卷积 (Dilated Convolution)技术,能指数级地扩大感受野,同时计算高度并行,特别适合需要全局视角分析长序列的任务。
  • 循环神经网络 (RNN/LSTM):这类网络拥有一种"循环"的思想,像人类阅读一样,按顺序"记住"前面看到的信息并传递给下一步,对处理动态变化明显的序列尤其关键。
  • Transformer :其核心是自注意力机制。它能直接"看到"序列中任意两个位置之间的关系,不存在信息遗忘的问题。通过将信号切分成Token进行学习,在处理非局部的干扰模式时优势显著。

🌍 应用领域

这些模型的实用性早已渗透到众多领域:

  • 🫀 生命科学 :通过分析脑电图(EEG) 自动识别癫痫发作,或通过心电图(ECG) 实时监测心律失常。
  • 🎶 音频处理 :在嘈杂环境中实现语音识别;或自动识别音乐中的歌手、流派等;以及分离歌曲中的人声与伴奏等。
  • 📝 文本分析 :在句子中检测并标记有害片段;或对整篇文章进行情感倾向分析。
  • ⚙️ 工业与金融 :分析传感器数据提前进行设备故障预测与维护 ;或分析股票价格、交易量数据进行时间序列预测

随着技术的演进,1D与2D的界限也正在被打破。像Mamba这样的新型状态空间模型,就在尝试通过特殊的扫描策略,用更接近1D的处理方式高效解决2D视觉任务,实现了效率与效果的新平衡。

以上是对1D序列模型的整体介绍。如果你想对某个具体架构,比如1D U-Net、LSTM或Transformer,进行更深入的细节探讨,随时可以告诉我。

claude总结的项目信息:

我在做一个 WPF 桌面工具(IntelligentSorting),用于产线工人对带光纤的颜色排序检测。

现在想用深度学习模型替代当前的规则法颜色识别。

===== 项目基本信息 =====

  • 技术栈:.NET 10 + WPF + MVVM + SQLite + OpenCvSharp4
  • 三层架构:IntelligentSorting(UI) / IntelligentSorting.Core(算法) / IntelligentSorting.Data(数据)
  • 颜色库存储在 SQLite,字段用 BLab(避免与 Color 类型冲突)
  • 现有 ICameraService 抽象,支持 Uvcham + Uvcsam 双 SDK 自动探测
  • 白平衡校准已实现(开机自动 AWB → 冻结增益)

===== 产品形态 =====

  • 产品 = N 根并排光纤(竖直排列,水平铺开)
  • 工人能告诉我们 N(根数)
  • 工人在画面上拉一个横长方形 ROI 横穿所有光纤
  • N 通常是 4、6、12、16 等,光纤宽度大致一致(同型号同直径)
  • 颜色种类有限(可能 20-30 种全行业总和)
  • 同一产品的光纤间可能有窄黑缝(光纤间隙)
  • 圆柱光纤中央有高光,两侧有暗影

===== 已经完成的(规则法) =====

  1. 拍照识别功能完整跑通(产品输入面板的"📷 拍照识别色序"按钮)
  2. 本地图片识别功能完整跑通("📁 从图片识别"按钮)
  3. 工人输入根数 N 的对话框 FiberCountInputDialog
  4. 识别结果确认对话框 ProductRecognitionConfirmDialog(支持UseExisting/Overwrite/CreateNew)
  5. 当前算法 RecognizeByEqualPartition:工人 N + ROI → ROI 等 N 分 → 每段中心 40% 区域采样 RGB → LAB 距离匹配颜色库
  6. 调试图保存功能 SaveRoiDebugImage:把 ROI + 切刀线 + 采样窗口画出来存到桌面 BMP

===== 规则法的剩余问题 =====

  • ROI 必须工人精确拉到光纤束两端(否则等分会偏移,出现"段位置错位 1 根"和"段含背景")
  • 已经设计了"自动边界检测"DetectFiberBundleBounds,基于亮度阈值找光纤束真实左右边界,但尚未测试
  • 接下来工作还需要验证自动边界 + 等分能否做到 90%+ 准确

===== 现在要做的:用深度学习模型替换"段 → 颜色"这一步 =====

走的路线:逐根分类(Image Classification)

  • 切分:沿用规则法(等分中心采样)→ 切出 N 个小图块(约 50×100 像素)
  • 分类:每个小图块丢进小型 CNN,预测它属于哪种颜色
  • 推理:Python 训练 PyTorch 模型 → 导出 ONNX → C# 用 Microsoft.ML.OnnxRuntime 加载

===== 模型设计 =====

  • 输入:64×64 RGB(从 50×100 小图块 resize)
  • 模型:3 层 CNN(Conv16/32/64 + ReLU + Pool + GAP + FC)
  • 输出:K 类颜色(K = 全行业可能的颜色数,初步 20-30)
  • 模型大小:~30 KB(几千参数)

===== 数据标注方案 =====

  1. 每种产品采集 30-50 张照片(同光照、同相机、不同水平摆位 + 微小旋转)
  2. 用现有的"本地图片识别"工具,先用规则法切分,得到小图块作为初稿
  3. 工人按颜色分文件夹归类:
    data/
    ├── 蓝/img001.png
    ├── 红/img001.png
    └── ...
  4. 工作量预估:30 张原图 = 360 个小图,人工归类 20 分钟

===== 训练 + 部署路线图 =====

阶段 工作 时长
1 收集 30-50 张产品照片 1 天
2 标注:小图块按颜色归类 0.5 天
3 PyTorch 训练 CNN 0.5 天
4 导出 ONNX 1 小时
5 C# 接入 ONNX Runtime 0.5 天
6 测试 + 微调 1 天

===== 当前需要你做的 =====

请先给我:

  1. 完整的 PyTorch 训练脚本(包含数据加载、模型定义、训练循环、ONNX 导出)
  2. C# 端 ONNX Runtime 推理代码模板(可以集成到 IntelligentSorting.Core 项目)
  3. 数据集组织规范(目录结构、命名规则、训练/验证集划分)
  4. 训练时的注意事项(数据增强、防过拟合、迁移学习选择等)

===== 关键约束 =====

  • 训练用 Python,推理用 C#(.NET 10)
  • 模型推理速度要快(单张 < 10ms,产线检测时段实时性要求高)
  • 模型要能在没有 GPU 的工业电脑上跑(CPU only)
  • 颜色种类后续可能扩展,模型要方便重新训练(增量学习/全量重训都行)
  • 不依赖 OpenCV Python 版(C# 端已有 OpenCvSharp,Python 端用 PIL/numpy 即可)

===== 颜色种类范围(初步) =====

深蓝、棕、绿、暗绿、暗棕、灰白、白、红、深红、暗红、黑、黄绿、亮黄绿、紫、粉紫、浅紫蓝、浅蓝、橙、黄、青绿(具体需采集后确定)

请你针对以上情况开始指导:从数据集组织开始。

收到。先把你担心的两个问题先解掉,再给完整训练方案。


一、关于"贴死光纤靠后处理切"------我有多大把握

直说:80% 把握能解决,但需要一个兜底机制

为什么有把握 :你这个产品有一个极其有利 的物理约束------同一根光纤束里,所有光纤直径完全一样(同型号同直径)。这意味着:

  • 模型输出一整块白色连通域,宽度 W;
  • 这一束总共 N 根(工人会告诉你);
  • 每根光纤标称宽度 w = W / N;
  • 即便中间没有缝隙,我也能按 w 等分切开,切出来的位置就是物理上的真实位置

这不是猜,是几何必然。只要 U-Net 把"光纤束的左右总边界"找准了,内部怎么切都是确定的。

为什么不是 100%:

  • 如果整张图里有多个独立连通域 (比如左边 5 根贴死成一块、中间有个明显大缝、右边 7 根贴死成一块),要正确分配"哪一块有几根"。解法:按宽度比例分配,即左块占总宽 5/12 → 5 根,右块占总宽 7/12 → 7 根。这一步在 OpenCV 里 10 行代码搞定
  • 如果 U-Net 把光纤束的外边界 判错了(比如把蓝色光纤外缘的一两个像素误判为背景),W 会偏小,等分位置全部漂移。这才是真风险

兜底机制:

  1. 切完后回查:每根切出来后,取中心 60% 区域算平均颜色;如果某根的平均颜色和左右邻居的颜色"过渡太平滑"(LAB 距离<阈值),说明切刀位置可能压在两根之间,触发警报让工人确认。
  2. 失败回退:如果连通域总数 K > N(模型把一根中间切断了),或 总宽 W 异常小,自动切回纯规则法的"工人 ROI 等分",并提示工人重拉 ROI。

这两层保护下来,实际可用率应该在 95%+。

结论:可以这么干,我有把握。


二、关于"像素可调,光纤占多少像素还能算吗"------这是好问题

你说得对,我之前算的"单根 30 像素宽"是基于你那张特定分辨率的图。像素数本身不是关键,关键是"光纤束在画面里的相对占比"

正确的思考方式:

  • 相机原始分辨率:1920×1080 或 2560×1440,看你 SDK 输出。
  • 工人拉的 ROI:横长方形,假设宽 W_roi 像素,高 H_roi 像素。
  • 送入 U-Net 的输入 :固定尺寸 ,我建议 512×128(宽×高)。所有 ROI 不管原始多大,先 resize 到 512×128 再喂给模型。
  • 模型输出 mask:也是 512×128,再 resize 回原始 ROI 尺寸做后处理。

为什么固定 512×128:

  • 512 宽度足够容纳 32 根光纤(每根至少 16 像素宽,U-Net 切得动)。
  • 128 高度足够采集到光纤的高光/暗影特征(不丢圆柱体反光信息)。
  • 4:1 比例匹配 ROI 的横长方形形状。
  • U-Net 输入尺寸固定后,推理时间也固定,产线节拍可预测。

所以以后你看这张你刚发的新图(1920×1538 之类),不用关心"光纤占多少像素",关心的是"光纤束占画面横向的多少比例"。新图我大概数了下,12 根光纤束占画面横向约 18% ,也就是约 350 像素宽。resize 到 512 后,12 根占约 92 像素,每根 7-8 像素宽------够 U-Net 处理了。32 根也就是每根 2.5 像素,这就到极限了,所以 32 根上线时可能要把输入分辨率提到 768×192。

这一点先记住,后面训练时会再用到


三、对这张新图的关键观察

这张全景图给了我两个之前没看到的信息:

1. 背景比第一张更暗、更纯

第一张是灰色绒布,这张几乎是黑色绒布。说明产线背景不是固定一种 ,可能换工位换批次会变。训练集必须包含至少 2 种背景,否则模型会过拟合到某一种背景颜色。

这是我之前没强调够的 :50 张图里,至少分两批拍,一批黑底、一批灰底(用你第一张图的那块绒布)。如果产线还会出现其他底色,也要拍到。

2. 光纤"高光线"非常明显

每根光纤中央有一条几乎纯白的反光线。这条线在视觉上极其显眼,U-Net 会很快学会用它定位光纤中心。这是好事,意味着即便贴死的两根,模型也可能从两条高光线之间看出"这里其实是两根"------某些情况下可以辅助切分。

3. 颜色覆盖

这一张图里我数到的颜色:蓝、棕、绿、黑、灰白、白、深红、暗红/黑(可能两根)、亮黄绿、淡紫、浅蓝、深蓝(可能就是最左那根)。一张图就覆盖了 10+ 种颜色,说明 50 张图很容易覆盖到 20-30 种全行业颜色,数据效率很高。


四、完整训练方案

下面是给你的执行方案,按时间顺序写,你照着做就行。

阶段 1:数据采集(1 天)

数量:50 张

  • 40 张 12 根产品
  • 10 张其他根数混搭(4 根 / 6 根 / 8 根各 3-4 张)
  • 暂不拍 32 根,留到验证后再决定

拍照规范:

  • 一半黑底绒布 / 一半灰底绒布
  • 光照变化 3 档:正常顶灯、加侧光、关一盏灯
  • 光纤束在画面里位置变化:偏左、居中、偏右
  • 轻微角度倾斜 0-5°(不要超过)
  • 故意拍几张"难样本":有灰尘、有划痕、光纤束没完全水平

目录结构:

复制代码
dataset/
├── raw/                    # 原始相机拍的图,完整画面
│   ├── 12fiber_001.png
│   ├── 12fiber_002.png
│   ├── 04fiber_001.png
│   ├── 06fiber_001.png
│   └── 08fiber_001.png
├── roi/                    # ROI 截图(从 raw 里裁出来,横长方形)
│   ├── 12fiber_001.png     # 注意:这里只是 ROI,不需要 resize 到 512×128
│   └── ...                 # resize 在训练脚本里做
└── mask/                   # mask 标注,和 roi 文件名一一对应
    ├── 12fiber_001.png     # 单通道 PNG,像素值 0=背景,1=光纤
    └── ...

为什么 raw 和 roi 都存:

  • raw 留着,以后如果发现 ROI 拉得不好,可以重新裁;
  • roi 是给训练直接用的;
  • 训练时只用 roi/ 和 mask/。

阶段 2:标注(半自动,1 天)

核心思路:用你现有的规则法生成 mask 初稿,人工只做修正

具体流程:

Step 1:用你现有的 DetectFiberBundleBounds + 等分代码,跑一遍所有 ROI,自动生成初稿 mask

逻辑:

  1. 工人在每张 ROI 上手动指定 N(写在文件名里就行,比如 12fiber_001.png);
  2. 你的代码自动检测光纤束左右边界 → 等分 N 段 → 每段中心 80% 标为"光纤",段间标为"背景"或"缝隙";
  3. 输出 PNG mask(0=背景,1=光纤)。

这一步会生成 50 张粗 mask,准确率大概 70-85%

Step 2:用 Labelme 或 ISAT 等工具人工修正

推荐工具:ISAT_with_segment_anything (https://github.com/yatengLG/ISAT_with_segment_anything),开源、轻量、支持载入初稿 mask 后微调。或者更简单粗暴:直接用 PIL/PS 画------你 mask 就 0 和 1 两个值,用黑白画笔涂就行。

修正重点:

  • 光纤束左右最外边界必须精确(误差不超过 2 像素);
  • 内部条带宽度可以容忍小误差(因为后处理会按 N 等分);
  • 高光区域、暗影区域、缝隙位置统统标为光纤(类别 1)------这是个关键点,后面解释。

预估时间:每张 1-2 分钟,50 张约 1-1.5 小时。

Step 3:标注 review

标完之后,把所有 mask 叠加到原图上,人眼快速过一遍看有没有漏。这步 10 分钟。

阶段 3:模型设计(已定)

架构:轻量 U-Net

复制代码
输入: 3×128×512  (CHW, RGB, float32, 归一化到 [0,1])

Encoder:
  Conv(3→16) + BN + ReLU + Conv(16→16) + BN + ReLU + MaxPool   →  16×64×256
  Conv(16→32) + BN + ReLU + Conv(32→32) + BN + ReLU + MaxPool  →  32×32×128
  Conv(32→64) + BN + ReLU + Conv(64→64) + BN + ReLU + MaxPool  →  64×16×64

Bottleneck:
  Conv(64→128) + BN + ReLU + Conv(128→128) + BN + ReLU         →  128×16×64

Decoder:
  Upsample + Conv(128+64→64) + BN + ReLU + Conv(64→64)         →  64×32×128
  Upsample + Conv(64+32→32) + BN + ReLU + Conv(32→32)          →  32×64×256
  Upsample + Conv(32+16→16) + BN + ReLU + Conv(16→16)          →  16×128×512

Head:
  Conv(16→1) + Sigmoid                                         →  1×128×512

参数量:约 50-80 万,模型文件 ONNX 后约 2-3 MB。

推理时间预估(i5-12600KF CPU,ONNX Runtime):单张 ROI 推理 20-40ms。

输出:单通道概率图,每个像素 [0,1] 表示"是光纤"的概率,阈值 0.5 二值化。

阶段 4:损失函数与训练策略

损失函数 :BCEWithLogitsLoss + DiceLoss 加权组合

  • BCE 负责像素级精度;
  • Dice 负责整体区域形状(对小目标和边界更友好);
  • 权重 0.5 : 0.5。

为什么不用纯 BCE:你的 mask 里光纤占比可能只有 15-25%(光纤束占画面一小部分),类别极度不平衡,纯 BCE 会让模型偷懒预测"全背景"。Dice 对类别不平衡免疫。

优化器:AdamW,初始 lr=1e-3,CosineAnnealing 调度,训练 100 epoch。

Batch size:8(3060 12G 完全吃得下,可以拉到 16,但 8 已经够稳)。

训练/验证划分 :42 张训练 / 8 张验证(其中验证集里故意放进 2-3 张非 12 根的图,用来盯泛化)。

阶段 5:数据增强(最关键的一环)

这是 50 张图能不能撑起来的关键

必须做的增强:

  1. 水平随机缩放 RandomResize(scale=[0.4, 1.6])------这是模拟 4/8/24/32 根的核心手段。12 根的图按 0.4× 水平压缩后,光纤变得很细,逼近 32 根的视觉效果;按 1.6× 拉伸后,光纤变粗,逼近 4 根的视觉效果。
  2. 亮度/对比度抖动 ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2)------模拟不同光照。
  3. 水平翻转 HorizontalFlip(p=0.5)------光纤束左右翻转,色序虽变但分割任务不在乎。
  4. 轻微旋转 Rotate(limit=5°)------模拟摆放不正。
  5. 高斯噪声 GaussNoise(var=10-30)------模拟相机噪点。
  6. 随机裁剪 RandomCrop------从大 ROI 里裁出不同位置的子区域,包括只包含部分光纤的子区域(强制模型学"局部也是光纤")。

不做的增强:

  • 不做色相旋转(HueShift)------会把蓝光纤变成红光纤,虽然分割任务不在乎,但万一以后扩展任务会出问题。
  • 不做垂直翻转------光纤是有方向的(高光在中央),垂直翻转后视觉一致,但不增加信息。

库用 albumentations,这些增强都是现成的。

阶段 6:训练验证指标

主要看两个:

  1. IoU(Intersection over Union):训练目标,验证集上要 > 0.92。
  2. 根数计数准确率 :训练完后,在验证集上跑完整 pipeline(分割 → 后处理 → 数根数),看预测 N 和真实 N 是否一致。这才是最终业务指标,要 > 95%。

如果 IoU 高但根数准确率低,说明边界精度不够,要加大 mask 边界部分的损失权重。

阶段 7:导出 ONNX + C# 集成

ONNX 导出参数:

  • opset_version=17
  • dynamic_axes:输入 batch 维度动态(其他维度固定 128×512)
  • onnxsim 简化模型图

C# 端:

  • 包:Microsoft.ML.OnnxRuntime(CPU 版,体积小)
  • 接口设计:IFiberSegmentationService 抽象,有 SegmentationResult Segment(Mat roi, int expectedN) 方法,返回 N 个光纤的 (int xStart, int xEnd) 列表。
  • 推理前后处理:输入 BGR→RGB→归一化→NCHW;输出 sigmoid→阈值→连通域分析→宽度等分。

五、整体时间表

阶段 工作 时长 累计
1 拍 50 张图 1 天 1 天
2 半自动标注 50 张 mask 1 天 2 天
3 PyTorch 训练脚本 + 训练完成 0.5 天 2.5 天
4 验证 + 调参 0.5 天 3 天
5 ONNX 导出 + C# 集成 1 天 4 天
6 产线测试 + 兜底逻辑联调 1 天 5 天

5 个工作日交付可用版本


六、下一步

方案就这样定下来。我建议的下一步动作:

  1. 你先去拍照------这是当前唯一阻塞项,我写代码再快也得等图。
  2. 拍照过程中,如果发现"难样本"特别难(比如某种颜色组合特别容易混淆),拍多几张,告诉我。
  3. 拍完 10-15 张后先发我看看,我帮你确认拍照规范执行到位了没有,避免拍完 50 张才发现都不合规
  4. 拍照同步进行的话,我可以先开始写 PyTorch 训练脚本和 C# 推理接口的骨架,等你图来了直接训。

确认下:我现在就开始写训练脚本和 C# 集成骨架,还是等你先拍 10 张回来再开始?

开始写代码,我的计划

按以下顺序输出,每个文件都会写完整代码 + 注释,你拿到就能用:

Python 端(训练):

  1. requirements.txt --- 依赖清单
  2. preprocess_roi.py --- 把 3840×3040 BMP 裁成 ROI(交互式或自动)
  3. generate_initial_masks.py --- 用规则法生成 mask 初稿(给标注省时间)
  4. dataset.py --- PyTorch Dataset,含 albumentations 数据增强
  5. model.py --- 轻量 U-Net 定义
  6. train.py --- 训练主循环 + 验证 + 保存最佳模型
  7. export_onnx.py --- PyTorch → ONNX 导出 + onnxsim 简化
  8. test_inference.py --- Python 端 ONNX 推理测试,用于验证导出正确

C# 端(推理,集成到 IntelligentSorting.Core) :

  1. IFiberSegmentationService.cs --- 接口定义

  2. OnnxFiberSegmentationService.cs --- ONNX Runtime 实现

  3. FiberSegmentationResult.cs --- 结果数据结构

  4. PostProcessor.cs --- 连通域分析 + 宽度等分后处理

  5. NuGet 包清单和注册方式说明

  6. 文档*:

  7. README.md --- 完整使用流程:从拍照到产线部署

可以验证,而且你要先理解一个核心点:

conda 环境之间默认是隔离的。

也就是说:

你以前装过的 PyTorch:

  • 大概率只存在于"某个环境"

  • 新环境默认看不到

  • 不能自动复用


一、先查看你有哪些 conda 环境

执行:

bash 复制代码
conda env list

或者:

bash 复制代码
conda info --envs

会看到类似:

text 复制代码
base                  *  D:\Anaconda
fiber                    D:\Anaconda\envs\fiber
yolo                     D:\Anaconda\envs\yolo
torch                    D:\Anaconda\envs\torch

* 的是当前环境。


二、检查某个环境里是否装了 PyTorch

比如:

bash 复制代码
conda activate yolo

然后:

bash 复制代码
pip show torch

如果已经装了,会显示:

text 复制代码
Name: torch
Version: 2.4.0
Location: D:\Anaconda\envs\yolo\Lib\site-packages

如果没装:

text 复制代码
WARNING: Package(s) not found: torch

三、验证 GPU 是否正常

继续执行:

bash 复制代码
python

进入 Python 后:

python 复制代码
import torch

print(torch.__version__)
print(torch.cuda.is_available())

if torch.cuda.is_available():
    print(torch.version.cuda)
    print(torch.cuda.get_device_name(0))

正常应该类似:

text 复制代码
2.4.0
True
12.1
NVIDIA GeForce RTX 3060

四、已装的 PyTorch 能不能"复用"到新环境

正常答案:不能直接复用

因为:

每个 conda 环境都有:

text 复制代码
自己的:
- python
- site-packages
- torch
- numpy
- opencv

相互隔离。

所以:

text 复制代码
yolo 环境里的 torch
≠
fiber 环境里的 torch

五、但是你可以"克隆环境"(推荐)

这是最好的方式。

比如你已经有一个:

text 复制代码
yolo

环境,里面:

  • torch

  • cuda

  • opencv

都好了。

那直接:

bash 复制代码
conda create -n fiber --clone yolo

即可。

优点:

  • 不用重新下载几十GB

  • CUDA/PyTorch 直接复制

  • 环境完全一致

  • 非常适合工业项目


六、这是我最推荐你的方案

方案A(推荐)

如果老环境已经稳定:

bash 复制代码
conda create -n fiber --clone 老环境名

比如:

bash 复制代码
conda create -n fiber --clone yolo

方案B

如果老环境乱七八糟:

重新创建干净环境。


七、你现在应该先做这个

先执行:

bash 复制代码
conda env list

然后:

bash 复制代码
conda activate 你的旧环境

再:

bash 复制代码
python

执行:

python 复制代码
import torch
print(torch.__version__)
print(torch.cuda.is_available())

如果你愿意,可以把输出贴给我。

我可以直接帮你判断:

  • 环境是否干净

  • CUDA 是否正常

  • 是否适合继续用于 YOLO/工业视觉

  • 是否建议 clone 还是重建

完全可以,而且强烈推荐你这么干。理由:

  1. 你这个 yolo 环境已经验证过 CUDA 可用(torch.cuda.is_available() = True,显卡识别正常),省去你从头折腾 CUDA/cuDNN 版本匹配。
  2. PyTorch 2.6.0 + CUDA 12.4 比我前面写的 2.4.0 + CUDA 12.1 还新,完全兼容我们的训练代码,U-Net 这种基础架构不挑版本。
  3. 克隆比新建快 10 倍以上,几秒钟搞定。

一、执行克隆

bash 复制代码
conda create -n fiber --clone yolo
conda activate fiber

完成后验证一下:

bash 复制代码
python -c "import torch; print(torch.__version__, torch.cuda.is_available(), torch.cuda.get_device_name(0))"

应该输出:

复制代码
2.6.0+cu124 True NVIDIA GeForce RTX 3060

二、补装我们项目额外需要的包

yolo 环境里已经有 torch / torchvision / opencv / numpy / pillow,但ultralytics 体系不一定装了我们需要的这几个,补一下:

bash 复制代码
conda activate fiber

pip install albumentations==1.4.14 onnx==1.16.2 onnxruntime==1.18.1 onnxsim==0.4.36

tqdm / matplotlib 大概率已经有了(ultralytics 依赖它们),但保险起见也加上:

bash 复制代码
pip install tqdm matplotlib

如果提示"Requirement already satisfied",忽略就行。


三、版本兼容性说明(给你心里有底)

你环境里的版本和我代码里写的有一点差异,但全部兼容,放心:

我原计划版本 你环境版本 兼容性
torch 2.4.0+cu121 2.6.0+cu124 ✅ 完全兼容
torchvision 0.19.0 0.21.0(配对 2.6.0) ✅ 完全兼容
Python 3.10 3.10.20 ✅ 一致
opencv 4.10.0 你环境里有 ✅ 我们用的 API 都很基础,任何 4.x 都行
numpy 1.26.4 大概率 1.26.x

唯一要注意的:ultralytics 包不要卸,留着不影响我们;但训练 U-Net 不依赖 ultralytics,我们自己定义模型。


四、ONNX Runtime 的一个特别提醒

我让你装的 onnxruntime==1.18.1CPU 版 。这是故意的:

  • Python 端训练完导出 ONNX 后,我们要做一次"Python 端 CPU 推理测试",验证导出正确。
  • C# 端产线推理也是 CPU 版,Python 端用同样的 CPU 版测试,结果完全对齐
  • 如果装 onnxruntime-gpu,Python 测试用 GPU、C# 跑 CPU,数值会有微小差异,排查 bug 时会蒙圈。

所以装 CPU 版就对了,不要装 gpu 版。训练用 PyTorch 直接吃 GPU,跟 ONNX Runtime 无关。


五、接下来你的操作

  1. conda create -n fiber --clone yolo + conda activate fiber

  2. 跑上面那条 pip install albumentations ... onnxsim 补包

  3. 按我上一条消息里的步骤,建 D:\FiberSeg 目录、放 60 张 BMP、存 preprocess_roi.py

  4. 在 fiber 环境里跑预处理:

bash 复制代码
conda activate fibercd D:\FiberSegpython scripts\preprocess_roi.py

裁完 60 张 ROI 之后,把 meta.csv 发我,我立刻发第二批代码(标注 + 训练 + 导出)。

有任何报错立刻贴,不要憋着自己 debug,效率低。

我在做一个 WPF 桌面工具(IntelligentSorting),用于产线工人对带光纤的颜色排序检测。

现在正在把规则法识别替换为 U-Net 深度学习方案。

===== 项目基本信息 =====

  • 技术栈:.NET 10 + WPF + MVVM + SQLite + OpenCvSharp4 + Microsoft.ML.OnnxRuntime
  • 三层架构:IntelligentSorting(UI) / IntelligentSorting.Core(算法) / IntelligentSorting.Data(数据)
  • 颜色库存储在 SQLite,字段用 BLab(避免与 Color 类型冲突)
  • 双相机抽象 ICameraService,支持 Uvcham + Uvcsam 双 SDK 自动探测
  • 白平衡校准已实现(开机自动 AWB → 冻结增益)
  • 图像数据全部 BGR 顺序(CameraFrame.Bgr 是 byte[],每像素 3 字节,B,G,R)

===== 项目目录结构 =====

IntelligentSorting/ ← WPF 主项目(UI 层)

├── App.xaml.cs ← DI 注册,启动入口

├── MainWindow.xaml(.cs) ← 主窗口

├── appsettings.json ← 配置(Sampling/Inspection 等)

├── Models/ ← 部署目录

│ └── fiber_unet_sim.onnx ← U-Net 模型(2.5MB)

├── Services/

│ ├── UvchamCameraService.cs ← 老相机 SDK

│ └── UvcsamCameraService.cs ← 新相机 SDK

├── ViewModels/

│ ├── MainViewModel.cs ← 主 VM,含 OnPhotoRecognizeAsync / OnImageFileRecognizeAsync

│ ├── Converters.cs ← BoolToVisibility 等,命名空间 IntelligentSorting.ViewModels

│ ├── ProductEditorViewModel.cs ← 产品输入面板 VM,用 CommunityToolkit.Mvvm 的 [RelayCommand]

│ └── ProductRecognitionViewModel.cs ← 识别确认页 VM(含 SegmentRowViewModel + SegmentAction enum)

└── Views/

├── InspectionPanel.xaml ← 标定/检测面板

├── ProductEditorPanel.xaml ← 产品输入面板(6行Grid,含 ROI 滑块 + 相机预览)

├── ProductRecognitionConfirmDialog.xaml(.cs) ← 识别结果确认对话框

├── FiberCountInputDialog.xaml(.cs) ← 工人输入根数 N 对话框

└── ImageFileRecognitionDialog.xaml(.cs) ← 本地图片+ROI+识别对话框

IntelligentSorting.Core/ ← 算法层(纯算法,不引用 WPF)

├── Abstractions/

│ ├── ICameraService.cs ← 相机抽象

│ └── IColorRepository.cs / IProductRepository.cs / IInspectionRepository.cs

├── Models/

│ ├── ColorEntry.cs(含 BLab 字段)

│ ├── Product.cs / FiberDirection.cs / RoiSettings.cs

│ └── InspectionResult.cs

└── Vision/

├── ColorSampler.cs ← 颜色采样(标定用)

├── ColorMatcher.cs ← LAB ΔE 匹配

├── ColorDistanceLab.cs ← CIEDE2000

├── InspectionService.cs ← 标定/检测服务(色序匹配版,t02 PASS)

├── FiberSegmenter.cs ← ★ 新增:U-Net ONNX 推理

├── ProductRecognitionService.cs ← ★ 拍照识别服务(含 RecognizeByEqualPartition + EnableUNet)

└── ProductRecognitionResult.cs ← 识别结果模型

IntelligentSorting.Data/ ← 数据层(SQLite)

├── DatabaseContext.cs

├── ColorRepository.cs / ProductRepository.cs / InspectionRepository.cs

===== 当前状态 =====

【规则法基础】已经全部完成:

  • 拍照识别功能("📷 拍照识别色序"按钮),含 FiberCountInputDialog 输入 N
  • 本地图片识别功能("📁 从图片识别"按钮),工人拉 ROI 测试
  • 识别结果确认对话框(UseExisting/Overwrite/CreateNew 三种动作)
  • 产品 t02(4 根:白/浅蓝/黄绿/蓝)已 PASS

【U-Net 深度学习方案】Python 端完成:

  • 数据集 48 张训练 mask(主力 12 根产品 41 张 + 8/5/4 根少量)
  • 训练完 val_iou=0.977
  • 导出 fiber_unet_sim.onnx(2.5MB)
  • CPU 推理 9-18ms
  • 输入 3×128×512(BGR→RGB→归一化 ImageNet 标准)
  • 输出 1×1×128×512 logits(sigmoid+阈值 0.5 得 mask)

【C# 集成】刚刚完成:

  • 新增 FiberSegmenter 类(Core/Vision/),封装 ONNX Runtime
  • ProductRecognitionService.RecognizeByEqualPartition 改造为:
    1. ExtractRoiMat 把 BGR byte[] 包成 OpenCV Mat
    2. _segmenter.Segment 输出二值 mask
    3. FindBundleBoundsFromMask 基于列前景占比找光纤束真实左右边界
    4. 在 mask 边界内等分 N 段,中心 40% 区域采样 RGB
    5. LAB ΔE 匹配颜色库
  • App.xaml.cs 注册时调 svc.EnableUNet(modelPath) 加载 ONNX
  • ProductRecognitionService 的 OnLog 委托记录详细诊断

【NuGet 依赖】

  • IntelligentSorting.Core 加了 Microsoft.ML.OnnxRuntime 1.18.1
  • 项目原有 OpenCvSharp4 + OpenCvSharp4.runtime.win

【模型部署】

  • fiber_unet_sim.onnx 放在 IntelligentSorting/Models/
  • csproj 配 PreserveNewest
  • 启动加载路径:AppDomain.CurrentDomain.BaseDirectory + "Models/fiber_unet_sim.onnx"

===== 关键设计决策(不可推翻) =====

  1. 拍照识别用规则法切分(等分中心采样),模型只做"切分"(分割光纤束 vs 背景),颜色匹配仍用 LAB
  2. 工人必须输入 N(根数),从 FiberCountInputDialog 取得
  3. ROI 是工人手动拉的,横长方形,横穿所有光纤
  4. 同一产品的光纤宽度大致一致(支持等分前提)
  5. 调试图保存到桌面的逻辑已删除(原 SaveRoiDebugImage)

===== 算法常量(ProductRecognitionService) =====

  • UseExistingThreshold = 5.0 (ΔE,匹配现有颜色)
  • AmbiguousThreshold = 12.0 (ΔE,模糊需确认)
  • PreviewMaxSize = 60 像素

===== 期待的日志输出 =====

Recognize\] ===== U-Net 分割 + 等分采样 N=12 ===== \[Recognize\] 帧=1920×1080 \| 原始ROI=(750,300,700×80) \| 颜色库=12 \[Recognize\] ROI Mat 准备好 700×80 \[Seg\] 输入ROI 700×80,缩放至 512×128 \[Seg\] ONNX 推理耗时 15ms \[Seg\] 前景像素占比 38.2% \[Recognize\] mask 找边界: \[120,520\] (mask 宽 700) \[Recognize\] 光纤束在 ROI 内: X=\[120,520\] 宽度=401px \[Recognize\] 每段宽度 = 33.42px \[Recognize\] 段\[1\] 像素X=\[870,902\] 宽33px RGB=(35,55,135) ... ... \[Recognize\] ✅ 完成 共 12 段 ===== 我现在的位置 ===== 代码集成全部完成,App.xaml.cs DI 注册也整理好了,接下来要: 1. 编译跑通,看启动日志是否成功加载 ONNX 2. 用本地图片识别功能测一张 12 根产品图 3. 看是否准确切出 12 段,色序对不对 4. 如果识别效果好,继续优化产线 UI(工人拉粗 ROI → 模型自动找边界) ===== 待办 ===== * 测试模型实际推理效果 * 如果光纤束边界检测在某些情况不准,考虑改用连通域分析 * 产线 UI 改造:工人拉粗 ROI 即可(不需要精确) * 接好相机(目前用本地图片测,产线相机色彩饱和度不足)定义目录标题) 贴下训练代码: 收到,数据集状态确认: 12 根: 41 张 8 根: 6 张 5 根: 3 张 (真实产品,保留) 4 根: 9 张 总计: 59 张 分辨率: 全部 3840×3040 ✓ **59 张,够用**。开始发第二批代码。 *** ** * ** *** ### 一、第二批代码的整体说明 这批代码做的事情: 1. **自动生成 mask 初稿**(基于规则法,你只需要人工修正,不用从零画) 2. 定义 **Dataset 类 + 数据增强**(albumentations,包含模拟不同根数的水平缩放) 3. 定义 **轻量 U-Net**(50-80 万参数,CPU 推理 20-40ms) 4. **训练脚本**(自动划分训练/验证集、保存最佳模型、记录指标) 5. **ONNX 导出 + Python 验证** **目录结构最终长这样**: D:/FiberSeg/ ├── dataset/ │ ├── raw/ # 你的原始 BMP │ ├── roi/ # 59 张 ROI(已完成 ✓) │ ├── mask_initial/ # ← 本批生成的 mask 初稿(规则法粗版) │ ├── mask/ # ← 你人工修正后的最终 mask(下一批讲怎么修) │ └── meta.csv # ✓ ├── scripts/ │ ├── preprocess_roi.py # ✓ │ ├── generate_initial_masks.py # ← 新增 │ ├── dataset.py # ← 新增(库,不直接跑) │ ├── model.py # ← 新增(库,不直接跑) │ ├── train.py # ← 新增 │ ├── export_onnx.py # ← 新增 │ └── test_inference.py # ← 新增 ├── checkpoints/ # 自动生成,存训练的模型 └── outputs/ # 自动生成,存可视化结果 *** ** * ** *** ### 二、`generate_initial_masks.py` --- 自动生成 mask 初稿 **这是核心省时工具** 。基于"光纤束亮度突变 + N 等分"的规则,自动生成粗 mask。质量大概 70-85%,**你人工修正只需要拖边界,不用从零画**。 ```python """ 基于规则法,自动生成 mask 初稿 逻辑: 1. 读取每张 ROI 和对应的 fiber_count 2. 沿垂直方向投影,找出光纤束的真实左右边界 3. 把光纤束区域标为前景(1),其他标为背景(0) 4. 保存到 dataset/mask_initial/ 后续: 你用图像编辑器(PS / 画图 / labelme / ISAT)在 mask_initial 基础上 修正边界,保存到 dataset/mask/ 使用: python scripts/generate_initial_masks.py """ import csv from pathlib import Path import cv2 import numpy as np from tqdm import tqdm PROJECT_ROOT = Path(__file__).parent.parent ROI_DIR = PROJECT_ROOT / "dataset" / "roi" MASK_INIT_DIR = PROJECT_ROOT / "dataset" / "mask_initial" META_CSV = PROJECT_ROOT / "dataset" / "meta.csv" PREVIEW_DIR = PROJECT_ROOT / "dataset" / "mask_initial_preview" # 检测参数 SMOOTH_KERNEL = 5 # 投影曲线平滑窗口 EDGE_THRESHOLD = 0.15 # 边界检测阈值(相对于最大亮度) MARGIN_RATIO = 0.02 # 边界外扩比例(2%) def detect_bundle_bounds(roi_bgr: np.ndarray) -> tuple[int, int]: """ 检测光纤束的左右边界 思路: - 转灰度后,沿 y 轴平均,得到 1D 水平亮度曲线 - 光纤区域亮度明显高于背景(黑色绒布) - 找到亮度 > 阈值的连续区域作为光纤束 返回: (x_left, x_right) ROI 内的边界 x 坐标 """ h, w = roi_bgr.shape[:2] gray = cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2GRAY) # 沿 y 轴平均(只看每一列的平均亮度) proj = gray.mean(axis=0).astype(np.float32) # 平滑 kernel = np.ones(SMOOTH_KERNEL) / SMOOTH_KERNEL proj_smooth = np.convolve(proj, kernel, mode='same') # 阈值法找前景 bg_level = np.percentile(proj_smooth, 20) # 背景亮度估计 fg_level = np.percentile(proj_smooth, 90) # 前景亮度估计 threshold = bg_level + (fg_level - bg_level) * EDGE_THRESHOLD is_fiber = proj_smooth > threshold # 找最长连续 True 区间作为光纤束 if not is_fiber.any(): return 0, w # 兜底:整个 ROI 都是光纤 indices = np.where(is_fiber)[0] # 简单粗暴:取首尾,而不是最长区间(避免被中间小缝隙断开) x_left = int(indices[0]) x_right = int(indices[-1]) + 1 # 安全检查 if x_right - x_left < w * 0.1: return 0, w # 检测失败,兜底 return x_left, x_right def generate_mask(roi_bgr: np.ndarray, fiber_count: int) -> np.ndarray: """ 生成单张 ROI 的 mask 初稿 类别: 0 = 背景 1 = 光纤(整个光纤束区域,不细分单根) 返回: (H, W) uint8 数组,值 0 或 1 """ h, w = roi_bgr.shape[:2] mask = np.zeros((h, w), dtype=np.uint8) # 检测左右边界 x_left, x_right = detect_bundle_bounds(roi_bgr) # 外扩 2%(避免边界压在光纤上) margin = int((x_right - x_left) * MARGIN_RATIO) x_left = max(0, x_left - margin) x_right = min(w, x_right + margin) # 整个光纤束区域标为 1 # 注意:不在垂直方向区分,因为 ROI 上下应该都是背景 # 这里假设光纤覆盖 ROI 完整高度 mask[:, x_left:x_right] = 1 return mask def save_preview(roi_bgr: np.ndarray, mask: np.ndarray, save_path: Path): """保存原图 + mask 叠加预览,方便人眼检查""" overlay = roi_bgr.copy() # mask 区域涂红色 mask_3ch = np.zeros_like(roi_bgr) mask_3ch[mask == 1] = (0, 0, 200) overlay = cv2.addWeighted(overlay, 0.7, mask_3ch, 0.3, 0) # 画边界线 if mask.any(): ys, xs = np.where(mask == 1) x_left = xs.min() x_right = xs.max() cv2.line(overlay, (x_left, 0), (x_left, mask.shape[0]), (0, 255, 0), 2) cv2.line(overlay, (x_right, 0), (x_right, mask.shape[0]), (0, 255, 0), 2) cv2.imwrite(str(save_path), overlay, [cv2.IMWRITE_PNG_COMPRESSION, 3]) def main(): MASK_INIT_DIR.mkdir(parents=True, exist_ok=True) PREVIEW_DIR.mkdir(parents=True, exist_ok=True) # 读 meta if not META_CSV.exists(): print(f"⚠ meta.csv 不存在: {META_CSV}") return rows = [] with open(META_CSV, "r", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: if row.get("status") == "ok": rows.append(row) print(f"找到 {len(rows)} 条有效 ROI 记录") print(f"开始生成 mask 初稿...\n") success = 0 failed = 0 for row in tqdm(rows, desc="生成 mask"): filename = row["filename"] fiber_count = int(row["fiber_count"]) roi_path = ROI_DIR / f"{filename}.png" if not roi_path.exists(): print(f"\n⚠ ROI 不存在: {roi_path}") failed += 1 continue roi_bgr = cv2.imread(str(roi_path), cv2.IMREAD_COLOR) if roi_bgr is None: print(f"\n⚠ 读取失败: {roi_path}") failed += 1 continue mask = generate_mask(roi_bgr, fiber_count) # 保存 mask(0 / 255 方便人眼看,训练时再转 0/1) mask_out = MASK_INIT_DIR / f"{filename}.png" cv2.imwrite(str(mask_out), mask * 255, [cv2.IMWRITE_PNG_COMPRESSION, 3]) # 保存预览 preview_out = PREVIEW_DIR / f"{filename}.png" save_preview(roi_bgr, mask, preview_out) success += 1 print(f"\n========== 完成 ==========") print(f" 成功: {success}") print(f" 失败: {failed}") print(f" mask: {MASK_INIT_DIR}") print(f" 预览: {PREVIEW_DIR}") print(f"\n下一步: 打开 {PREVIEW_DIR} 检查预览图") print(f" 绿色竖线 = 检测出的光纤束左右边界") print(f" 红色半透明 = mask 覆盖区域") print(f" 如果边界基本对(误差 5px 内),直接复制 mask_initial → mask") print(f" 如果边界差太多,人工修正(用 PS/画图,把红色区域改对)") if __name__ == "__main__": main() ``` **运行**: ```bash cd D:\FiberSeg python scripts\generate_initial_masks.py ``` 会生成: * `dataset/mask_initial/` --- 59 张 mask(值 0 / 255 的 PNG) * `dataset/mask_initial_preview/` --- 59 张叠加预览图,方便你检查 **重要** :跑完后**先打开 preview 目录,人眼快速过一遍**,看绿色竖线(检测出的边界)是不是大致贴合光纤束的左右边缘。 *** ** * ** *** ### 三、`dataset.py` --- PyTorch Dataset + 数据增强 ```python """ PyTorch Dataset 定义 输入: - dataset/roi/*.png ROI 图像 - dataset/mask/*.png 二值 mask(0 = 背景, 255 = 光纤) - dataset/meta.csv 元数据(fiber_count 等) 输出: - 训练样本: (image_tensor 3x128x512, mask_tensor 1x128x512) 数据增强(训练集): - 随机水平缩放 0.4~1.6x (模拟不同根数 4/8/12/24/32) - 颜色抖动(亮度/对比度/饱和度) - 水平翻转 - 轻微旋转 ±5° - 高斯噪声 - 随机裁剪 """ import csv from pathlib import Path from typing import List, Tuple import cv2 import numpy as np import torch from torch.utils.data import Dataset try: import albumentations as A from albumentations.pytorch import ToTensorV2 except ImportError: raise ImportError("请先安装 albumentations: pip install albumentations==1.4.14") # ============== 配置 ============== INPUT_W = 512 INPUT_H = 128 # 归一化参数(ImageNet 标准,U-Net 用这个稳定) MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] # ================================== def build_train_transform() -> A.Compose: """训练集增强:模拟不同根数 + 光照变化""" return A.Compose([ # 第一步:先随机缩放(模拟不同根数的光纤宽度) # scale_limit=0.6 表示 0.4x ~ 1.6x A.RandomScale(scale_limit=0.6, p=0.8), # 强制 resize 到目标尺寸(必须的,模型输入固定) A.Resize(INPUT_H, INPUT_W, interpolation=cv2.INTER_LINEAR), # 水平翻转(光纤左右翻转,分割任务等价) A.HorizontalFlip(p=0.5), # 轻微旋转(模拟摆放不正) A.Rotate(limit=5, border_mode=cv2.BORDER_CONSTANT, fill=0, fill_mask=0, p=0.5), # 颜色抖动(模拟光照变化) A.ColorJitter( brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05, p=0.7 ), # 高斯噪声(模拟相机噪点) A.GaussNoise(std_range=(0.04, 0.12), p=0.3), # 归一化 + 转 Tensor A.Normalize(mean=MEAN, std=STD), ToTensorV2(), ]) def build_val_transform() -> A.Compose: """验证集:只 resize + 归一化,不做随机增强""" return A.Compose([ A.Resize(INPUT_H, INPUT_W, interpolation=cv2.INTER_LINEAR), A.Normalize(mean=MEAN, std=STD), ToTensorV2(), ]) class FiberSegDataset(Dataset): """光纤分割数据集""" def __init__( self, roi_dir: Path, mask_dir: Path, sample_names: List[str], is_train: bool = True, ): """ roi_dir: ROI 图像目录 mask_dir: mask 目录 sample_names: 这个 dataset 包含哪些样本(文件名,不含扩展名) is_train: True 用训练增强,False 用验证增强 """ self.roi_dir = Path(roi_dir) self.mask_dir = Path(mask_dir) self.sample_names = sample_names self.is_train = is_train self.transform = build_train_transform() if is_train else build_val_transform() def __len__(self): return len(self.sample_names) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: name = self.sample_names[idx] # 读图 roi_path = self.roi_dir / f"{name}.png" mask_path = self.mask_dir / f"{name}.png" image = cv2.imread(str(roi_path), cv2.IMREAD_COLOR) if image is None: raise FileNotFoundError(f"无法读取 ROI: {roi_path}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) if mask is None: raise FileNotFoundError(f"无法读取 mask: {mask_path}") # mask 二值化:>127 → 1,否则 0 mask = (mask > 127).astype(np.uint8) # 数据增强 transformed = self.transform(image=image, mask=mask) image_tensor = transformed["image"].float() # 3x128x512 mask_tensor = transformed["mask"].float().unsqueeze(0) # 1x128x512 return image_tensor, mask_tensor def load_sample_names(meta_csv: Path) -> List[str]: """从 meta.csv 加载所有有效样本名""" names = [] with open(meta_csv, "r", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: if row.get("status") == "ok": names.append(row["filename"]) return names def split_train_val( sample_names: List[str], val_ratio: float = 0.15, seed: int = 42, ) -> Tuple[List[str], List[str]]: """划分训练集 / 验证集""" rng = np.random.RandomState(seed) indices = np.arange(len(sample_names)) rng.shuffle(indices) n_val = max(1, int(len(sample_names) * val_ratio)) val_indices = set(indices[:n_val].tolist()) train_names = [n for i, n in enumerate(sample_names) if i not in val_indices] val_names = [n for i, n in enumerate(sample_names) if i in val_indices] return train_names, val_names def denormalize(tensor: torch.Tensor) -> np.ndarray: """把归一化的 tensor 转回 [0,255] uint8 图像(可视化用)""" mean = torch.tensor(MEAN).view(3, 1, 1) std = torch.tensor(STD).view(3, 1, 1) img = tensor.cpu() * std + mean img = img.clamp(0, 1).permute(1, 2, 0).numpy() img = (img * 255).astype(np.uint8) return img ``` *** ** * ** *** ### 四、`model.py` --- 轻量 U-Net ```python """ 轻量 U-Net 用于光纤束分割 输入: 3 × 128 × 512 (RGB, 归一化) 输出: 1 × 128 × 512 (logits,sigmoid 后是"是光纤"的概率) 参数量: 约 60 万,模型文件 ONNX 后约 2.5 MB CPU 推理(i5-12600KF): 单张 20-40ms """ import torch import torch.nn as nn def conv_block(in_ch, out_ch): """双层 Conv3x3 + BN + ReLU""" return nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) class UpBlock(nn.Module): """上采样 + 拼接 skip + 双 conv""" def __init__(self, in_ch, skip_ch, out_ch): super().__init__() self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.conv = conv_block(in_ch + skip_ch, out_ch) def forward(self, x, skip): x = self.up(x) # 防止 upsample 后尺寸和 skip 不一致(理论上 128/512 整除不会出问题) if x.shape[-2:] != skip.shape[-2:]: x = nn.functional.interpolate( x, size=skip.shape[-2:], mode='bilinear', align_corners=False ) x = torch.cat([x, skip], dim=1) x = self.conv(x) return x class FiberUNet(nn.Module): """ Encoder: 16 → 32 → 64 → 128 Decoder: 64 → 32 → 16 Head: 1 channel (logits) """ def __init__(self, in_channels=3, base_ch=16): super().__init__() c1, c2, c3, c4 = base_ch, base_ch*2, base_ch*4, base_ch*8 # Encoder self.enc1 = conv_block(in_channels, c1) # 128x512 self.pool1 = nn.MaxPool2d(2) # 64x256 self.enc2 = conv_block(c1, c2) self.pool2 = nn.MaxPool2d(2) # 32x128 self.enc3 = conv_block(c2, c3) self.pool3 = nn.MaxPool2d(2) # 16x64 # Bottleneck self.bottleneck = conv_block(c3, c4) # 16x64 # Decoder self.up3 = UpBlock(c4, c3, c3) # 32x128 self.up2 = UpBlock(c3, c2, c2) # 64x256 self.up1 = UpBlock(c2, c1, c1) # 128x512 # Output head self.head = nn.Conv2d(c1, 1, kernel_size=1) def forward(self, x): # Encoder e1 = self.enc1(x) # 16 × 128 × 512 e2 = self.enc2(self.pool1(e1)) # 32 × 64 × 256 e3 = self.enc3(self.pool2(e2)) # 64 × 32 × 128 # Bottleneck b = self.bottleneck(self.pool3(e3)) # 128 × 16 × 64 # Decoder d3 = self.up3(b, e3) # 64 × 32 × 128 d2 = self.up2(d3, e2) # 32 × 64 × 256 d1 = self.up1(d2, e1) # 16 × 128 × 512 # Output out = self.head(d1) # 1 × 128 × 512 (logits) return out def count_params(model: nn.Module) -> int: return sum(p.numel() for p in model.parameters() if p.requires_grad) if __name__ == "__main__": # 自测 model = FiberUNet() x = torch.randn(2, 3, 128, 512) y = model(x) print(f"输入: {x.shape}") print(f"输出: {y.shape}") print(f"参数量: {count_params(model):,}") ``` *** ** * ** *** ### 五、`train.py` --- 训练主循环 ```python """ 训练主循环 使用: cd D:\\FiberSeg python scripts/train.py 会输出: - 控制台日志(每个 epoch 的 loss/IoU) - checkpoints/best.pth (验证集 IoU 最高的权重) - checkpoints/last.pth (最新权重) - outputs/train_log.csv (所有指标) - outputs/val_samples/ (每 10 epoch 保存验证集可视化) """ import sys import csv import time from pathlib import Path import cv2 import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from tqdm import tqdm # 加 scripts 到 path sys.path.insert(0, str(Path(__file__).parent)) from dataset import ( FiberSegDataset, load_sample_names, split_train_val, denormalize, INPUT_W, INPUT_H, ) from model import FiberUNet, count_params # ============== 配置 ============== PROJECT_ROOT = Path(__file__).parent.parent ROI_DIR = PROJECT_ROOT / "dataset" / "roi" MASK_DIR = PROJECT_ROOT / "dataset" / "mask" META_CSV = PROJECT_ROOT / "dataset" / "meta.csv" CKPT_DIR = PROJECT_ROOT / "checkpoints" OUT_DIR = PROJECT_ROOT / "outputs" VAL_VIS_DIR = OUT_DIR / "val_samples" # 训练超参 EPOCHS = 100 BATCH_SIZE = 8 LR = 1e-3 WEIGHT_DECAY = 1e-4 VAL_RATIO = 0.15 # 验证集比例 SEED = 42 NUM_WORKERS = 2 # Windows 下不要太大,2 通常够 SAVE_VIS_EVERY = 10 # 每 N 个 epoch 存一次验证集可视化 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ================================== # ============== 损失函数 ============== class BCEDiceLoss(nn.Module): """BCE + Dice 加权""" def __init__(self, bce_weight=0.5): super().__init__() self.bce = nn.BCEWithLogitsLoss() self.bce_weight = bce_weight def forward(self, logits, target): bce_loss = self.bce(logits, target) # Dice probs = torch.sigmoid(logits) smooth = 1.0 inter = (probs * target).sum(dim=(1,2,3)) union = probs.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) dice = (2 * inter + smooth) / (union + smooth) dice_loss = 1 - dice.mean() return self.bce_weight * bce_loss + (1 - self.bce_weight) * dice_loss # ============== 指标 ============== def compute_iou(logits, target, threshold=0.5): """计算 IoU(每个样本独立,然后 batch 平均)""" probs = torch.sigmoid(logits) preds = (probs > threshold).float() inter = (preds * target).sum(dim=(1,2,3)) union = (preds + target).clamp(0, 1).sum(dim=(1,2,3)) iou = (inter + 1e-6) / (union + 1e-6) return iou.mean().item() # ============== 可视化 ============== def save_val_visualization(model, val_loader, epoch, save_dir): """保存几张验证集可视化结果""" save_dir.mkdir(parents=True, exist_ok=True) model.eval() saved = 0 max_save = 4 with torch.no_grad(): for images, masks in val_loader: images = images.to(DEVICE) logits = model(images) probs = torch.sigmoid(logits).cpu() preds = (probs > 0.5).float() for i in range(images.shape[0]): if saved >= max_save: return img_np = denormalize(images[i]) # RGB uint8 gt_np = (masks[i, 0].cpu().numpy() * 255).astype(np.uint8) pred_np = (preds[i, 0].numpy() * 255).astype(np.uint8) # 拼接:原图 | GT | 预测 img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) gt_bgr = cv2.cvtColor(gt_np, cv2.COLOR_GRAY2BGR) pred_bgr = cv2.cvtColor(pred_np, cv2.COLOR_GRAY2BGR) # 标题 h = img_bgr.shape[0] divider = np.ones((h, 5, 3), dtype=np.uint8) * 128 combined = np.hstack([img_bgr, divider, gt_bgr, divider, pred_bgr]) out_path = save_dir / f"epoch{epoch:03d}_sample{saved}.png" cv2.imwrite(str(out_path), combined, [cv2.IMWRITE_PNG_COMPRESSION, 3]) saved += 1 # ============== 训练循环 ============== def train_one_epoch(model, loader, optimizer, criterion): model.train() total_loss = 0 total_iou = 0 n = 0 pbar = tqdm(loader, desc=" Train", leave=False) for images, masks in pbar: images = images.to(DEVICE, non_blocking=True) masks = masks.to(DEVICE, non_blocking=True) optimizer.zero_grad() logits = model(images) loss = criterion(logits, masks) loss.backward() optimizer.step() with torch.no_grad(): iou = compute_iou(logits, masks) bs = images.size(0) total_loss += loss.item() * bs total_iou += iou * bs n += bs pbar.set_postfix(loss=f"{loss.item():.4f}", iou=f"{iou:.3f}") return total_loss / n, total_iou / n @torch.no_grad() def validate(model, loader, criterion): model.eval() total_loss = 0 total_iou = 0 n = 0 for images, masks in tqdm(loader, desc=" Val", leave=False): images = images.to(DEVICE, non_blocking=True) masks = masks.to(DEVICE, non_blocking=True) logits = model(images) loss = criterion(logits, masks) iou = compute_iou(logits, masks) bs = images.size(0) total_loss += loss.item() * bs total_iou += iou * bs n += bs return total_loss / n, total_iou / n # ============== 主函数 ============== def main(): CKPT_DIR.mkdir(parents=True, exist_ok=True) OUT_DIR.mkdir(parents=True, exist_ok=True) # 检查 mask 目录是否存在 if not MASK_DIR.exists() or not any(MASK_DIR.iterdir()): print(f"⚠ {MASK_DIR} 为空或不存在") print(f" 请先生成 mask 初稿,人工修正后放到 dataset/mask/") return # 设置随机种子 torch.manual_seed(SEED) np.random.seed(SEED) # 加载样本 all_names = load_sample_names(META_CSV) # 只保留 mask 存在的样本 all_names = [n for n in all_names if (MASK_DIR / f"{n}.png").exists()] train_names, val_names = split_train_val(all_names, VAL_RATIO, SEED) print(f"训练集: {len(train_names)} | 验证集: {len(val_names)}") print(f"验证集样本: {val_names}\n") train_ds = FiberSegDataset(ROI_DIR, MASK_DIR, train_names, is_train=True) val_ds = FiberSegDataset(ROI_DIR, MASK_DIR, val_names, is_train=False) train_loader = DataLoader( train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True, ) val_loader = DataLoader( val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, ) # 模型 model = FiberUNet().to(DEVICE) print(f"模型参数量: {count_params(model):,}") print(f"设备: {DEVICE}") if DEVICE.type == "cuda": print(f"GPU: {torch.cuda.get_device_name(0)}") # 损失 + 优化器 + 调度器 criterion = BCEDiceLoss(bce_weight=0.5) optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-5) # 训练日志 log_path = OUT_DIR / "train_log.csv" with open(log_path, "w", encoding="utf-8", newline="") as f: writer = csv.writer(f) writer.writerow(["epoch", "lr", "train_loss", "train_iou", "val_loss", "val_iou"]) best_iou = -1 best_epoch = 0 print(f"\n开始训练 {EPOCHS} epochs...\n") start_time = time.time() for epoch in range(1, EPOCHS + 1): epoch_start = time.time() current_lr = optimizer.param_groups[0]["lr"] print(f"Epoch {epoch}/{EPOCHS} lr={current_lr:.6f}") tr_loss, tr_iou = train_one_epoch(model, train_loader, optimizer, criterion) val_loss, val_iou = validate(model, val_loader, criterion) scheduler.step() epoch_time = time.time() - epoch_start print(f" train: loss={tr_loss:.4f} iou={tr_iou:.4f} " f"val: loss={val_loss:.4f} iou={val_iou:.4f} " f"({epoch_time:.1f}s)") # 记录日志 with open(log_path, "a", encoding="utf-8", newline="") as f: writer = csv.writer(f) writer.writerow([epoch, current_lr, tr_loss, tr_iou, val_loss, val_iou]) # 保存 last torch.save({ "epoch": epoch, "model_state": model.state_dict(), "val_iou": val_iou, }, CKPT_DIR / "last.pth") # 保存 best if val_iou > best_iou: best_iou = val_iou best_epoch = epoch torch.save({ "epoch": epoch, "model_state": model.state_dict(), "val_iou": val_iou, }, CKPT_DIR / "best.pth") print(f" ★ best iou {best_iou:.4f} → 已保存 best.pth") # 周期性保存可视化 if epoch % SAVE_VIS_EVERY == 0 or epoch == 1: save_val_visualization(model, val_loader, epoch, VAL_VIS_DIR) total_time = time.time() - start_time print(f"\n========== 训练完成 ==========") print(f" 总耗时: {total_time/60:.1f} 分钟") print(f" 最佳 IoU: {best_iou:.4f} (epoch {best_epoch})") print(f" best 权重: {CKPT_DIR / 'best.pth'}") print(f" 日志: {log_path}") print(f" 下一步: python scripts/export_onnx.py") if __name__ == "__main__": main() ``` *** ** * ** *** ### 六、`export_onnx.py` --- 导出 ONNX ```python """ PyTorch → ONNX 导出 使用: python scripts/export_onnx.py 输出: checkpoints/fiber_unet.onnx 原始导出 checkpoints/fiber_unet_sim.onnx onnxsim 简化版(C# 端用这个) """ import sys from pathlib import Path import torch import onnx import onnxsim sys.path.insert(0, str(Path(__file__).parent)) from model import FiberUNet from dataset import INPUT_W, INPUT_H PROJECT_ROOT = Path(__file__).parent.parent CKPT_DIR = PROJECT_ROOT / "checkpoints" INPUT_CKPT = CKPT_DIR / "best.pth" OUTPUT_ONNX = CKPT_DIR / "fiber_unet.onnx" OUTPUT_ONNX_SIM = CKPT_DIR / "fiber_unet_sim.onnx" OPSET = 17 def main(): if not INPUT_CKPT.exists(): print(f"⚠ 找不到权重: {INPUT_CKPT}") return # 加载模型 model = FiberUNet() ckpt = torch.load(INPUT_CKPT, map_location="cpu", weights_only=False) model.load_state_dict(ckpt["model_state"]) model.eval() print(f"加载权重: {INPUT_CKPT}") print(f" 来自 epoch {ckpt['epoch']}, val_iou {ckpt['val_iou']:.4f}") # 构造 dummy input dummy = torch.randn(1, 3, INPUT_H, INPUT_W) # 导出 print(f"\n导出 ONNX (opset={OPSET})...") torch.onnx.export( model, dummy, str(OUTPUT_ONNX), export_params=True, opset_version=OPSET, do_constant_folding=True, input_names=["input"], output_names=["logits"], dynamic_axes={ "input": {0: "batch"}, "logits": {0: "batch"}, }, ) print(f" → {OUTPUT_ONNX}") # 验证 ONNX onnx_model = onnx.load(str(OUTPUT_ONNX)) onnx.checker.check_model(onnx_model) print(" ONNX 模型校验通过") # 简化 print("\n简化 ONNX (onnxsim)...") simplified, check_ok = onnxsim.simplify(onnx_model) if not check_ok: print(" ⚠ 简化后校验失败,使用原始版本") simplified = onnx_model onnx.save(simplified, str(OUTPUT_ONNX_SIM)) print(f" → {OUTPUT_ONNX_SIM}") # 显示大小 print(f"\n文件大小:") print(f" 原始: {OUTPUT_ONNX.stat().st_size / 1024:.1f} KB") print(f" 简化: {OUTPUT_ONNX_SIM.stat().st_size / 1024:.1f} KB") print(f"\n下一步: python scripts/test_inference.py") if __name__ == "__main__": main() ``` *** ** * ** *** ### 七、`test_inference.py` --- Python 端 ONNX 推理验证 ```python """ Python 端 ONNX 推理测试 目的: - 验证 ONNX 导出正确(数值和 PyTorch 一致) - 验证 C# 端用同样的预处理逻辑能复现 使用: python scripts/test_inference.py 会从 dataset/roi/ 拿前 3 张图测试, 可视化结果保存到 outputs/onnx_test/ """ import sys from pathlib import Path import cv2 import numpy as np import onnxruntime as ort import torch sys.path.insert(0, str(Path(__file__).parent)) from model import FiberUNet from dataset import INPUT_W, INPUT_H, MEAN, STD PROJECT_ROOT = Path(__file__).parent.parent ROI_DIR = PROJECT_ROOT / "dataset" / "roi" CKPT_DIR = PROJECT_ROOT / "checkpoints" ONNX_PATH = CKPT_DIR / "fiber_unet_sim.onnx" PTH_PATH = CKPT_DIR / "best.pth" OUT_DIR = PROJECT_ROOT / "outputs" / "onnx_test" def preprocess(bgr: np.ndarray) -> np.ndarray: """ROI BGR → 模型输入 NCHW float32""" # BGR → RGB rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) # Resize rgb = cv2.resize(rgb, (INPUT_W, INPUT_H), interpolation=cv2.INTER_LINEAR) # 归一化 rgb = rgb.astype(np.float32) / 255.0 mean = np.array(MEAN, dtype=np.float32).reshape(1, 1, 3) std = np.array(STD, dtype=np.float32).reshape(1, 1, 3) rgb = (rgb - mean) / std # HWC → CHW → NCHW chw = rgb.transpose(2, 0, 1) return chw[np.newaxis, ...].astype(np.float32) def postprocess(logits: np.ndarray, orig_w: int, orig_h: int, threshold: float = 0.5) -> np.ndarray: """logits → 原 ROI 尺寸的二值 mask""" probs = 1.0 / (1.0 + np.exp(-logits)) # sigmoid mask_small = (probs[0, 0] > threshold).astype(np.uint8) # 128x512 mask_orig = cv2.resize(mask_small, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) return mask_orig def main(): if not ONNX_PATH.exists(): print(f"⚠ ONNX 不存在: {ONNX_PATH}") return OUT_DIR.mkdir(parents=True, exist_ok=True) # 加载 ONNX print(f"加载 ONNX: {ONNX_PATH}") sess = ort.InferenceSession( str(ONNX_PATH), providers=["CPUExecutionProvider"], ) # 加载 PyTorch(对比验证) print(f"加载 PyTorch: {PTH_PATH}") pt_model = FiberUNet() pt_model.load_state_dict(torch.load(PTH_PATH, map_location="cpu", weights_only=False)["model_state"]) pt_model.eval() # 找几张测试图 test_imgs = sorted(ROI_DIR.glob("*.png"))[:3] print(f"\n测试 {len(test_imgs)} 张图...") for img_path in test_imgs: bgr = cv2.imread(str(img_path), cv2.IMREAD_COLOR) h, w = bgr.shape[:2] # 预处理 x = preprocess(bgr) # ONNX 推理 import time t0 = time.time() for _ in range(5): # 跑 5 次取平均 onnx_out = sess.run(None, {"input": x})[0] onnx_time = (time.time() - t0) / 5 * 1000 # PyTorch 推理(对比) with torch.no_grad(): pt_out = pt_model(torch.from_numpy(x)).numpy() # 数值差异 diff = np.abs(onnx_out - pt_out).max() print(f"\n {img_path.name}:") print(f" ROI 尺寸: {w}x{h}") print(f" ONNX 推理: {onnx_time:.1f} ms") print(f" ONNX vs PyTorch 最大差异: {diff:.6f} {'✓' if diff < 1e-3 else '⚠'}") # 后处理 mask = postprocess(onnx_out, w, h) # 可视化:原图 | mask mask_3ch = cv2.cvtColor(mask * 255, cv2.COLOR_GRAY2BGR) # 叠加 overlay = bgr.copy() overlay[mask == 1] = (overlay[mask == 1] * 0.5 + np.array([0, 0, 200]) * 0.5).astype(np.uint8) divider = np.ones((h, 5, 3), dtype=np.uint8) * 128 combined = np.hstack([bgr, divider, mask_3ch, divider, overlay]) out_path = OUT_DIR / f"test_{img_path.stem}.png" cv2.imwrite(str(out_path), combined, [cv2.IMWRITE_PNG_COMPRESSION, 3]) print(f" 可视化: {out_path}") print(f"\n========== 测试完成 ==========") print(f" 结果: {OUT_DIR}") if __name__ == "__main__": main() ``` *** ** * ** *** ### 八、完整工作流(从现在到训练完成) 按这个顺序执行,**每一步不通过不要进下一步**: #### Step 1:生成 mask 初稿 ```bash cd D:\FiberSeg python scripts\generate_initial_masks.py ``` **完成后** :打开 `dataset/mask_initial_preview/` 用 Windows 图片查看器 **快速过一遍**(箭头键翻图)。 **检查标准**: * 绿色竖线大致贴合光纤束左右边缘 * 红色半透明覆盖整个光纤束区域 * 没有明显漏掉或多框 **预估**:59 张里大概 40-50 张是合格的,10-20 张需要修正。 #### Step 2:人工修正 mask 这一步**是整个项目最关键的人工环节** 。我下一条消息单独讲怎么操作。**先不要进入下一步**,把 Step 1 跑完,告诉我结果。 #### Step 3:训练(等 mask 准备好) ```bash python scripts\train.py ``` 3060 上预估 5-10 分钟训完 100 epoch。 #### Step 4:导出 + 验证 ```bash python scripts\export_onnx.py python scripts\test_inference.py ``` *** ** * ** *** ### 九、现在你做的事 1. **保存 5 个 Python 文件到 `D:\FiberSeg\scripts\`**: * `generate_initial_masks.py` * `dataset.py` * `model.py` * `train.py` * `export_onnx.py` * `test_inference.py`(6 个,数错了) 2. **跑 generate_initial_masks**: ```bash cd D:\FiberSeg python scripts\generate_initial_masks.py ``` 3. **告诉我结果**: * 跑通没?有没有报错? * 打开 `dataset/mask_initial_preview/` 看几张,**截图发我 2-3 张**(挑一张看起来对的、一张看起来错的) * 我看完之后告诉你怎么进行人工修正(Step 2 的详细操作指南) **有任何报错立刻贴**。 """ 交互式 mask 修正工具 工作流: 1. 自动从 dataset/mask_initial/ 加载初稿 2. 显示 ROI 原图 + mask 红色半透明叠加 3. 鼠标左键涂白(标光纤),右键涂黑(标背景) 4. Enter 保存到 dataset/mask/,自动下一张 键盘操作: 鼠标左键拖动 - 涂白色(标光纤) 鼠标右键拖动 - 涂黑色(标背景) 滚轮 - 调画笔大小 Z - 撤销上一笔 R - 重置为初稿 F - 标记为"完美"(直接复制初稿到 mask/) Enter - 保存并下一张 S - 跳过(不保存) A 或 ← - 上一张(不保存当前) D 或 → - 下一张(不保存当前) Q 或 ESC - 退出 使用: python scripts/refine_masks.py """ import sys from pathlib import Path from collections import deque import cv2 import numpy as np ## ============== 配置 ============== PROJECT_ROOT = Path(**file** ).parent.parent ROI_DIR = PROJECT_ROOT / "dataset" / "roi" MASK_INIT_DIR = PROJECT_ROOT / "dataset" / "mask_initial" MASK_DIR = PROJECT_ROOT / "dataset" / "mask" ## 显示窗口最大宽度(ROI 通常 1000-1700px,放屏幕够大) DISPLAY_MAX_WIDTH = 1600 DISPLAY_MAX_HEIGHT = 900 ## 画笔参数 BRUSH_INIT = 30 # 初始画笔半径 BRUSH_MIN = 3 BRUSH_MAX = 200 UNDO_STACK = 30 # 撤销栈深度 ## 叠加显示参数 OVERLAY_ALPHA = 0.45 # mask 半透明叠加强度 OVERLAY_COLOR = (0, 0, 220) # 红色(BGR) ## ================================== class MaskEditor: def **init** (self): self.roi: np.ndarray \| None = None # 原图 BGR self.mask: np.ndarray \| None = None # 当前 mask (0/255) self.mask_initial: np.ndarray \| None = None # 初稿副本(用于重置) self.display: np.ndarray \| None = None # 缩放后用于显示的图 self.display_scale = 1.0 self.brush_radius = BRUSH_INIT self.is_dirty = False # 是否有修改 self.undo_stack: deque = deque(maxlen=UNDO_STACK) self.last_paint_pos: tuple \| None = None self.current_button: int = -1 # -1 没按, 0 左键白, 1 右键黑 self.cursor_pos: tuple \| None = None # 当前鼠标位置(显示坐标) def load(self, roi_path: Path, mask_path: Path): """加载一张图和它的 mask""" self.roi = cv2.imread(str(roi_path), cv2.IMREAD_COLOR) if self.roi is None: raise FileNotFoundError(f"无法读取 ROI: {roi_path}") if mask_path.exists(): mask_raw = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) # 二值化为纯 0/255 self.mask = ((mask_raw > 127).astype(np.uint8)) * 255 else: # 没有 mask,全黑(全背景) h, w = self.roi.shape[:2] self.mask = np.zeros((h, w), dtype=np.uint8) self.mask_initial = self.mask.copy() self.is_dirty = False self.undo_stack.clear() self._compute_display_scale() def _compute_display_scale(self): """根据屏幕大小,计算显示缩放系数""" h, w = self.roi.shape[:2] scale_w = DISPLAY_MAX_WIDTH / w if w > DISPLAY_MAX_WIDTH else 1.0 scale_h = DISPLAY_MAX_HEIGHT / h if h > DISPLAY_MAX_HEIGHT else 1.0 self.display_scale = min(scale_w, scale_h, 1.0) def _display_to_original(self, x: int, y: int) -> tuple: """显示坐标 → 原图坐标""" ox = int(x / self.display_scale) oy = int(y / self.display_scale) return ox, oy def _save_undo(self): """保存当前 mask 到撤销栈""" self.undo_stack.append(self.mask.copy()) def undo(self): """撤销上一步""" if self.undo_stack: self.mask = self.undo_stack.pop() self.is_dirty = True def reset(self): """重置为初稿""" self._save_undo() self.mask = self.mask_initial.copy() self.is_dirty = True def paint(self, x: int, y: int, color: int, save_undo: bool = True): """ 在 (x, y) 显示坐标处涂一笔 color: 255=白(光纤), 0=黑(背景) """ ox, oy = self._display_to_original(x, y) radius_orig = max(1, int(self.brush_radius / self.display_scale)) if save_undo: self._save_undo() # 如果有上一个位置,画线段(避免快速拖动时点之间断开) if self.last_paint_pos is not None: cv2.line(self.mask, self.last_paint_pos, (ox, oy), color, thickness=radius_orig * 2, lineType=cv2.LINE_8) cv2.circle(self.mask, (ox, oy), radius_orig, color, thickness=-1, lineType=cv2.LINE_8) # 强制二值化(防止任何渐变) self.mask = (self.mask > 127).astype(np.uint8) * 255 self.last_paint_pos = (ox, oy) self.is_dirty = True def render(self) -> np.ndarray: """渲染显示图:原图 + mask 红色叠加 + 画笔光标""" # 叠加 mask overlay = self.roi.copy() mask_3ch = np.zeros_like(self.roi) mask_3ch[self.mask == 255] = OVERLAY_COLOR blended = cv2.addWeighted(overlay, 1 - OVERLAY_ALPHA, mask_3ch, OVERLAY_ALPHA, 0) # 在原图区域只显示有 mask 的地方,其他保持原图 result = self.roi.copy() mask_region = self.mask == 255 result[mask_region] = blended[mask_region] # 缩放到显示尺寸 if self.display_scale != 1.0: new_w = int(result.shape[1] * self.display_scale) new_h = int(result.shape[0] * self.display_scale) result = cv2.resize(result, (new_w, new_h), interpolation=cv2.INTER_AREA) # 画画笔光标 if self.cursor_pos is not None: cv2.circle(result, self.cursor_pos, self.brush_radius, (0, 255, 255), 1) # 黄圈 cv2.circle(result, self.cursor_pos, 2, (0, 255, 255), -1) return result def get_final_mask(self) -> np.ndarray: """返回最终 mask(确保纯 0/255)""" return ((self.mask > 127).astype(np.uint8)) * 255 def list_samples() -\> list\[str\]: """列出所有 ROI 文件名(不含扩展名)""" roi_files = sorted(ROI_DIR.glob("\*.png")) return \[p.stem for p in roi_files

def main():

MASK_DIR.mkdir(parents=True, exist_ok=True)

复制代码
samples = list_samples()
if not samples:
    print(f"⚠ {ROI_DIR} 下没有 PNG 图")
    return

print(f"找到 {len(samples)} 张 ROI")

# 统计已完成
done = set()
for s in samples:
    if (MASK_DIR / f"{s}.png").exists():
        done.add(s)
print(f"已完成: {len(done)} / {len(samples)}")

# 询问从哪开始
start_idx = 0
if done:
    ans = input("是否跳过已完成的? (Y/n): ").strip().lower()
    if ans != "n":
        # 找第一个未完成的
        for i, s in enumerate(samples):
            if s not in done:
                start_idx = i
                break
        print(f"从第 {start_idx+1} 张开始: {samples[start_idx]}")

editor = MaskEditor()
win_name = "Mask Editor"
cv2.namedWindow(win_name, cv2.WINDOW_AUTOSIZE)

def mouse_cb(event, x, y, flags, param):
    editor.cursor_pos = (x, y)

    if event == cv2.EVENT_LBUTTONDOWN:
        editor.current_button = 0
        editor.last_paint_pos = None
        editor.paint(x, y, 255, save_undo=True)
    elif event == cv2.EVENT_RBUTTONDOWN:
        editor.current_button = 1
        editor.last_paint_pos = None
        editor.paint(x, y, 0, save_undo=True)
    elif event == cv2.EVENT_MOUSEMOVE:
        if editor.current_button == 0:
            editor.paint(x, y, 255, save_undo=False)
        elif editor.current_button == 1:
            editor.paint(x, y, 0, save_undo=False)
    elif event in (cv2.EVENT_LBUTTONUP, cv2.EVENT_RBUTTONUP):
        editor.current_button = -1
        editor.last_paint_pos = None
    elif event == cv2.EVENT_MOUSEWHEEL:
        if flags > 0:
            editor.brush_radius = min(BRUSH_MAX, editor.brush_radius + 3)
        else:
            editor.brush_radius = max(BRUSH_MIN, editor.brush_radius - 3)

cv2.setMouseCallback(win_name, mouse_cb)

print("\n=========== 操作说明 ===========")
print("  鼠标左键拖动  → 涂白(光纤)")
print("  鼠标右键拖动  → 涂黑(背景)")
print("  滚轮          → 调画笔大小")
print("  Z             → 撤销")
print("  R             → 重置为初稿")
print("  F             → 完美(初稿不改,直接保存)")
print("  Enter         → 保存并下一张")
print("  S             → 跳过(不保存)")
print("  A / D         → 上/下一张(不保存)")
print("  Q / ESC       → 退出")
print("===============================\n")

idx = start_idx
while 0 <= idx < len(samples):
    name = samples[idx]
    roi_path = ROI_DIR / f"{name}.png"
    mask_path = MASK_DIR / f"{name}.png"
    init_path = MASK_INIT_DIR / f"{name}.png"

    # 优先加载 mask/ 已有的,其次加载 mask_initial/
    source_mask = mask_path if mask_path.exists() else init_path

    try:
        editor.load(roi_path, source_mask)
    except Exception as e:
        print(f"⚠ 加载失败 {name}: {e}")
        idx += 1
        continue

    h, w = editor.roi.shape[:2]
    status = "✓已完成" if name in done else "新"
    title = f"[{idx+1}/{len(samples)}] {name}.png  ROI:{w}x{h}  {status}"

    while True:
        display_img = editor.render()

        # 在顶部画状态条
        info = f"[{idx+1}/{len(samples)}] {name} | brush:{editor.brush_radius} | {'MODIFIED' if editor.is_dirty else 'unchanged'}"
        cv2.putText(display_img, info, (10, 25),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)

        cv2.imshow(win_name, display_img)
        cv2.setWindowTitle(win_name, title)
        key = cv2.waitKey(20) & 0xFF

        if key == 13:  # Enter
            # 保存
            final = editor.get_final_mask()
            cv2.imwrite(str(mask_path), final,
                        [cv2.IMWRITE_PNG_COMPRESSION, 3])
            print(f"  ✓ 保存 {name}.png  (修改:{editor.is_dirty})")
            done.add(name)
            idx += 1
            break

        elif key in (ord("f"), ord("F")):
            # 标记完美:直接把初稿复制过去
            cv2.imwrite(str(mask_path), editor.mask_initial,
                        [cv2.IMWRITE_PNG_COMPRESSION, 3])
            print(f"  ✓ 完美 {name}.png  (使用初稿)")
            done.add(name)
            idx += 1
            break

        elif key in (ord("s"), ord("S")):
            print(f"  跳过 {name}.png")
            idx += 1
            break

        elif key in (ord("a"), ord("A"), 2424832):  # A 或 左箭头
            if idx > 0:
                idx -= 1
            else:
                print("  已经是第一张")
                continue
            break

        elif key in (ord("d"), ord("D"), 2555904):  # D 或 右箭头
            idx += 1
            break

        elif key in (ord("z"), ord("Z")):
            editor.undo()

        elif key in (ord("r"), ord("R")):
            editor.reset()
            print(f"  重置 {name}.png")

        elif key in (ord("q"), ord("Q"), 27):
            if editor.is_dirty:
                ans = input("\n当前有未保存修改,确定退出? (y/N): ").strip().lower()
                if ans != "y":
                    continue
            cv2.destroyAllWindows()
            print(f"\n========== 退出 ==========")
            print(f"  已完成: {len(done)} / {len(samples)}")
            return

cv2.destroyAllWindows()
print(f"\n========== 全部完成 ==========")
print(f"  保存目录: {MASK_DIR}")
print(f"  下一步: 检查 mask/ 目录,然后跑 python scripts/train.py")

if name == "main ":

main()

相关推荐
轻口味1 小时前
HarmonyOS 6.1 全栈实战录 - 13 流量增长新引擎:全场景归因与 App Linking 链接深度开发实战
pytorch·深度学习·harmonyos
搜佛说1 小时前
一多操作系统性能篇
人工智能
月诸清酒1 小时前
63-260516 AI 科技日报 (X推荐算法开源,核心驱动转向Grok模型)
人工智能·算法·推荐算法
逐米时代2 小时前
成都企业做大模型本地化部署,如何从试点走向生产?
人工智能
RSTJ_16252 小时前
PYTHON+AI LLM DAY FOURTY-SEVEN
开发语言·人工智能·python·深度学习
阳艳讲ai2 小时前
中小企业AI项目落地技术指南:常见问题与实施框架
大数据·人工智能·企业ai培训·ai获客·九尾狐ai
踏着七彩祥云的小丑2 小时前
AI——Dify企业级最佳实践
人工智能·ai
阳明山水2 小时前
零售销量预测为何选LightGBM
人工智能·机器学习·微信·微信公众平台·微信开放平台
2zcode2 小时前
基于深度学习的智能职业匹配系统设计与实现
人工智能·深度学习