代码
论文复现的是《RSAR: Restricted State Angle Resolver and Rotated SAR Benchmark》,设计了一个前端界面供参考!
python
import gradio as gr
import cv2
import numpy as np
import torch
import mmcv
from mmdet.apis import init_detector, inference_detector
from mmengine.structures import InstanceData
import matplotlib.pyplot as plt
import os
from PIL import Image
import warnings
warnings.filterwarnings('ignore')
# 设置路径
config_file = 'D:/PythonProject/RSAR/configs/redet/redet-le90_re50_refpn_1x_rsar.py'
checkpoint_file = 'D:/PythonProject/RSAR/epoch_22.pth'
# 类别标签(根据RSAR论文)
CLASSES = ['Ship', 'Tank', 'Bridge', 'Aircraft', 'Harbor', 'Car']
# 颜色映射
COLORS = [
(255, 0, 0), # 红色 - Ship
(0, 255, 0), # 绿色 - Tank
(0, 0, 255), # 蓝色 - Bridge
(255, 255, 0), # 黄色 - Aircraft
(255, 0, 255), # 紫色 - Harbor
(0, 255, 255), # 青色 - Car
]
# 初始化模型
def init_model():
"""初始化RSAR模型"""
print("正在初始化RSAR模型...")
model = init_detector(config_file, checkpoint_file, device='cuda:0')
print("模型初始化完成!")
return model
def decode_rbox_to_corners(rbox):
"""
将旋转框(x, y, w, h, theta)解码为四个角点坐标
参数:
rbox: [x, y, w, h, theta] theta为弧度
返回:
corners: 四个角点坐标 [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
"""
x, y, w, h, theta = rbox
# 计算半宽半高
w2 = w / 2.0
h2 = h / 2.0
# 旋转矩阵
cos_t = np.cos(theta)
sin_t = np.sin(theta)
# 四个角点(相对于中心)
corners = np.array([
[-w2, -h2],
[w2, -h2],
[w2, h2],
[-w2, h2]
])
# 旋转
rot_matrix = np.array([[cos_t, -sin_t], [sin_t, cos_t]])
corners_rotated = np.dot(corners, rot_matrix.T)
# 平移
corners_rotated[:, 0] += x
corners_rotated[:, 1] += y
return corners_rotated
def extract_detections(result, score_thr=0.3):
"""
从DetDataSample中提取检测结果
参数:
result: DetDataSample对象
score_thr: 置信度阈值
返回:
detections: 按类别分组的检测结果列表
"""
if hasattr(result, 'pred_instances'):
instances = result.pred_instances
# 检查是否有检测结果
if instances is None or len(instances) == 0:
return [[] for _ in range(len(CLASSES))]
# 获取边界框、分数和标签
bboxes = instances.bboxes.cpu().numpy() if hasattr(instances, 'bboxes') else None
scores = instances.scores.cpu().numpy() if hasattr(instances, 'scores') else None
labels = instances.labels.cpu().numpy() if hasattr(instances, 'labels') else None
if bboxes is None or scores is None or labels is None:
return [[] for _ in range(len(CLASSES))]
# 按类别分组检测结果
detections = [[] for _ in range(len(CLASSES))]
for i in range(len(bboxes)):
score = scores[i]
if score < score_thr:
continue
bbox = bboxes[i]
label = int(labels[i])
if label < len(CLASSES):
detections[label].append(bbox)
return detections
else:
# 回退到旧的格式处理
detections = []
for i in range(len(CLASSES)):
if hasattr(result, f'pred_instances_{i}'):
instances = getattr(result, f'pred_instances_{i}')
if instances is not None and len(instances) > 0:
detections.append(instances)
else:
detections.append([])
else:
detections.append([])
return detections
def draw_rotated_boxes_cv(image, detections, score_thr=0.3):
"""
使用OpenCV绘制旋转边界框
参数:
image: 原始图像
detections: 检测结果列表
score_thr: 置信度阈值
返回:
绘制了边界框的图像和检测统计
"""
if isinstance(image, str):
img = cv2.imread(image)
else:
img = image.copy()
if img is None:
return None, {}
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_h, img_w = img_rgb.shape[:2]
detection_count = {}
for class_id, class_dets in enumerate(detections):
if len(class_dets) == 0:
continue
for det in class_dets:
# 跳过置信度低的检测
if len(det) >= 6 and det[5] < score_thr:
continue
# 提取旋转框参数
if len(det) >= 6: # 格式: [x, y, w, h, theta, score]
x, y, w, h, theta, score = det[:6]
elif len(det) == 5: # 格式: [x, y, w, h, theta]
x, y, w, h, theta = det
score = 1.0
else:
continue
# 解码为四个角点
corners = decode_rbox_to_corners([x, y, w, h, theta])
# 绘制旋转框
corners_int = corners.astype(np.int32)
color = COLORS[class_id]
cv2.polylines(img_rgb, [corners_int], isClosed=True, color=color, thickness=2)
# 在左上角显示类别和置信度
label = f'{CLASSES[class_id]}: {score:.2f}'
# 找到最左边的角点
min_x_idx = np.argmin(corners[:, 0])
min_x, min_y = corners[min_x_idx]
# 确保坐标在图像范围内
min_x = max(0, min(min_x, img_w - 1))
min_y = max(0, min(min_y, img_h - 1))
# 计算文本大小
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 2
(text_width, text_height), baseline = cv2.getTextSize(
label, font, font_scale, thickness
)
# 调整文本框位置
text_x = int(min_x)
text_y = int(min_y) - 5
if text_y < text_height + 5:
text_y = int(min_y) + text_height + 10
# 绘制文本框背景
box_coords = (
(text_x - 2, text_y - text_height - 5),
(text_x + text_width + 2, text_y + 5)
)
cv2.rectangle(img_rgb, box_coords[0], box_coords[1], color, cv2.FILLED)
# 绘制文本
cv2.putText(
img_rgb,
label,
(text_x, text_y),
font,
font_scale,
(255, 255, 255), # 白色文字
thickness
)
# 统计检测数量
class_name = CLASSES[class_id]
detection_count[class_name] = detection_count.get(class_name, 0) + 1
# 在图像左上角添加统计信息
if detection_count:
y_offset = 30
cv2.putText(img_rgb, "Detection Summary:", (10, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
cv2.putText(img_rgb, "Detection Summary:", (10, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1)
for i, (class_name, count) in enumerate(detection_count.items()):
y_offset += 25
color = COLORS[CLASSES.index(class_name)] if class_name in CLASSES else (255, 255, 255)
color_bgr = (color[2], color[1], color[0]) # RGB转BGR
# 绘制颜色块
cv2.rectangle(img_rgb, (10, y_offset - 15), (25, y_offset), color_bgr, -1)
cv2.rectangle(img_rgb, (10, y_offset - 15), (25, y_offset), (255, 255, 255), 1)
# 添加文本
text = f"{class_name}: {count}"
cv2.putText(img_rgb, text, (30, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
cv2.putText(img_rgb, text, (30, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1)
return img_rgb, detection_count
def process_image(image, score_threshold):
"""
处理图像并进行检测
参数:
image: 输入图像
score_threshold: 置信度阈值
返回:
检测结果图像和统计信息
"""
if image is None:
error_img = np.ones((300, 400, 3), dtype=np.uint8) * 255
cv2.putText(error_img, "请上传图像", (100, 150),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
error_img_rgb = cv2.cvtColor(error_img, cv2.COLOR_BGR2RGB)
return error_img_rgb, "请上传图像"
try:
# 将PIL图像转换为numpy数组
if isinstance(image, Image.Image):
img_np = np.array(image)
if len(img_np.shape) == 2: # 灰度图
img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
elif img_np.shape[2] == 4: # RGBA
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2RGB)
else: # RGB
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
else:
img_np = image.copy()
# 临时保存图像
temp_path = 'temp_input.jpg'
cv2.imwrite(temp_path, img_np)
# 进行检测 - 直接返回DetDataSample
result = inference_detector(model, temp_path)
# 提取检测结果
detections = extract_detections(result, score_threshold)
# 绘制结果
result_img, detection_count = draw_rotated_boxes_cv(img_np, detections, score_threshold)
# 创建统计文本
total_detections = sum(detection_count.values())
stats_text = f"总检测数: {total_detections}\n"
for class_name, count in detection_count.items():
stats_text += f"{class_name}: {count}\n"
if total_detections == 0:
stats_text += "\n⚠️ 未检测到目标,尝试降低置信度阈值"
# 转换为RGB格式用于显示
result_img_rgb = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
return result_img_rgb, stats_text
except Exception as e:
print(f"处理错误: {e}")
import traceback
traceback.print_exc()
# 显示错误图像
error_img = np.ones((300, 400, 3), dtype=np.uint8) * 255
cv2.putText(error_img, f"处理错误", (100, 120),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
cv2.putText(error_img, f"请检查图像格式", (80, 160),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 1)
error_img_rgb = cv2.cvtColor(error_img, cv2.COLOR_BGR2RGB)
return error_img_rgb, f"处理错误: {str(e)}"
finally:
# 清理临时文件
if os.path.exists(temp_path):
try:
os.remove(temp_path)
except:
pass
def process_batch_images(files, score_threshold):
"""
处理批量图像
参数:
files: 文件列表
score_threshold: 置信度阈值
返回:
处理后的图像列表和统计信息
"""
if not files:
return [], "没有图像"
processed_images = []
all_stats = []
for i, file_info in enumerate(files):
try:
# Gradio文件对象处理
if hasattr(file_info, 'name'):
# 从临时文件读取
temp_path = file_info.name
image = Image.open(temp_path)
else:
# 直接路径
image = Image.open(file_info)
result_img, stats = process_image(image, score_threshold)
if result_img is not None:
# 转换为PIL Image用于Gradio Gallery
result_pil = Image.fromarray(result_img)
processed_images.append(result_pil)
all_stats.append(f"图像 {i + 1}:\n{stats}\n")
else:
processed_images.append(None)
all_stats.append(f"图像 {i + 1}: 处理失败 - 无结果\n")
except Exception as e:
print(f"处理图像 {i + 1} 错误: {e}")
error_img = np.ones((300, 400, 3), dtype=np.uint8) * 255
cv2.putText(error_img, f"处理失败", (120, 150),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
error_img_rgb = cv2.cvtColor(error_img, cv2.COLOR_BGR2RGB)
processed_images.append(Image.fromarray(error_img_rgb))
all_stats.append(f"图像 {i + 1}: 处理失败 - {str(e)}\n")
combined_stats = "批量处理结果:\n" + "=" * 50 + "\n" + "\n".join(all_stats)
return processed_images, combined_stats
# 初始化模型
model = init_model()
# 创建Gradio界面
with gr.Blocks(title="RSAR旋转SAR目标检测系统", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🛰️ RSAR: 旋转SAR目标检测可视化系统")
gr.Markdown("### 指导教师:赵作鹏 ")
with gr.Tabs():
with gr.TabItem("📷 单图像检测"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 输入设置")
image_input = gr.Image(
label="上传SAR图像",
type="pil",
height=300
)
score_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.3,
step=0.05,
label="置信度阈值",
info="值越高,检测结果越严格"
)
with gr.Row():
detect_btn = gr.Button("🚀 开始检测", variant="primary", scale=2)
clear_btn = gr.Button("🗑️ 清空", variant="secondary", scale=1)
gr.Markdown("### 类别颜色说明")
color_html = ""
for i, class_name in enumerate(CLASSES):
color_hex = f"#{COLORS[i][0]:02x}{COLORS[i][1]:02x}{COLORS[i][2]:02x}"
color_html += f'<div style="margin:5px 0;"><span style="display:inline-block;width:20px;height:20px;background-color:{color_hex};border:1px solid #ccc;margin-right:10px;"></span><b>{class_name}</b></div>'
gr.HTML(color_html)
with gr.Column(scale=1):
gr.Markdown("### 检测结果")
image_output = gr.Image(
label="检测结果",
height=300,
type="pil"
)
stats_output = gr.Textbox(
label="📊 检测统计",
lines=8,
interactive=False
)
# 按钮事件
detect_btn.click(
fn=process_image,
inputs=[image_input, score_slider],
outputs=[image_output, stats_output]
)
clear_btn.click(
fn=lambda: (None, "等待检测..."),
outputs=[image_output, stats_output]
)
with gr.TabItem("📚 批量检测"):
with gr.Row():
with gr.Column():
gr.Markdown("### 批量上传")
batch_image_input = gr.Files(
label="选择多个图像文件",
file_types=["image"],
file_count="multiple"
)
batch_score_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.3,
step=0.05,
label="置信度阈值"
)
batch_detect_btn = gr.Button("🚀 批量检测", variant="primary")
with gr.Column():
gr.Markdown("### 检测结果图库")
gallery_output = gr.Gallery(
label="检测结果",
columns=3,
height=500,
object_fit="contain"
)
batch_stats_output = gr.Textbox(
label="批量统计",
lines=12
)
batch_detect_btn.click(
fn=process_batch_images,
inputs=[batch_image_input, batch_score_slider],
outputs=[gallery_output, batch_stats_output]
)
with gr.TabItem("ℹ️ 关于"):
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("""
## 系统简介
**RSAR旋转SAR目标检测系统**是基于最新研究的旋转目标检测算法,专门用于合成孔径雷达(SAR)图像中的多类别目标识别。
### 主要功能
- **旋转目标检测**: 精准检测任意角度的目标
- **多类别识别**: 支持6类典型SAR目标
- **批量处理**: 高效处理多张图像
- **可视化分析**: 直观展示检测结果
### 技术特点
- 基于ReDet旋转等变检测器
- 使用Unit Cycle Resolver(UCR)提高角度预测精度
- 在RSAR数据集上训练(95,842张图像)
""")
gr.Markdown("### 检测类别")
classes_html = '<div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 10px;">'
for i, class_name in enumerate(CLASSES):
color_hex = f"#{COLORS[i][0]:02x}{COLORS[i][1]:02x}{COLORS[i][2]:02x}"
classes_html += f'''
<div style="padding: 10px; border: 1px solid #ddd; border-radius: 5px; background: white;">
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background-color: {color_hex}; margin-right: 10px; border: 1px solid #ccc;"></div>
<b>{class_name}</b>
</div>
</div>
'''
classes_html += '</div>'
gr.HTML(classes_html)
with gr.Column(scale=1):
gr.Markdown("""
### 使用说明
1. **单图像检测**
- 在"单图像检测"标签页上传图像
- 调整置信度阈值(推荐0.3-0.5)
- 点击"开始检测"按钮
2. **批量检测**
- 在"批量检测"标签页选择多个图像
- 设置阈值并点击"批量检测"
- 查看图库和统计结果
### 注意事项
- 支持常见图像格式:JPG, PNG, BMP
- 建议图像尺寸:800×800像素以上
- 检测效果受图像质量和目标大小影响
### 联系信息
**研究团队**: 南开大学VCIP实验室
**论文**: RSAR: Restricted State Angle Resolver and Rotated SAR Benchmark
**GitHub**: [zhasion/RSAR](https://github.com/zhasion/RSAR)
""")
# 启动应用
if __name__ == "__main__":
print("=" * 60)
print("🚀 RSAR检测系统启动中...")
print("🌐 请访问: http://localhost:7860")
print("=" * 60)
try:
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
debug=False
)
except Exception as e:
print(f"启动失败: {e}")
print("请检查端口7860是否被占用")