U-NET模型训练--图像标注脚本工具

python 复制代码
r"""
多图 mask 标注工具 v1.4
======================
功能:
  - 列方向涂条:V 模式下左/右键点击即画贯穿全高的竖条
  - 直线辅助  :L 模式下两次点击画/擦直线
  - 多图切换  :A/D 上/下一张,自动保存
  - 进度续接  :label_log.json 记录每张状态

路径:
  ROI 输入  : E:\Deeplearning\U-Net\multi_train\roi\
  Mask 输出 : E:\Deeplearning\U-Net\multi_train\mask\
  进度日志  : E:\Deeplearning\U-Net\multi_train\label_log.json

操作:
  左键拖动 / 右键拖动  - 涂白 / 涂黑
  滚轮 或 [ ]          - 调画笔大小
  V                    - 切换 圆形画笔 / 竖条模式
  L                    - 切换 直线辅助模式
                          第1次点击:定位起点
                          第2次左键:画白线(添加)
                          第2次右键:画黑线(消除)
  Z / C / B            - 撤销 / 清空 / 切视图
  S / A / D            - 保存 / 上一张 / 下一张
  Q / ESC              - 退出
"""

from pathlib import Path
from collections import deque
import cv2
import numpy as np
import json


# ============== 配置 ==============
ROOT      = Path(r"E:\Deeplearning\U-Net\multi_train")
ROI_DIR   = ROOT / "roi"
MASK_DIR  = ROOT / "mask"
LOG_PATH  = ROOT / "label_log.json"

EXTS = {".bmp", ".png", ".jpg", ".jpeg", ".tif", ".tiff"}

DISPLAY_MAX_WIDTH  = 1600
DISPLAY_MAX_HEIGHT = 900

BRUSH_INIT = 12
BRUSH_MIN  = 1
BRUSH_MAX  = 200
UNDO_STACK = 30

OVERLAY_ALPHA = 0.65
OVERLAY_COLOR = (0, 255, 255)  # BGR 亮黄
# ==================================


def load_log():
    if LOG_PATH.exists():
        with open(LOG_PATH, "r", encoding="utf-8") as f:
            return json.load(f)
    return {}


def save_log(log):
    with open(LOG_PATH, "w", encoding="utf-8") as f:
        json.dump(log, f, indent=2, ensure_ascii=False)


def list_rois():
    files = []
    for p in sorted(ROI_DIR.iterdir()):
        if p.suffix.lower() in EXTS:
            files.append(p)
    return files


class MaskEditor:
    def __init__(self, roi_path, mask_path):
        self.roi_path = roi_path
        self.mask_path = mask_path
        self.roi = cv2.imread(str(roi_path), cv2.IMREAD_COLOR)
        if self.roi is None:
            raise FileNotFoundError(f"无法读取 ROI: {roi_path}")
        h, w = self.roi.shape[:2]
        self.h, self.w = h, w

        if mask_path.exists():
            m = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
            self.mask = ((m > 127).astype(np.uint8)) * 255
            self.is_new = False
            print("  继续编辑已有 mask")
        else:
            self.mask = np.zeros((h, w), dtype=np.uint8)
            self.is_new = True
            print("  空白 mask (全黑)")

        scale_w = DISPLAY_MAX_WIDTH  / w if w > DISPLAY_MAX_WIDTH  else 1.0
        scale_h = DISPLAY_MAX_HEIGHT / h if h > DISPLAY_MAX_HEIGHT else 1.0
        self.display_scale = min(scale_w, scale_h, 1.0)

        self.brush_radius = BRUSH_INIT
        self.undo_stack = deque(maxlen=UNDO_STACK)
        self.last_paint_pos = None
        self.current_button = -1
        self.cursor_pos = None
        self.show_overlay = True
        self.is_dirty = self.is_new
        self.vertical_mode = False

        # ---- 直线辅助模式 ----
        self.line_mode = False
        self.line_start = None
        self.line_pending_button = -1

    def _disp_to_orig(self, x, y):
        return int(x / self.display_scale), int(y / self.display_scale)

    def _save_undo(self):
        self.undo_stack.append(self.mask.copy())

    def undo(self):
        if self.undo_stack:
            self.mask = self.undo_stack.pop()
            self.is_dirty = True

    def clear(self):
        self._save_undo()
        self.mask[:] = 0
        self.is_dirty = True

    def paint(self, x, y, color, save_undo=True):
        """常规画笔 (圆形 / 竖条)"""
        ox, oy = self._disp_to_orig(x, y)
        radius = max(1, int(self.brush_radius / self.display_scale))
        if save_undo:
            self._save_undo()

        if self.vertical_mode:
            x1 = max(0, ox - radius)
            x2 = min(self.w, ox + radius + 1)
            self.mask[:, x1:x2] = color
        else:
            if self.last_paint_pos is not None:
                cv2.line(self.mask, self.last_paint_pos, (ox, oy),
                         color, thickness=radius * 2, lineType=cv2.LINE_8)
            cv2.circle(self.mask, (ox, oy), radius, color, -1, cv2.LINE_8)

        self.mask = (self.mask > 127).astype(np.uint8) * 255
        self.last_paint_pos = (ox, oy)
        self.is_dirty = True

    def paint_line(self, start_orig, end_orig, color):
        """在 start_orig -> end_orig 之间画一条粗直线"""
        self._save_undo()
        radius = max(1, int(self.brush_radius / self.display_scale))
        p1 = (int(start_orig[0]), int(start_orig[1]))
        p2 = (int(end_orig[0]), int(end_orig[1]))
        cv2.line(self.mask, p1, p2, color,
                 thickness=radius * 2, lineType=cv2.LINE_8)
        self.mask = (self.mask > 127).astype(np.uint8) * 255
        self.is_dirty = True

    def render(self):
        if self.show_overlay:
            mask_3 = np.zeros_like(self.roi)
            mask_3[self.mask == 255] = OVERLAY_COLOR
            blended = cv2.addWeighted(self.roi, 1 - OVERLAY_ALPHA,
                                       mask_3, OVERLAY_ALPHA, 0)
            view = self.roi.copy()
            view[self.mask == 255] = blended[self.mask == 255]
        else:
            view = cv2.cvtColor(self.mask, cv2.COLOR_GRAY2BGR)

        if self.display_scale != 1.0:
            nw = int(view.shape[1] * self.display_scale)
            nh = int(view.shape[0] * self.display_scale)
            view = cv2.resize(view, (nw, nh), interpolation=cv2.INTER_AREA)

        if self.cursor_pos is not None:
            cx, cy = self.cursor_pos

            if self.line_mode and self.line_start is not None:
                sx = int(self.line_start[0] * self.display_scale)
                sy = int(self.line_start[1] * self.display_scale)
                cv2.line(view, (sx, sy), (cx, cy),
                         (0, 255, 255), 1, cv2.LINE_AA)
                cv2.circle(view, (sx, sy), 5, (0, 255, 0), -1)
                cv2.circle(view, (cx, cy), 5, (0, 200, 255), -1)
            elif self.vertical_mode:
                cv2.line(view, (cx, 0), (cx, view.shape[0] - 1),
                         (0, 255, 255), 1)
                cv2.line(view, (cx - self.brush_radius, 0),
                         (cx - self.brush_radius, view.shape[0] - 1),
                         (0, 200, 200), 1)
                cv2.line(view, (cx + self.brush_radius, 0),
                         (cx + self.brush_radius, view.shape[0] - 1),
                         (0, 200, 200), 1)
            else:
                cv2.circle(view, self.cursor_pos, self.brush_radius,
                           (0, 255, 255), 1)
                cv2.circle(view, self.cursor_pos, 2, (0, 255, 255), -1)

        return view

    def get_final_mask(self):
        return ((self.mask > 127).astype(np.uint8)) * 255

    def count_components(self):
        final = self.get_final_mask()
        num, _, stats, _ = cv2.connectedComponentsWithStats(final, 8)
        valid = sum(1 for i in range(1, num)
                    if stats[i, cv2.CC_STAT_AREA] > 50
                    and stats[i, cv2.CC_STAT_WIDTH] > 3)
        return valid


def save_mask(editor, log):
    final = editor.get_final_mask()
    cv2.imwrite(str(editor.mask_path), final,
                [cv2.IMWRITE_PNG_COMPRESSION, 3])
    editor.is_dirty = False
    cnt = editor.count_components()
    log[editor.roi_path.name] = {
        "mask": str(editor.mask_path),
        "components": cnt,
    }
    save_log(log)
    return cnt


def main():
    MASK_DIR.mkdir(parents=True, exist_ok=True)
    log = load_log()
    rois = list_rois()
    if not rois:
        print(f"! 没找到 ROI: {ROI_DIR}")
        return

    print(f"共 {len(rois)} 张 ROI,已标 {len(log)} 张")
    print("=" * 60)

    idx = 0
    while idx < len(rois) and rois[idx].name in log:
        idx += 1
    if idx >= len(rois):
        idx = 0
        print("所有图已标过,从第一张开始可继续修改")

    win = "Label Mask Multi"
    cv2.namedWindow(win, cv2.WINDOW_AUTOSIZE)

    def open_editor(i):
        roi_path = rois[i]
        mask_path = MASK_DIR / roi_path.name
        if mask_path.suffix.lower() == ".jpg":
            mask_path = mask_path.with_suffix(".png")
        print(f"\n[{i+1}/{len(rois)}] {roi_path.name}")
        ed = MaskEditor(roi_path, mask_path)

        def mouse_cb(event, x, y, flags, param):
            ed.cursor_pos = (x, y)

            # ======== 直线辅助模式 ========
            if ed.line_mode:
                if event in (cv2.EVENT_LBUTTONDOWN, cv2.EVENT_RBUTTONDOWN):
                    if ed.line_start is None:
                        # 第 1 次点击:记起点
                        ed.line_start = ed._disp_to_orig(x, y)
                    else:
                        # 第 2 次点击:画直线
                        end_orig = ed._disp_to_orig(x, y)
                        if event == cv2.EVENT_LBUTTONDOWN:
                            color = 255   # 左键 → 画白 (添加)
                        else:
                            color = 0     # 右键 → 画黑 (消除)
                        ed.paint_line(ed.line_start, end_orig, color)
                        action = "画线" if color == 255 else "擦线"
                        print(f"  [line] {action} "
                              f"({ed.line_start[0]},{ed.line_start[1]})"
                              f" -> ({end_orig[0]},{end_orig[1]})")
                        ed.line_start = None

                elif event == cv2.EVENT_MOUSEWHEEL:
                    if flags > 0:
                        ed.brush_radius = min(BRUSH_MAX, ed.brush_radius + 2)
                    else:
                        ed.brush_radius = max(BRUSH_MIN, ed.brush_radius - 2)
                return

            # ======== 常规模式 ========
            if event == cv2.EVENT_LBUTTONDOWN:
                ed.current_button = 0
                ed.last_paint_pos = None
                ed.paint(x, y, 255, save_undo=True)
            elif event == cv2.EVENT_RBUTTONDOWN:
                ed.current_button = 1
                ed.last_paint_pos = None
                ed.paint(x, y, 0, save_undo=True)
            elif event == cv2.EVENT_MOUSEMOVE:
                if ed.current_button == 0:
                    ed.paint(x, y, 255, save_undo=False)
                elif ed.current_button == 1:
                    ed.paint(x, y, 0, save_undo=False)
            elif event in (cv2.EVENT_LBUTTONUP, cv2.EVENT_RBUTTONUP):
                ed.current_button = -1
                ed.last_paint_pos = None
            elif event == cv2.EVENT_MOUSEWHEEL:
                if flags > 0:
                    ed.brush_radius = min(BRUSH_MAX, ed.brush_radius + 2)
                else:
                    ed.brush_radius = max(BRUSH_MIN, ed.brush_radius - 2)

        cv2.setMouseCallback(win, mouse_cb)
        return ed

    editor = open_editor(idx)

    print("\n========== 操作 ==========")
    print("  左键=画  右键=擦  滚轮/[]=笔刷")
    print("  V=竖条模式  L=直线辅助模式")
    print("    直线:第1次点击定起点,第2次左键=画,右键=擦")
    print("  Z=撤销  C=清空  B=切视图")
    print("  S=保存  A=上一张  D=下一张")
    print("  Q/ESC=退出")
    print("==========================\n")

    while True:
        img = editor.render()
        view_name = "OVERLAY" if editor.show_overlay else "MASK"
        if editor.line_mode:
            mode_name = "LINE"
        elif editor.vertical_mode:
            mode_name = "VERT-BAR"
        else:
            mode_name = "CIRCLE"
        cnt = editor.count_components()

        line_hint = ""
        if editor.line_mode:
            if editor.line_start is not None:
                line_hint = "  [2nd: L-Draw R-Erase]"
            else:
                line_hint = "  [1st: set start]"

        info1 = (f"[{idx+1}/{len(rois)}] {editor.roi_path.name}  "
                 f"comps={cnt}")
        info2 = (f"{view_name}  {mode_name}  "
                 f"brush={editor.brush_radius}  "
                 f"{'*DIRTY*' if editor.is_dirty else 'saved'}"
                 f"{line_hint}")

        cv2.putText(img, info1, (10, 25),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 0), 2)
        cv2.putText(img, info2, (10, 50),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
        cv2.imshow(win, img)

        k = cv2.waitKey(20) & 0xFF
        if k == 255:
            continue

        if k in (ord('q'), ord('Q'), 27):
            if editor.is_dirty:
                save_mask(editor, log)
                print("  [save] 保存退出")
            break
        elif k in (ord('s'), ord('S')):
            cnt = save_mask(editor, log)
            print(f"  [save] 保存 (comps={cnt})")
        elif k in (ord('d'), ord('D'), 83):
            if editor.is_dirty:
                cnt = save_mask(editor, log)
                print(f"  [save] 自动保存 (comps={cnt})")
            if idx < len(rois) - 1:
                idx += 1
                editor = open_editor(idx)
            else:
                print("  已是最后一张")
        elif k in (ord('a'), ord('A'), 81):
            if editor.is_dirty:
                cnt = save_mask(editor, log)
                print(f"  [save] 自动保存 (comps={cnt})")
            if idx > 0:
                idx -= 1
                editor = open_editor(idx)
            else:
                print("  已是第一张")
        elif k in (ord('v'), ord('V')):
            editor.line_mode = False
            editor.line_start = None
            editor.vertical_mode = not editor.vertical_mode
            print(f"  模式: {'竖条' if editor.vertical_mode else '圆形画笔'}")
        elif k in (ord('l'), ord('L')):
            editor.line_mode = not editor.line_mode
            editor.line_start = None
            if editor.line_mode:
                editor.vertical_mode = False
            print(f"  模式: {'直线辅助' if editor.line_mode else '圆形画笔'}")
            if editor.line_mode:
                print("    第1次点击定起点")
                print("    第2次左键=画白线  右键=画黑线(消除)")
        elif k in (ord('z'), ord('Z')):
            editor.undo()
        elif k in (ord('c'), ord('C')):
            editor.clear()
            print("  清空")
        elif k in (ord('b'), ord('B')):
            editor.show_overlay = not editor.show_overlay
        elif k == ord('['):
            editor.brush_radius = max(BRUSH_MIN, editor.brush_radius - 2)
        elif k == ord(']'):
            editor.brush_radius = min(BRUSH_MAX, editor.brush_radius + 2)

    cv2.destroyAllWindows()
    print("\n========== 退出 ==========")
    print(f"  已标: {len(log)}/{len(rois)}")
    print(f"  日志: {LOG_PATH}")


if __name__ == "__main__":
    main()
相关推荐
码界筑梦坊1 小时前
119-基于Python的各类企业排行数据可视化分析系统
开发语言·python·信息可视化·数据分析·毕业设计·echarts·fastapi
习明然1 小时前
记录下解决Python在windows 2008 Server 无法启动
开发语言·windows·python
duke8692672141 小时前
C# 文件上传的服务器端加密 C#如何在存储到S3或Azure Blob时启用加密
jvm·数据库·python
凯瑟琳.奥古斯特1 小时前
IP组播跨子网传输核心技术解析
java·开发语言·网络·网络协议·职场和发展
SOC罗三炮1 小时前
Hermes Agent v0.14.0:不用装 WSL 了,Windows 原生支持来了(Early Beta)
python
用户78937733908531 小时前
前端转后端生存指南(中):化身架构师,用 ORM 魔法掌控数据库
后端·python
xyq20241 小时前
Razor VB 循环
开发语言
古城小栈1 小时前
Bun从Zig迁移至Rust:有何重大意义?
开发语言·后端·rust
༒࿈南林࿈༒1 小时前
某川数据接口逆向、SM系列国密算法
python·js逆向·国密(sm系列)