RSAR的前端可视化界面

代码

论文复现的是《RSAR: Restricted State Angle Resolver and Rotated SAR Benchmark》,设计了一个前端界面供参考!

python 复制代码
import gradio as gr
import cv2
import numpy as np
import torch
import mmcv
from mmdet.apis import init_detector, inference_detector
from mmengine.structures import InstanceData
import matplotlib.pyplot as plt
import os
from PIL import Image
import warnings

warnings.filterwarnings('ignore')

# 设置路径
config_file = 'D:/PythonProject/RSAR/configs/redet/redet-le90_re50_refpn_1x_rsar.py'
checkpoint_file = 'D:/PythonProject/RSAR/epoch_22.pth'

# 类别标签(根据RSAR论文)
CLASSES = ['Ship', 'Tank', 'Bridge', 'Aircraft', 'Harbor', 'Car']

# 颜色映射
COLORS = [
    (255, 0, 0),  # 红色 - Ship
    (0, 255, 0),  # 绿色 - Tank
    (0, 0, 255),  # 蓝色 - Bridge
    (255, 255, 0),  # 黄色 - Aircraft
    (255, 0, 255),  # 紫色 - Harbor
    (0, 255, 255),  # 青色 - Car
]


# 初始化模型
def init_model():
    """初始化RSAR模型"""
    print("正在初始化RSAR模型...")
    model = init_detector(config_file, checkpoint_file, device='cuda:0')
    print("模型初始化完成!")
    return model


def decode_rbox_to_corners(rbox):
    """
    将旋转框(x, y, w, h, theta)解码为四个角点坐标

    参数:
        rbox: [x, y, w, h, theta] theta为弧度

    返回:
        corners: 四个角点坐标 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
    """
    x, y, w, h, theta = rbox

    # 计算半宽半高
    w2 = w / 2.0
    h2 = h / 2.0

    # 旋转矩阵
    cos_t = np.cos(theta)
    sin_t = np.sin(theta)

    # 四个角点(相对于中心)
    corners = np.array([
        [-w2, -h2],
        [w2, -h2],
        [w2, h2],
        [-w2, h2]
    ])

    # 旋转
    rot_matrix = np.array([[cos_t, -sin_t], [sin_t, cos_t]])
    corners_rotated = np.dot(corners, rot_matrix.T)

    # 平移
    corners_rotated[:, 0] += x
    corners_rotated[:, 1] += y

    return corners_rotated


def extract_detections(result, score_thr=0.3):
    """
    从DetDataSample中提取检测结果

    参数:
        result: DetDataSample对象
        score_thr: 置信度阈值

    返回:
        detections: 按类别分组的检测结果列表
    """
    if hasattr(result, 'pred_instances'):
        instances = result.pred_instances

        # 检查是否有检测结果
        if instances is None or len(instances) == 0:
            return [[] for _ in range(len(CLASSES))]

        # 获取边界框、分数和标签
        bboxes = instances.bboxes.cpu().numpy() if hasattr(instances, 'bboxes') else None
        scores = instances.scores.cpu().numpy() if hasattr(instances, 'scores') else None
        labels = instances.labels.cpu().numpy() if hasattr(instances, 'labels') else None

        if bboxes is None or scores is None or labels is None:
            return [[] for _ in range(len(CLASSES))]

        # 按类别分组检测结果
        detections = [[] for _ in range(len(CLASSES))]

        for i in range(len(bboxes)):
            score = scores[i]
            if score < score_thr:
                continue

            bbox = bboxes[i]
            label = int(labels[i])

            if label < len(CLASSES):
                detections[label].append(bbox)

        return detections
    else:
        # 回退到旧的格式处理
        detections = []
        for i in range(len(CLASSES)):
            if hasattr(result, f'pred_instances_{i}'):
                instances = getattr(result, f'pred_instances_{i}')
                if instances is not None and len(instances) > 0:
                    detections.append(instances)
                else:
                    detections.append([])
            else:
                detections.append([])
        return detections


def draw_rotated_boxes_cv(image, detections, score_thr=0.3):
    """
    使用OpenCV绘制旋转边界框

    参数:
        image: 原始图像
        detections: 检测结果列表
        score_thr: 置信度阈值

    返回:
        绘制了边界框的图像和检测统计
    """
    if isinstance(image, str):
        img = cv2.imread(image)
    else:
        img = image.copy()

    if img is None:
        return None, {}

    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_h, img_w = img_rgb.shape[:2]

    detection_count = {}

    for class_id, class_dets in enumerate(detections):
        if len(class_dets) == 0:
            continue

        for det in class_dets:
            # 跳过置信度低的检测
            if len(det) >= 6 and det[5] < score_thr:
                continue

            # 提取旋转框参数
            if len(det) >= 6:  # 格式: [x, y, w, h, theta, score]
                x, y, w, h, theta, score = det[:6]
            elif len(det) == 5:  # 格式: [x, y, w, h, theta]
                x, y, w, h, theta = det
                score = 1.0
            else:
                continue

            # 解码为四个角点
            corners = decode_rbox_to_corners([x, y, w, h, theta])

            # 绘制旋转框
            corners_int = corners.astype(np.int32)
            color = COLORS[class_id]
            cv2.polylines(img_rgb, [corners_int], isClosed=True, color=color, thickness=2)

            # 在左上角显示类别和置信度
            label = f'{CLASSES[class_id]}: {score:.2f}'

            # 找到最左边的角点
            min_x_idx = np.argmin(corners[:, 0])
            min_x, min_y = corners[min_x_idx]

            # 确保坐标在图像范围内
            min_x = max(0, min(min_x, img_w - 1))
            min_y = max(0, min(min_y, img_h - 1))

            # 计算文本大小
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.6
            thickness = 2

            (text_width, text_height), baseline = cv2.getTextSize(
                label, font, font_scale, thickness
            )

            # 调整文本框位置
            text_x = int(min_x)
            text_y = int(min_y) - 5
            if text_y < text_height + 5:
                text_y = int(min_y) + text_height + 10

            # 绘制文本框背景
            box_coords = (
                (text_x - 2, text_y - text_height - 5),
                (text_x + text_width + 2, text_y + 5)
            )
            cv2.rectangle(img_rgb, box_coords[0], box_coords[1], color, cv2.FILLED)

            # 绘制文本
            cv2.putText(
                img_rgb,
                label,
                (text_x, text_y),
                font,
                font_scale,
                (255, 255, 255),  # 白色文字
                thickness
            )

            # 统计检测数量
            class_name = CLASSES[class_id]
            detection_count[class_name] = detection_count.get(class_name, 0) + 1

    # 在图像左上角添加统计信息
    if detection_count:
        y_offset = 30
        cv2.putText(img_rgb, "Detection Summary:", (10, y_offset),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        cv2.putText(img_rgb, "Detection Summary:", (10, y_offset),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1)

        for i, (class_name, count) in enumerate(detection_count.items()):
            y_offset += 25
            color = COLORS[CLASSES.index(class_name)] if class_name in CLASSES else (255, 255, 255)
            color_bgr = (color[2], color[1], color[0])  # RGB转BGR

            # 绘制颜色块
            cv2.rectangle(img_rgb, (10, y_offset - 15), (25, y_offset), color_bgr, -1)
            cv2.rectangle(img_rgb, (10, y_offset - 15), (25, y_offset), (255, 255, 255), 1)

            # 添加文本
            text = f"{class_name}: {count}"
            cv2.putText(img_rgb, text, (30, y_offset),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
            cv2.putText(img_rgb, text, (30, y_offset),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1)

    return img_rgb, detection_count


def process_image(image, score_threshold):
    """
    处理图像并进行检测

    参数:
        image: 输入图像
        score_threshold: 置信度阈值

    返回:
        检测结果图像和统计信息
    """
    if image is None:
        error_img = np.ones((300, 400, 3), dtype=np.uint8) * 255
        cv2.putText(error_img, "请上传图像", (100, 150),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
        error_img_rgb = cv2.cvtColor(error_img, cv2.COLOR_BGR2RGB)
        return error_img_rgb, "请上传图像"

    try:
        # 将PIL图像转换为numpy数组
        if isinstance(image, Image.Image):
            img_np = np.array(image)
            if len(img_np.shape) == 2:  # 灰度图
                img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
            elif img_np.shape[2] == 4:  # RGBA
                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2RGB)
            else:  # RGB
                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
        else:
            img_np = image.copy()

        # 临时保存图像
        temp_path = 'temp_input.jpg'
        cv2.imwrite(temp_path, img_np)

        # 进行检测 - 直接返回DetDataSample
        result = inference_detector(model, temp_path)

        # 提取检测结果
        detections = extract_detections(result, score_threshold)

        # 绘制结果
        result_img, detection_count = draw_rotated_boxes_cv(img_np, detections, score_threshold)

        # 创建统计文本
        total_detections = sum(detection_count.values())
        stats_text = f"总检测数: {total_detections}\n"
        for class_name, count in detection_count.items():
            stats_text += f"{class_name}: {count}\n"

        if total_detections == 0:
            stats_text += "\n⚠️ 未检测到目标,尝试降低置信度阈值"

        # 转换为RGB格式用于显示
        result_img_rgb = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)

        return result_img_rgb, stats_text

    except Exception as e:
        print(f"处理错误: {e}")
        import traceback
        traceback.print_exc()

        # 显示错误图像
        error_img = np.ones((300, 400, 3), dtype=np.uint8) * 255
        cv2.putText(error_img, f"处理错误", (100, 120),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
        cv2.putText(error_img, f"请检查图像格式", (80, 160),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 1)
        error_img_rgb = cv2.cvtColor(error_img, cv2.COLOR_BGR2RGB)
        return error_img_rgb, f"处理错误: {str(e)}"

    finally:
        # 清理临时文件
        if os.path.exists(temp_path):
            try:
                os.remove(temp_path)
            except:
                pass


def process_batch_images(files, score_threshold):
    """
    处理批量图像

    参数:
        files: 文件列表
        score_threshold: 置信度阈值

    返回:
        处理后的图像列表和统计信息
    """
    if not files:
        return [], "没有图像"

    processed_images = []
    all_stats = []

    for i, file_info in enumerate(files):
        try:
            # Gradio文件对象处理
            if hasattr(file_info, 'name'):
                # 从临时文件读取
                temp_path = file_info.name
                image = Image.open(temp_path)
            else:
                # 直接路径
                image = Image.open(file_info)

            result_img, stats = process_image(image, score_threshold)
            if result_img is not None:
                # 转换为PIL Image用于Gradio Gallery
                result_pil = Image.fromarray(result_img)
                processed_images.append(result_pil)
                all_stats.append(f"图像 {i + 1}:\n{stats}\n")
            else:
                processed_images.append(None)
                all_stats.append(f"图像 {i + 1}: 处理失败 - 无结果\n")

        except Exception as e:
            print(f"处理图像 {i + 1} 错误: {e}")
            error_img = np.ones((300, 400, 3), dtype=np.uint8) * 255
            cv2.putText(error_img, f"处理失败", (120, 150),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
            error_img_rgb = cv2.cvtColor(error_img, cv2.COLOR_BGR2RGB)
            processed_images.append(Image.fromarray(error_img_rgb))
            all_stats.append(f"图像 {i + 1}: 处理失败 - {str(e)}\n")

    combined_stats = "批量处理结果:\n" + "=" * 50 + "\n" + "\n".join(all_stats)
    return processed_images, combined_stats


# 初始化模型
model = init_model()

# 创建Gradio界面
with gr.Blocks(title="RSAR旋转SAR目标检测系统", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🛰️ RSAR: 旋转SAR目标检测可视化系统")
    gr.Markdown("### 指导教师:赵作鹏 ")

    with gr.Tabs():
        with gr.TabItem("📷 单图像检测"):
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("### 输入设置")
                    image_input = gr.Image(
                        label="上传SAR图像",
                        type="pil",
                        height=300
                    )

                    score_slider = gr.Slider(
                        minimum=0.1,
                        maximum=1.0,
                        value=0.3,
                        step=0.05,
                        label="置信度阈值",
                        info="值越高,检测结果越严格"
                    )

                    with gr.Row():
                        detect_btn = gr.Button("🚀 开始检测", variant="primary", scale=2)
                        clear_btn = gr.Button("🗑️ 清空", variant="secondary", scale=1)

                    gr.Markdown("### 类别颜色说明")
                    color_html = ""
                    for i, class_name in enumerate(CLASSES):
                        color_hex = f"#{COLORS[i][0]:02x}{COLORS[i][1]:02x}{COLORS[i][2]:02x}"
                        color_html += f'<div style="margin:5px 0;"><span style="display:inline-block;width:20px;height:20px;background-color:{color_hex};border:1px solid #ccc;margin-right:10px;"></span><b>{class_name}</b></div>'
                    gr.HTML(color_html)

                with gr.Column(scale=1):
                    gr.Markdown("### 检测结果")
                    image_output = gr.Image(
                        label="检测结果",
                        height=300,
                        type="pil"
                    )

                    stats_output = gr.Textbox(
                        label="📊 检测统计",
                        lines=8,
                        interactive=False
                    )

            # 按钮事件
            detect_btn.click(
                fn=process_image,
                inputs=[image_input, score_slider],
                outputs=[image_output, stats_output]
            )

            clear_btn.click(
                fn=lambda: (None, "等待检测..."),
                outputs=[image_output, stats_output]
            )

        with gr.TabItem("📚 批量检测"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### 批量上传")
                    batch_image_input = gr.Files(
                        label="选择多个图像文件",
                        file_types=["image"],
                        file_count="multiple"
                    )

                    batch_score_slider = gr.Slider(
                        minimum=0.1,
                        maximum=1.0,
                        value=0.3,
                        step=0.05,
                        label="置信度阈值"
                    )

                    batch_detect_btn = gr.Button("🚀 批量检测", variant="primary")

                with gr.Column():
                    gr.Markdown("### 检测结果图库")
                    gallery_output = gr.Gallery(
                        label="检测结果",
                        columns=3,
                        height=500,
                        object_fit="contain"
                    )

                    batch_stats_output = gr.Textbox(
                        label="批量统计",
                        lines=12
                    )

            batch_detect_btn.click(
                fn=process_batch_images,
                inputs=[batch_image_input, batch_score_slider],
                outputs=[gallery_output, batch_stats_output]
            )

        with gr.TabItem("ℹ️ 关于"):
            with gr.Row():
                with gr.Column(scale=2):
                    gr.Markdown("""
                    ## 系统简介

                    **RSAR旋转SAR目标检测系统**是基于最新研究的旋转目标检测算法,专门用于合成孔径雷达(SAR)图像中的多类别目标识别。

                    ### 主要功能
                    - **旋转目标检测**: 精准检测任意角度的目标
                    - **多类别识别**: 支持6类典型SAR目标
                    - **批量处理**: 高效处理多张图像
                    - **可视化分析**: 直观展示检测结果

                    ### 技术特点
                    - 基于ReDet旋转等变检测器
                    - 使用Unit Cycle Resolver(UCR)提高角度预测精度
                    - 在RSAR数据集上训练(95,842张图像)
                    """)

                    gr.Markdown("### 检测类别")
                    classes_html = '<div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 10px;">'
                    for i, class_name in enumerate(CLASSES):
                        color_hex = f"#{COLORS[i][0]:02x}{COLORS[i][1]:02x}{COLORS[i][2]:02x}"
                        classes_html += f'''
                        <div style="padding: 10px; border: 1px solid #ddd; border-radius: 5px; background: white;">
                            <div style="display: flex; align-items: center;">
                                <div style="width: 20px; height: 20px; background-color: {color_hex}; margin-right: 10px; border: 1px solid #ccc;"></div>
                                <b>{class_name}</b>
                            </div>
                        </div>
                        '''
                    classes_html += '</div>'
                    gr.HTML(classes_html)

                with gr.Column(scale=1):
                    gr.Markdown("""
                    ### 使用说明

                    1. **单图像检测**
                       - 在"单图像检测"标签页上传图像
                       - 调整置信度阈值(推荐0.3-0.5)
                       - 点击"开始检测"按钮

                    2. **批量检测**
                       - 在"批量检测"标签页选择多个图像
                       - 设置阈值并点击"批量检测"
                       - 查看图库和统计结果

                    ### 注意事项
                    - 支持常见图像格式:JPG, PNG, BMP
                    - 建议图像尺寸:800×800像素以上
                    - 检测效果受图像质量和目标大小影响

                    ### 联系信息
                    **研究团队**: 南开大学VCIP实验室  
                    **论文**: RSAR: Restricted State Angle Resolver and Rotated SAR Benchmark  
                    **GitHub**: [zhasion/RSAR](https://github.com/zhasion/RSAR)
                    """)

# 启动应用
if __name__ == "__main__":
    print("=" * 60)
    print("🚀 RSAR检测系统启动中...")
    print("🌐 请访问: http://localhost:7860")
    print("=" * 60)

    try:
        demo.launch(
            server_name="0.0.0.0",
            server_port=7860,
            share=False,
            show_error=True,
            debug=False
        )
    except Exception as e:
        print(f"启动失败: {e}")
        print("请检查端口7860是否被占用")
相关推荐
asdfg12589632 小时前
数组去重(JS)
java·前端·javascript
鹏多多2 小时前
前端大数字精度解决:big.js的教程和原理解析
前端·javascript·vue.js
恋猫de小郭2 小时前
八年开源,GSY 用五种技术开发了同一个 Github 客户端,这次轮到 AI + Compose
android·前端·flutter
少年姜太公8 小时前
什么?还不知道git cherry pick?
前端·javascript·git
白兰地空瓶9 小时前
🏒 前端 AI 应用实战:用 Vue3 + Coze,把宠物一键变成冰球运动员!
前端·vue.js·coze
Liu.77410 小时前
vue3使用vue3-print-nb打印
前端·javascript·vue.js
松涛和鸣11 小时前
Linux Makefile : From Basic Syntax to Multi-File Project Compilation
linux·运维·服务器·前端·windows·哈希算法
dly_blog11 小时前
Vue 逻辑复用的多种方案对比!
前端·javascript·vue.js
万少12 小时前
HarmonyOS6 接入分享,原来也是三分钟的事情
前端·harmonyos